use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::iter::FromIterator;
use indexmap::map::IndexMap;
use itertools::Itertools;
use ndarray::prelude::*;
use noisy_float::prelude::n64;
use crate::{base, Float, proto, Warnable};
use crate::base::{IndexKey, NodeProperties, Release, SensitivitySpace, Value, ValueProperties, ArrayProperties, Array};
use crate::components::*;
use crate::errors::*;
use crate::utilities::inference::infer_property;
use crate::utilities::privacy::spread_privacy_usage;
use std::ops::MulAssign;
pub mod json;
pub mod inference;
pub mod serial;
pub mod array;
pub mod privacy;
pub mod properties;
pub fn take_argument(
arguments: &mut IndexMap<base::IndexKey, Value>,
name: &str,
) -> Result<Value> {
arguments.remove::<base::IndexKey>(&name.into())
.ok_or_else(|| Error::from(name.to_string() + " must be defined"))
}
pub fn get_argument<'a>(
arguments: &IndexMap<base::IndexKey, &'a Value>,
name: &str,
) -> Result<&'a Value> {
arguments.get::<base::IndexKey>(&name.into()).cloned()
.ok_or_else(|| Error::from(name.to_string() + " must be defined"))
}
pub fn get_public_arguments<'a>(
component: &proto::Component,
release: &'a Release,
) -> Result<IndexMap<base::IndexKey, &'a Value>> {
let mut arguments = IndexMap::<base::IndexKey, &'a Value>::new();
for (arg_name, arg_node_id) in component.arguments() {
if let Some(evaluation) = release.get(&arg_node_id) {
if evaluation.public {
arguments.insert(arg_name.to_owned(), &evaluation.value);
}
}
}
Ok(arguments)
}
pub fn get_input_properties<T>(
component: &proto::Component,
graph_properties: &HashMap<u32, T>,
) -> Result<IndexMap<base::IndexKey, T>> where T: std::clone::Clone {
let mut properties = IndexMap::<base::IndexKey, T>::new();
for (arg_name, arg_node_id) in component.arguments() {
if let Some(property) = graph_properties.get(&arg_node_id) {
properties.insert(arg_name.to_owned(), property.clone());
}
}
Ok(properties)
}
pub fn propagate_properties(
privacy_definition: &Option<proto::PrivacyDefinition>,
computation_graph: &mut HashMap<u32, proto::Component>,
release: &mut base::Release,
properties: Option<HashMap<u32, base::ValueProperties>>,
dynamic: bool
) -> Result<(HashMap<u32, ValueProperties>, Vec<Error>)> {
let mut traversal: Vec<u32> = get_traversal(&computation_graph)?;
traversal.reverse();
let mut properties = properties.unwrap_or_else(HashMap::new);
let mut maximum_id = computation_graph.keys().max().cloned().unwrap_or(0);
properties.extend(release.iter()
.filter(|(_, release_node)| release_node.public)
.map(|(node_id, release_node)|
Ok((*node_id, infer_property(
&release_node.value,
properties.get(node_id), *node_id)?)))
.collect::<Result<HashMap<u32, ValueProperties>>>()?);
let mut failed_ids = HashSet::new();
let mut warnings = Vec::new();
while !traversal.is_empty() {
let node_id = *traversal.last().unwrap();
let component: &proto::Component = computation_graph.get(&node_id).unwrap();
if component.arguments().values().any(|v| failed_ids.contains(v)) {
failed_ids.insert(traversal.pop().unwrap());
continue;
}
let mut expansion = match component
.expand_component(
privacy_definition,
&component,
&get_public_arguments(component, &release)?,
&get_input_properties(&component, &properties)?,
node_id,
maximum_id,
) {
Ok(expansion) => expansion,
Err(err) => if dynamic {
failed_ids.insert(traversal.pop().unwrap());
warnings.push(err);
continue;
} else { return Err(err) }
};
maximum_id = expansion.computation_graph.keys().max().cloned()
.unwrap_or(0).max(maximum_id);
computation_graph.extend(expansion.computation_graph);
properties.extend(expansion.properties);
release.extend(expansion.releases);
warnings.extend(expansion.warnings);
if !expansion.traversal.is_empty() {
expansion.traversal.reverse();
traversal.extend(expansion.traversal);
continue;
}
let component: &proto::Component = computation_graph.get(&node_id).unwrap();
let mut input_properties = IndexMap::<base::IndexKey, ValueProperties>::new();
let mut missing_properties = Vec::new();
for (arg_name, arg_node_id) in component.arguments() {
if let Some(property) = properties.get(&arg_node_id) {
input_properties.insert(arg_name.to_owned(), property.clone());
} else {
missing_properties.push(arg_node_id);
}
}
if !missing_properties.is_empty() {
traversal.extend(missing_properties);
continue
}
traversal.pop();
let release_node = release.get(&node_id);
let propagation_result = if release_node
.map(|release_node| release_node.public).unwrap_or(false) {
Ok(Warnable(infer_property(
&release_node.unwrap().value,
properties.get(&node_id), node_id)?, vec![]))
} else {
computation_graph.get(&node_id).unwrap()
.propagate_property(
privacy_definition,
get_public_arguments(component, &release)?,
input_properties,
node_id)
.chain_err(|| format!("at node_id {:?}", node_id))
};
match propagation_result {
Ok(propagation_result) => {
let Warnable(component_properties, propagation_warnings) = propagation_result;
warnings.extend(propagation_warnings.into_iter()
.map(|err| err.chain_err(|| format!("at node_id {:?}", node_id)))
.collect::<Vec<Error>>());
properties.insert(node_id, component_properties);
},
Err(err) => if dynamic {
failed_ids.insert(node_id);
warnings.push(err);
} else { return Err(err) }
};
}
Ok((properties, warnings))
}
pub fn get_traversal(
graph: &HashMap<u32, proto::Component>
) -> Result<Vec<u32>> {
let mut parents = HashMap::<u32, HashSet<u32>>::new();
graph.iter().for_each(|(node_id, component)| {
parents.entry(*node_id)
.or_insert_with(HashSet::<u32>::new);
component.arguments().values().for_each(|argument_node_id| {
parents.entry(*argument_node_id)
.or_insert_with(HashSet::<u32>::new)
.insert(*node_id);
});
});
let mut traversal = Vec::new();
let mut queue: Vec<u32> = graph.iter()
.filter(|(_node_id, component)| component.arguments().is_empty()
|| component.arguments().values().all(|arg_idx| !graph.contains_key(arg_idx)))
.map(|(node_id, _component)| node_id.to_owned()).collect();
let mut visited = HashMap::new();
while !queue.is_empty() {
let queue_node_id: u32 = *queue.last().unwrap();
queue.pop();
traversal.push(queue_node_id);
let mut is_cyclic = false;
parents.get(&queue_node_id).unwrap().iter().for_each(|parent_node_id| {
let parent_arguments = graph.get(parent_node_id).unwrap().to_owned().arguments();
let count = visited.entry(*parent_node_id).or_insert(0);
*count += 1;
if visited.get(parent_node_id).unwrap() > &parent_arguments.len() {
is_cyclic = true;
}
if parent_arguments.values().all(|argument_node_id| traversal.contains(argument_node_id)) {
queue.push(*parent_node_id);
}
});
if is_cyclic {
return Err("Graph is cyclic.".into());
}
}
Ok(traversal)
}
pub fn get_sinks(computation_graph: &HashMap<u32, proto::Component>) -> HashSet<u32> {
let mut node_ids = HashSet::from_iter(computation_graph.keys().cloned());
computation_graph.values()
.for_each(|component| component.arguments().values()
.for_each(|source_node_id| {
node_ids.remove(source_node_id);
}));
node_ids
}
pub fn set_node_id(property: &mut ValueProperties, node_id: u32) -> () {
match property {
ValueProperties::Array(array) => array.node_id = node_id as i64,
ValueProperties::Dataframe(dataframe) => dataframe.children.iter_mut()
.for_each(|(_k, v)| set_node_id(v, node_id)),
ValueProperties::Partitions(partitions) => partitions.children.iter_mut()
.for_each(|(_k, v)| set_node_id(v, node_id)),
ValueProperties::Jagged(_) => (),
ValueProperties::Function(_) => ()
};
}
#[doc(hidden)]
pub fn standardize_numeric_argument<T: Clone>(value: ArrayD<T>, length: i64) -> Result<ArrayD<T>> {
match value.ndim() {
0 => match value.first() {
Some(scalar) => Ok(ndarray::Array::from((0..length).map(|_| scalar.clone())
.collect::<Vec<T>>()).into_dyn()),
None => Err("value must be non-empty".into())
},
1 => if value.len() as i64 == length {
Ok(value)
} else { Err("value is of incompatible length".into()) },
_ => Err("value must be a scalar or vector".into())
}
}
pub fn standardize_float_argument(
mut categories: Vec<Vec<Float>>,
length: i64,
) -> Result<Vec<Vec<Float>>> {
if categories.is_empty() {
return Err("no categories are defined".into());
}
categories.clone().into_iter().try_for_each(|mut col| {
if !col.iter().all(|v| v.is_finite()) {
return Err("all floats must be finite".into());
}
col.sort_unstable_by(|l, r| l.partial_cmp(r).unwrap());
let original_length = col.len();
if deduplicate(col.into_iter()
.map(|v| n64(v as f64)).collect()).len() < original_length {
return Err("floats must not contain duplicates".into());
}
Ok::<_, Error>(())
})?;
if categories.len() == 1 {
categories = (0..length).map(|_| categories.first().unwrap().clone()).collect();
}
Ok(categories)
}
#[doc(hidden)]
pub fn standardize_categorical_argument<T: Clone + Eq + Hash + Ord>(
categories: Vec<Vec<T>>,
length: i64,
) -> Result<Vec<Vec<T>>> {
let mut categories = categories.into_iter()
.map(deduplicate).collect::<Vec<Vec<T>>>();
if categories.is_empty() {
return Err("no categories are defined".into());
}
if categories.len() == 1 {
categories = (0..length).map(|_| categories.first().unwrap().clone()).collect();
}
Ok(categories)
}
#[doc(hidden)]
pub fn standardize_null_candidates_argument<T: Clone>(
mut value: Vec<Vec<T>>,
length: i64,
) -> Result<Vec<Vec<T>>> {
if value.is_empty() {
return Err("null values cannot be an empty vector".into());
}
if value.len() == 1 {
let first_set = value.first().unwrap();
value = (0..length).map(|_| first_set.clone()).collect();
}
Ok(value)
}
#[doc(hidden)]
pub fn standardize_null_target_argument<T: Clone>(
value: ArrayD<T>,
length: i64,
) -> Result<Vec<T>> {
if value.is_empty() {
return Err("null values cannot be empty".into());
}
if value.len() == length as usize {
return Ok(value.iter().cloned().collect());
}
if value.len() == 1 {
let value = value.first().unwrap();
return Ok((0..length).map(|_| value.clone()).collect());
}
bail!("length of null must be one, or {}", length)
}
#[doc(hidden)]
pub fn standardize_weight_argument(
weights: &Option<Vec<Vec<Float>>>,
lengths: &[i64],
) -> Result<Vec<Vec<Float>>> {
let weights = weights.clone().unwrap_or_else(Vec::new);
fn uniform_density(length: usize) -> Vec<Float> {
(0..length).map(|_| 1. / (length as Float)).collect()
}
fn normalize_probabilities(weights: &[Float]) -> Result<Vec<Float>> {
if !weights.iter().all(|w| w >= &0.) {
return Err("all weights must be greater than zero".into());
}
let sum: Float = weights.iter().sum();
Ok(weights.iter().map(|prob| prob / sum).collect())
}
match weights.len() {
0 => Ok(lengths.iter()
.map(|length| uniform_density(*length as usize))
.collect::<Vec<Vec<Float>>>()),
1 => {
let probabilities = normalize_probabilities(&weights[0])?;
lengths.iter()
.map(|length| if *length as usize == weights.len() {
Ok(probabilities.clone())
} else {
Err("length of weights does not match number of categories".into())
}).collect::<Result<Vec<Vec<Float>>>>()
}
_ => if lengths.len() == weights.len() {
weights.iter().map(|v| normalize_probabilities(v))
.collect::<Result<Vec<Vec<Float>>>>()
} else {
Err("category weights must be the same length as categories, or none".into())
}
}
}
#[doc(hidden)]
pub fn get_literal(value: Value, submission: u32) -> Result<(proto::Component, base::ReleaseNode)> {
Ok((
proto::Component {
arguments: None,
variant: Some(proto::component::Variant::Literal(proto::Literal {})),
omit: true,
submission,
},
base::ReleaseNode {
value,
privacy_usages: None,
public: true,
}
))
}
#[doc(hidden)]
pub fn prepend(text: &str) -> impl Fn(Error) -> Error + '_ {
move |e| format!("{} {}", text, e).into()
}
#[allow(clippy::float_cmp)]
pub fn expand_mechanism(
sensitivity_type: &SensitivitySpace,
privacy_definition: &Option<proto::PrivacyDefinition>,
privacy_usage: &[proto::PrivacyUsage],
component: &proto::Component,
properties: &NodeProperties,
component_id: u32,
mut maximum_id: u32,
) -> Result<base::ComponentExpansion> {
let mut expansion = base::ComponentExpansion::default();
let privacy_definition = privacy_definition.as_ref()
.ok_or_else(|| "privacy definition must be defined")?;
let data_property: ArrayProperties = properties.get::<IndexKey>(&"data".into())
.ok_or("data: missing")?.array()
.map_err(prepend("data:"))?.clone();
let spread_usages = spread_privacy_usage(
privacy_usage, data_property.num_columns()? as usize)?;
let effective_usages = spread_usages.into_iter()
.map(|usage| usage.actual_to_effective(
data_property.sample_proportion.unwrap_or(1.),
data_property.c_stability,
privacy_definition.group_size))
.collect::<Result<Vec<proto::PrivacyUsage>>>()?;
let mut noise_component = component.clone();
macro_rules! assign_usage {
($($variant:ident),*) => {
match noise_component.variant.as_mut() {
$(Some(proto::component::Variant::$variant(variant)) =>
variant.privacy_usage = effective_usages,)*
_ => return Err(Error::from("unrecognized component in expand_mechanism"))
}
}
}
assign_usage!(LaplaceMechanism, GaussianMechanism, SimpleGeometricMechanism, SnappingMechanism);
if let Some(sensitivity_property) = properties.get(&IndexKey::from("sensitivity")) {
if privacy_definition.protect_sensitivity {
return Err(Error::from("custom sensitivities may only be passed if protect_sensitivity is disabled"))
}
check_sensitivity_properties(sensitivity_property.array()?, &data_property)?;
} else {
let aggregator = data_property.aggregator.as_ref()
.ok_or_else(|| Error::from("aggregator: missing"))?;
let mut sensitivity_value = aggregator.component.compute_sensitivity(
privacy_definition,
&aggregator.properties,
&sensitivity_type)?;
match aggregator.lipschitz_constants.clone().array()? {
Array::Float(lipschitz) => {
if lipschitz.iter().any(|v| v != &1.) {
let mut sensitivity = sensitivity_value.array()?.float()?;
sensitivity.mul_assign(&lipschitz);
sensitivity_value = sensitivity.into();
}
},
Array::Int(lipschitz) => {
if lipschitz.iter().any(|v| v != &1) {
let mut sensitivity = sensitivity_value.array()?.int()?;
sensitivity.mul_assign(&lipschitz);
sensitivity_value = sensitivity.into();
}
},
_ => return Err(Error::from("lipschitz constants must be numeric"))
};
maximum_id += 1;
let id_sensitivity = maximum_id;
let (patch_node, release) = get_literal(sensitivity_value.clone(), component.submission)?;
expansion.computation_graph.insert(id_sensitivity, patch_node);
expansion.properties.insert(id_sensitivity, infer_property(&release.value, None, id_sensitivity)?);
expansion.releases.insert(id_sensitivity, release);
noise_component.insert_argument(&"sensitivity".into(), id_sensitivity);
}
expansion.computation_graph.insert(component_id, noise_component);
Ok(expansion)
}
pub fn check_sensitivity_properties(
sensitivity_property: &ArrayProperties, data_property: &ArrayProperties
) -> Result<()> {
if sensitivity_property.num_columns()? != data_property.num_columns()? {
return Err(Error::from(format!("sensitivity has {:?} columns, while the expected shape has {:?} columns.", sensitivity_property.num_columns()?, data_property.num_columns()?)));
}
if sensitivity_property.num_records()? != data_property.num_records()? {
return Err(Error::from(format!("sensitivity has {:?} records, while the expected shape has {:?} records.", sensitivity_property.num_records()?, data_property.num_records()?)));
}
if sensitivity_property.dimensionality.map(|dim| dim > 2).unwrap_or(false) {
return Err(Error::from("sensitivity may not have dimensionality greater than 2"))
}
Ok(())
}
#[allow(clippy::ptr_arg)]
pub fn get_common_value<T: Clone + Eq>(values: &Vec<T>) -> Option<T> {
if values.windows(2).all(|w| w[0] == w[1]) {
values.first().cloned()
} else { None }
}
pub fn get_dependents(graph: &HashMap<u32, proto::Component>) -> HashMap<u32, HashSet<u32>> {
let mut dependents = HashMap::<u32, HashSet<u32>>::new();
graph.iter().for_each(|(node_id, component)| {
component.arguments().values().for_each(|source_node_id| {
dependents
.entry(*source_node_id)
.or_insert_with(HashSet::<u32>::new)
.insert(*node_id);
})
});
dependents
}
pub fn deduplicate<T: Eq + Hash + Clone>(values: Vec<T>) -> Vec<T> {
values.into_iter().unique().collect()
}
#[cfg(test)]
mod test_utilities {
use crate::utilities;
#[test]
fn test_deduplicate() {
let values = vec![2, 0, 1, 0];
let deduplicated = utilities::deduplicate(values.clone());
assert!(deduplicated == vec![2, 0, 1]);
}
}