use crate::imp_prelude::*;
use crate::numeric_util;
use crate::{LinalgScalar, Zip};
use std::any::TypeId;
#[cfg(feature = "blas")]
use std::cmp;
#[cfg(feature = "blas")]
use std::mem::swap;
#[cfg(feature = "blas")]
use std::os::raw::c_int;
#[cfg(feature = "blas")]
use cblas_sys as blas_sys;
#[cfg(feature = "blas")]
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
#[cfg(feature = "blas")]
const DOT_BLAS_CUTOFF: usize = 32;
#[cfg(feature = "blas")]
const GEMM_BLAS_CUTOFF: usize = 7;
#[cfg(feature = "blas")]
#[allow(non_camel_case_types)]
type blas_index = c_int;
impl<A, S> ArrayBase<S, Ix1>
where
S: Data<Elem = A>,
{
pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
where
Self: Dot<Rhs>,
{
Dot::dot(self, rhs)
}
fn dot_generic<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
where
S2: Data<Elem = A>,
A: LinalgScalar,
{
debug_assert_eq!(self.len(), rhs.len());
assert!(self.len() == rhs.len());
if let Some(self_s) = self.as_slice() {
if let Some(rhs_s) = rhs.as_slice() {
return numeric_util::unrolled_dot(self_s, rhs_s);
}
}
let mut sum = A::zero();
for i in 0..self.len() {
unsafe {
sum = sum + *self.uget(i) * *rhs.uget(i);
}
}
sum
}
#[cfg(not(feature = "blas"))]
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
where
S2: Data<Elem = A>,
A: LinalgScalar,
{
self.dot_generic(rhs)
}
#[cfg(feature = "blas")]
fn dot_impl<S2>(&self, rhs: &ArrayBase<S2, Ix1>) -> A
where
S2: Data<Elem = A>,
A: LinalgScalar,
{
if self.len() >= DOT_BLAS_CUTOFF {
debug_assert_eq!(self.len(), rhs.len());
assert!(self.len() == rhs.len());
macro_rules! dot {
($ty:ty, $func:ident) => {{
if blas_compat_1d::<$ty, _>(self) && blas_compat_1d::<$ty, _>(rhs) {
unsafe {
let (lhs_ptr, n, incx) =
blas_1d_params(self.ptr.as_ptr(), self.len(), self.strides()[0]);
let (rhs_ptr, _, incy) =
blas_1d_params(rhs.ptr.as_ptr(), rhs.len(), rhs.strides()[0]);
let ret = blas_sys::$func(
n,
lhs_ptr as *const $ty,
incx,
rhs_ptr as *const $ty,
incy,
);
return cast_as::<$ty, A>(&ret);
}
}
}};
}
dot! {f32, cblas_sdot};
dot! {f64, cblas_ddot};
}
self.dot_generic(rhs)
}
}
#[cfg(feature = "blas")]
unsafe fn blas_1d_params<A>(
ptr: *const A,
len: usize,
stride: isize,
) -> (*const A, blas_index, blas_index) {
if stride >= 0 || len == 0 {
(ptr, len as blas_index, stride as blas_index)
} else {
let ptr = ptr.offset((len - 1) as isize * stride);
(ptr, len as blas_index, stride as blas_index)
}
}
pub trait Dot<Rhs> {
type Output;
fn dot(&self, rhs: &Rhs) -> Self::Output;
}
impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix1>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = A;
fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> A {
self.dot_impl(rhs)
}
}
impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix1>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array<A, Ix1>;
fn dot(&self, rhs: &ArrayBase<S2, Ix2>) -> Array<A, Ix1> {
rhs.t().dot(self)
}
}
impl<A, S> ArrayBase<S, Ix2>
where
S: Data<Elem = A>,
{
pub fn dot<Rhs>(&self, rhs: &Rhs) -> <Self as Dot<Rhs>>::Output
where
Self: Dot<Rhs>,
{
Dot::dot(self, rhs)
}
}
impl<A, S, S2> Dot<ArrayBase<S2, Ix2>> for ArrayBase<S, Ix2>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array2<A>;
fn dot(&self, b: &ArrayBase<S2, Ix2>) -> Array2<A> {
let a = self.view();
let b = b.view();
let ((m, k), (k2, n)) = (a.dim(), b.dim());
if k != k2 || m.checked_mul(n).is_none() {
dot_shape_error(m, k, k2, n);
}
let lhs_s0 = a.strides()[0];
let rhs_s0 = b.strides()[0];
let column_major = lhs_s0 == 1 && rhs_s0 == 1;
let mut v = Vec::with_capacity(m * n);
let mut c;
unsafe {
v.set_len(m * n);
c = Array::from_shape_vec_unchecked((m, n).set_f(column_major), v);
}
mat_mul_impl(A::one(), &a, &b, A::zero(), &mut c.view_mut());
c
}
}
#[cold]
#[inline(never)]
fn dot_shape_error(m: usize, k: usize, k2: usize, n: usize) -> ! {
match m.checked_mul(n) {
Some(len) if len <= ::std::isize::MAX as usize => {}
_ => panic!("ndarray: shape {} × {} overflows isize", m, n),
}
panic!(
"ndarray: inputs {} × {} and {} × {} are not compatible for matrix multiplication",
m, k, k2, n
);
}
#[cold]
#[inline(never)]
fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c2: usize) -> ! {
panic!("ndarray: inputs {} × {}, {} × {}, and output {} × {} are not compatible for matrix multiplication",
m, k, k2, n, c1, c2);
}
impl<A, S, S2> Dot<ArrayBase<S2, Ix1>> for ArrayBase<S, Ix2>
where
S: Data<Elem = A>,
S2: Data<Elem = A>,
A: LinalgScalar,
{
type Output = Array<A, Ix1>;
fn dot(&self, rhs: &ArrayBase<S2, Ix1>) -> Array<A, Ix1> {
let ((m, a), n) = (self.dim(), rhs.dim());
if a != n {
dot_shape_error(m, a, n, 1);
}
unsafe {
let mut c = Array::uninitialized(m);
general_mat_vec_mul(A::one(), self, rhs, A::zero(), &mut c);
c
}
}
}
impl<A, S, D> ArrayBase<S, D>
where
S: Data<Elem = A>,
D: Dimension,
{
pub fn scaled_add<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
where
S: DataMut,
S2: Data<Elem = A>,
A: LinalgScalar,
E: Dimension,
{
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
}
}
#[cfg(not(feature = "blas"))]
use self::mat_mul_general as mat_mul_impl;
#[cfg(feature = "blas")]
fn mat_mul_impl<A>(
alpha: A,
lhs: &ArrayView2<'_, A>,
rhs: &ArrayView2<'_, A>,
beta: A,
c: &mut ArrayViewMut2<'_, A>,
) where
A: LinalgScalar,
{
let cut = GEMM_BLAS_CUTOFF;
let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim());
if !(m > cut || n > cut || a > cut) || !(same_type::<A, f32>() || same_type::<A, f64>()) {
return mat_mul_general(alpha, lhs, rhs, beta, c);
}
{
let mut lhs_ = lhs.view();
let mut rhs_ = rhs.view();
let mut c_ = c.view_mut();
let lhs_s0 = lhs_.strides()[0];
let rhs_s0 = rhs_.strides()[0];
let both_f = lhs_s0 == 1 && rhs_s0 == 1;
let mut lhs_trans = CblasNoTrans;
let mut rhs_trans = CblasNoTrans;
if both_f {
let lhs_t = lhs_.reversed_axes();
lhs_ = rhs_.reversed_axes();
rhs_ = lhs_t;
c_ = c_.reversed_axes();
swap(&mut m, &mut n);
} else if lhs_s0 == 1 && m == a {
lhs_ = lhs_.reversed_axes();
lhs_trans = CblasTrans;
} else if rhs_s0 == 1 && a == n {
rhs_ = rhs_.reversed_axes();
rhs_trans = CblasTrans;
}
macro_rules! gemm {
($ty:ty, $gemm:ident) => {
if blas_row_major_2d::<$ty, _>(&lhs_)
&& blas_row_major_2d::<$ty, _>(&rhs_)
&& blas_row_major_2d::<$ty, _>(&c_)
{
let (m, k) = match lhs_trans {
CblasNoTrans => lhs_.dim(),
_ => {
let (rows, cols) = lhs_.dim();
(cols, rows)
}
};
let n = match rhs_trans {
CblasNoTrans => rhs_.raw_dim()[1],
_ => rhs_.raw_dim()[0],
};
let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index);
let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index);
let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index);
unsafe {
blas_sys::$gemm(
CblasRowMajor,
lhs_trans,
rhs_trans,
m as blas_index,
n as blas_index,
k as blas_index,
cast_as(&alpha),
lhs_.ptr.as_ptr() as *const _,
lhs_stride,
rhs_.ptr.as_ptr() as *const _,
rhs_stride,
cast_as(&beta),
c_.ptr.as_ptr() as *mut _,
c_stride,
);
}
return;
}
};
}
gemm!(f32, cblas_sgemm);
gemm!(f64, cblas_dgemm);
}
mat_mul_general(alpha, lhs, rhs, beta, c)
}
fn mat_mul_general<A>(
alpha: A,
lhs: &ArrayView2<'_, A>,
rhs: &ArrayView2<'_, A>,
beta: A,
c: &mut ArrayViewMut2<'_, A>,
) where
A: LinalgScalar,
{
let ((m, k), (_, n)) = (lhs.dim(), rhs.dim());
let ap = lhs.as_ptr();
let bp = rhs.as_ptr();
let cp = c.as_mut_ptr();
let (rsc, csc) = (c.strides()[0], c.strides()[1]);
if same_type::<A, f32>() {
unsafe {
::matrixmultiply::sgemm(
m,
k,
n,
cast_as(&alpha),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
cast_as(&beta),
cp as *mut _,
rsc,
csc,
);
}
} else if same_type::<A, f64>() {
unsafe {
::matrixmultiply::dgemm(
m,
k,
n,
cast_as(&alpha),
ap as *const _,
lhs.strides()[0],
lhs.strides()[1],
bp as *const _,
rhs.strides()[0],
rhs.strides()[1],
cast_as(&beta),
cp as *mut _,
rsc,
csc,
);
}
} else {
if c.is_empty() {
return;
}
if beta.is_zero() {
c.fill(beta);
}
let mut i = 0;
let mut j = 0;
loop {
unsafe {
let elt = c.uget_mut((i, j));
*elt = *elt * beta
+ alpha
* (0..k).fold(A::zero(), move |s, x| {
s + *lhs.uget((i, x)) * *rhs.uget((x, j))
});
}
j += 1;
if j == n {
j = 0;
i += 1;
if i == m {
break;
}
}
}
}
}
pub fn general_mat_mul<A, S1, S2, S3>(
alpha: A,
a: &ArrayBase<S1, Ix2>,
b: &ArrayBase<S2, Ix2>,
beta: A,
c: &mut ArrayBase<S3, Ix2>,
) where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: DataMut<Elem = A>,
A: LinalgScalar,
{
let ((m, k), (k2, n)) = (a.dim(), b.dim());
let (m2, n2) = c.dim();
if k != k2 || m != m2 || n != n2 {
general_dot_shape_error(m, k, k2, n, m2, n2);
} else {
mat_mul_impl(alpha, &a.view(), &b.view(), beta, &mut c.view_mut());
}
}
#[allow(clippy::collapsible_if)]
pub fn general_mat_vec_mul<A, S1, S2, S3>(
alpha: A,
a: &ArrayBase<S1, Ix2>,
x: &ArrayBase<S2, Ix1>,
beta: A,
y: &mut ArrayBase<S3, Ix1>,
) where
S1: Data<Elem = A>,
S2: Data<Elem = A>,
S3: DataMut<Elem = A>,
A: LinalgScalar,
{
let ((m, k), k2) = (a.dim(), x.dim());
let m2 = y.dim();
if k != k2 || m != m2 {
general_dot_shape_error(m, k, k2, 1, m2, 1);
} else {
#[cfg(feature = "blas")]
macro_rules! gemv {
($ty:ty, $gemv:ident) => {
if let Some(layout) = blas_layout::<$ty, _>(&a) {
if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) {
let a_trans = CblasNoTrans;
let a_stride = match layout {
CBLAS_LAYOUT::CblasRowMajor => {
a.strides()[0].max(k as isize) as blas_index
}
CBLAS_LAYOUT::CblasColMajor => {
a.strides()[1].max(m as isize) as blas_index
}
};
let x_stride = x.strides()[0] as blas_index;
let y_stride = y.strides()[0] as blas_index;
unsafe {
blas_sys::$gemv(
layout,
a_trans,
m as blas_index,
k as blas_index,
cast_as(&alpha),
a.ptr.as_ptr() as *const _,
a_stride,
x.ptr.as_ptr() as *const _,
x_stride,
cast_as(&beta),
y.ptr.as_ptr() as *mut _,
y_stride,
);
}
return;
}
}
};
}
#[cfg(feature = "blas")]
gemv!(f32, cblas_sgemv);
#[cfg(feature = "blas")]
gemv!(f64, cblas_dgemv);
if beta.is_zero() {
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {
*elt = row.dot(x) * alpha;
});
} else {
Zip::from(a.outer_iter()).and(y).apply(|row, elt| {
*elt = *elt * beta + row.dot(x) * alpha;
});
}
}
}
#[inline(always)]
fn same_type<A: 'static, B: 'static>() -> bool {
TypeId::of::<A>() == TypeId::of::<B>()
}
fn cast_as<A: 'static + Copy, B: 'static + Copy>(a: &A) -> B {
assert!(same_type::<A, B>());
unsafe { ::std::ptr::read(a as *const _ as *const B) }
}
#[cfg(feature = "blas")]
fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
where
S: Data,
A: 'static,
S::Elem: 'static,
{
if !same_type::<A, S::Elem>() {
return false;
}
if a.len() > blas_index::max_value() as usize {
return false;
}
let stride = a.strides()[0];
if stride > blas_index::max_value() as isize || stride < blas_index::min_value() as isize {
return false;
}
true
}
#[cfg(feature = "blas")]
enum MemoryOrder {
C,
F,
}
#[cfg(feature = "blas")]
fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
where
S: Data,
A: 'static,
S::Elem: 'static,
{
if !same_type::<A, S::Elem>() {
return false;
}
is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
}
#[cfg(feature = "blas")]
fn blas_column_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
where
S: Data,
A: 'static,
S::Elem: 'static,
{
if !same_type::<A, S::Elem>() {
return false;
}
is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
}
#[cfg(feature = "blas")]
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool {
let (m, n) = dim.into_pattern();
let s0 = stride[0] as isize;
let s1 = stride[1] as isize;
let (inner_stride, outer_dim) = match order {
MemoryOrder::C => (s1, n),
MemoryOrder::F => (s0, m),
};
if !(inner_stride == 1 || outer_dim == 1) {
return false;
}
if s0 < 1 || s1 < 1 {
return false;
}
if (s0 > blas_index::max_value() as isize || s0 < blas_index::min_value() as isize)
|| (s1 > blas_index::max_value() as isize || s1 < blas_index::min_value() as isize)
{
return false;
}
if m > blas_index::max_value() as usize || n > blas_index::max_value() as usize {
return false;
}
true
}
#[cfg(feature = "blas")]
fn blas_layout<A, S>(a: &ArrayBase<S, Ix2>) -> Option<CBLAS_LAYOUT>
where
S: Data,
A: 'static,
S::Elem: 'static,
{
if blas_row_major_2d::<A, _>(a) {
Some(CBLAS_LAYOUT::CblasRowMajor)
} else if blas_column_major_2d::<A, _>(a) {
Some(CBLAS_LAYOUT::CblasColMajor)
} else {
None
}
}
#[cfg(test)]
#[cfg(feature = "blas")]
mod blas_tests {
use super::*;
#[test]
fn blas_row_major_2d_normal_matrix() {
let m: Array2<f32> = Array2::zeros((3, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(!blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_row_matrix() {
let m: Array2<f32> = Array2::zeros((1, 5));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_column_matrix() {
let m: Array2<f32> = Array2::zeros((5, 1));
assert!(blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
#[test]
fn blas_row_major_2d_transposed_row_matrix() {
let m: Array2<f32> = Array2::zeros((1, 5));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m_t));
}
#[test]
fn blas_row_major_2d_transposed_column_matrix() {
let m: Array2<f32> = Array2::zeros((5, 1));
let m_t = m.t();
assert!(blas_row_major_2d::<f32, _>(&m_t));
assert!(blas_column_major_2d::<f32, _>(&m_t));
}
#[test]
fn blas_column_major_2d_normal_matrix() {
let m: Array2<f32> = Array2::zeros((3, 5).f());
assert!(!blas_row_major_2d::<f32, _>(&m));
assert!(blas_column_major_2d::<f32, _>(&m));
}
}