#[macro_use]
mod zipmacro;
use std::mem::MaybeUninit;
use crate::imp_prelude::*;
use crate::AssignElem;
use crate::IntoDimension;
use crate::Layout;
use crate::NdIndex;
use crate::indexes::{indices, Indices};
use crate::layout::{CORDER, FORDER};
macro_rules! fold_while {
($e:expr) => {
match $e {
FoldWhile::Continue(x) => x,
x => return x,
}
};
}
trait Broadcast<E>
where
E: IntoDimension,
{
type Output: NdProducer<Dim = E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output;
private_decl! {}
}
impl<S, D> ArrayBase<S, D>
where
S: RawData,
D: Dimension,
{
pub(crate) fn layout_impl(&self) -> Layout {
Layout::new(if self.is_standard_layout() {
if self.ndim() <= 1 {
FORDER | CORDER
} else {
CORDER
}
} else if self.ndim() > 1 && self.raw_view().reversed_axes().is_standard_layout() {
FORDER
} else {
0
})
}
}
impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
where
E: IntoDimension,
D: Dimension,
{
type Output = ArrayView<'a, A, E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output {
let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
}
private_impl! {}
}
pub trait Splittable: Sized {
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self);
}
impl<D> Splittable for D
where
D: Dimension,
{
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let mut d1 = self;
let mut d2 = d1.clone();
let i = axis.index();
let len = d1[i];
d1[i] = index;
d2[i] = len - index;
(d1, d2)
}
}
pub trait IntoNdProducer {
type Item;
type Dim: Dimension;
type Output: NdProducer<Dim = Self::Dim, Item = Self::Item>;
fn into_producer(self) -> Self::Output;
}
impl<P> IntoNdProducer for P
where
P: NdProducer,
{
type Item = P::Item;
type Dim = P::Dim;
type Output = Self;
fn into_producer(self) -> Self::Output {
self
}
}
pub trait NdProducer {
type Item;
type Dim: Dimension;
#[doc(hidden)]
type Ptr: Offset<Stride = Self::Stride>;
#[doc(hidden)]
type Stride: Copy;
#[doc(hidden)]
fn layout(&self) -> Layout;
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim;
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.raw_dim() == *dim
}
#[doc(hidden)]
fn as_ptr(&self) -> Self::Ptr;
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> <Self::Ptr as Offset>::Stride;
#[doc(hidden)]
fn contiguous_stride(&self) -> Self::Stride;
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self)
where
Self: Sized;
private_decl! {}
}
pub trait Offset: Copy {
type Stride: Copy;
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self;
private_decl! {}
}
impl<T> Offset for *const T {
type Stride = isize;
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
self.offset(s * (index as isize))
}
private_impl! {}
}
impl<T> Offset for *mut T {
type Stride = isize;
unsafe fn stride_offset(self, s: Self::Stride, index: usize) -> Self {
self.offset(s * (index as isize))
}
private_impl! {}
}
trait ZippableTuple: Sized {
type Item;
type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
type Dim: Dimension;
type Stride: Copy;
fn as_ptr(&self) -> Self::Ptr;
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
fn stride_of(&self, index: usize) -> Self::Stride;
fn contiguous_stride(&self) -> Self::Stride;
fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
}
impl<'a, A: 'a, S, D> IntoNdProducer for &'a ArrayBase<S, D>
where
D: Dimension,
S: Data<Elem = A>,
{
type Item = &'a A;
type Dim = D;
type Output = ArrayView<'a, A, D>;
fn into_producer(self) -> Self::Output {
self.view()
}
}
impl<'a, A: 'a, S, D> IntoNdProducer for &'a mut ArrayBase<S, D>
where
D: Dimension,
S: DataMut<Elem = A>,
{
type Item = &'a mut A;
type Dim = D;
type Output = ArrayViewMut<'a, A, D>;
fn into_producer(self) -> Self::Output {
self.view_mut()
}
}
impl<'a, A: 'a> IntoNdProducer for &'a [A] {
type Item = <Self::Output as NdProducer>::Item;
type Dim = Ix1;
type Output = ArrayView1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a mut [A] {
type Item = <Self::Output as NdProducer>::Item;
type Dim = Ix1;
type Output = ArrayViewMut1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a Vec<A> {
type Item = <Self::Output as NdProducer>::Item;
type Dim = Ix1;
type Output = ArrayView1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a mut Vec<A> {
type Item = <Self::Output as NdProducer>::Item;
type Dim = Ix1;
type Output = ArrayViewMut1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> {
type Item = &'a A;
type Dim = D;
type Ptr = *mut A;
type Stride = isize;
private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}
#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
&*ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}
#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> {
type Item = &'a mut A;
type Dim = D;
type Ptr = *mut A;
type Stride = isize;
private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}
#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
&mut *ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}
#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
impl<A, D: Dimension> NdProducer for RawArrayView<A, D> {
type Item = *const A;
type Dim = D;
type Ptr = *const A;
type Stride = isize;
private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *const A {
self.as_ptr()
}
#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *const A) -> *const A {
ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}
#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
impl<A, D: Dimension> NdProducer for RawArrayViewMut<A, D> {
type Item = *mut A;
type Dim = D;
type Ptr = *mut A;
type Stride = isize;
private_impl! {}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}
#[doc(hidden)]
fn layout(&self) -> Layout {
self.layout_impl()
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> *mut A {
ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.as_ptr().offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.stride_of(axis)
}
#[inline(always)]
fn contiguous_stride(&self) -> Self::Stride {
1
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
#[derive(Debug, Clone)]
pub struct Zip<Parts, D> {
parts: Parts,
dimension: D,
layout: Layout,
}
impl<P, D> Zip<(P,), D>
where
D: Dimension,
P: NdProducer<Dim = D>,
{
pub fn from<IP>(p: IP) -> Self
where
IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
{
let array = p.into_producer();
let dim = array.raw_dim();
Zip {
dimension: dim,
layout: array.layout(),
parts: (array,),
}
}
}
impl<P, D> Zip<(Indices<D>, P), D>
where
D: Dimension + Copy,
P: NdProducer<Dim = D>,
{
pub fn indexed<IP>(p: IP) -> Self
where
IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>,
{
let array = p.into_producer();
let dim = array.raw_dim();
Zip::from(indices(dim)).and(array)
}
}
impl<Parts, D> Zip<Parts, D>
where
D: Dimension,
{
fn check<P>(&self, part: &P)
where
P: NdProducer<Dim = D>,
{
ndassert!(
part.equal_dim(&self.dimension),
"Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
self.dimension,
part.raw_dim()
);
}
pub fn size(&self) -> usize {
self.dimension.size()
}
fn len_of(&self, axis: Axis) -> usize {
self.dimension[axis.index()]
}
fn max_stride_axis(&self) -> Axis {
let i = match self.layout.flag() {
FORDER => self
.dimension
.slice()
.iter()
.rposition(|&len| len > 1)
.unwrap_or(self.dimension.ndim() - 1),
_ => self
.dimension
.slice()
.iter()
.position(|&len| len > 1)
.unwrap_or(0),
};
Axis(i)
}
}
impl<P, D> Zip<P, D>
where
D: Dimension,
{
fn apply_core<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
if self.layout.is(CORDER | FORDER) {
self.apply_core_contiguous(acc, function)
} else {
self.apply_core_strided(acc, function)
}
}
fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
debug_assert!(self.layout.is(CORDER | FORDER));
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
for i in 0..size {
unsafe {
let ptr_i = ptrs.stride_offset(inner_strides, i);
acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
}
}
FoldWhile::Continue(acc)
}
fn apply_core_strided<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
let n = self.dimension.ndim();
if n == 0 {
panic!("Unreachable: ndim == 0 is contiguous")
}
let unroll_axis = n - 1;
let inner_len = self.dimension[unroll_axis];
self.dimension[unroll_axis] = 1;
let mut index_ = self.dimension.first_index();
let inner_strides = self.parts.stride_of(unroll_axis);
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(inner_strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
}
index_ = self.dimension.next_for(index);
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}
pub(crate) fn uninitalized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
{
let is_f = !self.layout.is(CORDER) && self.layout.is(FORDER);
Array::maybe_uninit(self.dimension.clone().set_f(is_f))
}
}
trait OffsetTuple {
type Args;
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
}
impl<T> OffsetTuple for *mut T {
type Args = isize;
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
self.offset(index as isize * stride)
}
}
macro_rules! offset_impl {
($([$($param:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
type Args = ($($param::Stride,)*);
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
let ($($param, )*) = self;
let ($($q, )*) = stride;
($(Offset::stride_offset($param, $q, index),)*)
}
}
)+
}
}
offset_impl! {
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! zipt_impl {
($([$($p:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
type Item = ($($p::Item, )*);
type Ptr = ($($p::Ptr, )*);
type Dim = Dim;
type Stride = ($($p::Stride,)* );
fn stride_of(&self, index: usize) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.stride_of(Axis(index)), )*)
}
fn contiguous_stride(&self) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.contiguous_stride(), )*)
}
fn as_ptr(&self) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.as_ptr(), )*)
}
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
let ($(ref $q ,)*) = *self;
let ($($p,)*) = ptr;
($($q.as_ref($p),)*)
}
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.uget_ptr(i), )*)
}
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let ($($p,)*) = self;
let ($($p,)*) = (
$($p.split_at(axis, index), )*
);
(
($($p.0,)*),
($($p.1,)*)
)
}
}
)+
}
}
zipt_impl! {
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! map_impl {
($([$notlast:ident $($p:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<D, $($p),*> Zip<($($p,)*), D>
where D: Dimension,
$($p: NdProducer<Dim=D> ,)*
{
pub fn apply<F>(mut self, mut function: F)
where F: FnMut($($p::Item),*)
{
self.apply_core((), move |(), args| {
let ($($p,)*) = args;
FoldWhile::Continue(function($($p),*))
});
}
pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
where
F: FnMut(Acc, $($p::Item),*) -> Acc,
{
self.apply_core(acc, move |acc, args| {
let ($($p,)*) = args;
FoldWhile::Continue(function(acc, $($p),*))
}).into_inner()
}
pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
-> FoldWhile<Acc>
where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
{
self.apply_core(acc, move |acc, args| {
let ($($p,)*) = args;
function(acc, $($p),*)
})
}
pub fn all<F>(mut self, mut predicate: F) -> bool
where F: FnMut($($p::Item),*) -> bool
{
!self.apply_core((), move |_, args| {
let ($($p,)*) = args;
if predicate($($p),*) {
FoldWhile::Continue(())
} else {
FoldWhile::Done(())
}
}).is_done()
}
expand_if!(@bool [$notlast]
pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
where P: IntoNdProducer<Dim=D>,
{
let array = p.into_producer();
self.check(&array);
let part_layout = array.layout();
let ($($p,)*) = self.parts;
Zip {
parts: ($($p,)* array, ),
layout: self.layout.and(part_layout),
dimension: self.dimension,
}
}
pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
-> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
D2: Dimension,
{
let array = p.into_producer().broadcast_unwrap(self.dimension.clone());
let part_layout = array.layout();
let ($($p,)*) = self.parts;
Zip {
parts: ($($p,)* array, ),
layout: self.layout.and(part_layout),
dimension: self.dimension,
}
}
pub fn apply_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D>
where R: Copy,
{
let mut output = self.uninitalized_for_current_layout::<R>();
self.apply_assign_into(&mut output, f);
unsafe {
output.assume_init()
}
}
pub fn apply_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
where Q: IntoNdProducer<Dim=D>,
Q::Item: AssignElem<R>
{
self.and(into)
.apply(move |$($p, )* output_| {
output_.assign_elem(f($($p ),*));
});
}
);
pub fn split(self) -> (Self, Self) {
debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
let axis = self.max_stride_axis();
let index = self.len_of(axis) / 2;
let (p1, p2) = self.parts.split_at(axis, index);
let (d1, d2) = self.dimension.split_at(axis, index);
(Zip {
dimension: d1,
layout: self.layout,
parts: p1,
},
Zip {
dimension: d2,
layout: self.layout,
parts: p2,
})
}
}
)+
}
}
map_impl! {
[true P1],
[true P1 P2],
[true P1 P2 P3],
[true P1 P2 P3 P4],
[true P1 P2 P3 P4 P5],
[false P1 P2 P3 P4 P5 P6],
}
#[derive(Debug, Copy, Clone)]
pub enum FoldWhile<T> {
Continue(T),
Done(T),
}
impl<T> FoldWhile<T> {
pub fn into_inner(self) -> T {
match self {
FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
}
}
pub fn is_done(&self) -> bool {
match *self {
FoldWhile::Continue(_) => false,
FoldWhile::Done(_) => true,
}
}
}