1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
use smartnoise_validator::errors::*;

use ndarray::prelude::*;
use crate::NodeArguments;
use smartnoise_validator::base::{Array, ReleaseNode};
use smartnoise_validator::utilities::take_argument;
use crate::components::Evaluable;
use ndarray::{ArrayD, Axis, Array1};

use smartnoise_validator::proto;

use smartnoise_validator::utilities::array::slow_select;
use crate::utilities::to_nd;


impl Evaluable for proto::Filter {
    fn evaluate(&self, _privacy_definition: &Option<proto::PrivacyDefinition>, mut arguments: NodeArguments) -> Result<ReleaseNode> {
        let mask = take_argument(&mut arguments, "mask")?.array()?.bool()?;

        Ok(ReleaseNode::new(match take_argument(&mut arguments, "data")?.array()? {
            Array::Str(data) => filter(data, mask)?.into(),
            Array::Float(data) => filter(data, mask)?.into(),
            Array::Int(data) => filter(data, mask)?.into(),
            Array::Bool(data) => filter(data, mask)?.into(),
        }))
    }
}

/// Filters data down into only the desired rows.
///
/// # Arguments
/// * `data` - Data to be filtered.
/// * `mask` - Boolean mask giving whether or not each row should be kept.
///
/// # Return
/// Data with only the desired rows.
///
/// # Example
/// ```
/// use ndarray::{ArrayD, arr1, arr2};
/// use smartnoise_runtime::components::filter::filter;
///
/// let data = arr2(&[ [1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12] ]).into_dyn();
/// let mask = arr1(&[true, false, true, false]).into_dyn();
/// let filtered = filter(data, mask).unwrap();
/// assert_eq!(filtered, arr2(&[ [1, 2, 3], [7, 8, 9] ]).into_dyn());
/// ```
pub fn filter<T: Clone + Default>(data: ArrayD<T>, mask: ArrayD<bool>) -> Result<ArrayD<T>> {

    let columnar_mask: Array1<bool> = to_nd(mask, 1)?.into_dimensionality::<Ix1>()?;

    let mask_indices: Vec<usize> = columnar_mask.iter().enumerate()
        .filter(|(_index, &v)| v)
        .map(|(index, _)| index)
        .collect();
    Ok(slow_select(&data, Axis(0), &mask_indices))
}