use crate::distribution::{Discrete, Univariate};
use crate::function::{factorial, gamma};
use rand::distributions::Distribution;
use rand::Rng;
use crate::statistics::*;
use std::f64;
use std::u64;
use crate::{Result, StatsError};
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Poisson {
lambda: f64,
}
impl Poisson {
pub fn new(lambda: f64) -> Result<Poisson> {
if lambda.is_nan() || lambda <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(Poisson { lambda: lambda })
}
}
pub fn lambda(&self) -> f64 {
self.lambda
}
}
impl Distribution<f64> for Poisson {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> f64 {
sample_unchecked(r, self.lambda)
}
}
impl Univariate<u64, f64> for Poisson {
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
0.0
} else if x == f64::INFINITY {
1.0
} else {
1.0 - gamma::gamma_lr(x.floor() + 1.0, self.lambda)
}
}
}
impl Min<u64> for Poisson {
fn min(&self) -> u64 {
0
}
}
impl Max<u64> for Poisson {
fn max(&self) -> u64 {
u64::MAX
}
}
impl Mean<f64> for Poisson {
fn mean(&self) -> f64 {
self.lambda
}
}
impl Variance<f64> for Poisson {
fn variance(&self) -> f64 {
self.lambda
}
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
}
impl Entropy<f64> for Poisson {
fn entropy(&self) -> f64 {
0.5 * (2.0 * f64::consts::PI * f64::consts::E * self.lambda).ln()
- 1.0 / (12.0 * self.lambda)
- 1.0 / (24.0 * self.lambda * self.lambda)
- 19.0 / (360.0 * self.lambda * self.lambda * self.lambda)
}
}
impl Skewness<f64> for Poisson {
fn skewness(&self) -> f64 {
1.0 / self.lambda.sqrt()
}
}
impl Median<f64> for Poisson {
fn median(&self) -> f64 {
(self.lambda + 1.0 / 3.0 - 0.02 / self.lambda).floor()
}
}
impl Mode<u64> for Poisson {
fn mode(&self) -> u64 {
self.lambda.floor() as u64
}
}
impl Discrete<u64, f64> for Poisson {
fn pmf(&self, x: u64) -> f64 {
(-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64)).exp()
}
fn ln_pmf(&self, x: u64) -> f64 {
-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64)
}
}
fn sample_unchecked<R: Rng + ?Sized>(r: &mut R, lambda: f64) -> f64 {
if lambda < 30.0 {
let limit = (-lambda).exp();
let mut count = 0.0;
let mut product: f64 = r.gen();
while product >= limit {
count += 1.0;
product *= r.gen::<f64>();
}
count
} else {
let c = 0.767 - 3.36 / lambda;
let beta = f64::consts::PI / (3.0 * lambda).sqrt();
let alpha = beta * lambda;
let k = c.ln() - lambda - beta.ln();
loop {
let u: f64 = r.gen();
let x = (alpha - ((1.0 - u) / u).ln()) / beta;
let n = (x + 0.5).floor();
if n < 0.0 {
continue;
}
let v: f64 = r.gen();
let y = alpha - beta * x;
let temp = 1.0 + y.exp();
let lhs = y + (v / (temp * temp)).ln();
let rhs = k + n * lambda.ln() - factorial::ln_factorial(n as u64);
if lhs <= rhs {
return n;
}
}
}
}
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(test)]
mod test {
use std::fmt::Debug;
use std::f64;
use std::u64;
use crate::statistics::*;
use crate::distribution::{Univariate, Discrete, Poisson};
use crate::distribution::internal::*;
fn try_create(lambda: f64) -> Poisson {
let n = Poisson::new(lambda);
assert!(n.is_ok());
n.unwrap()
}
fn create_case(lambda: f64) {
let n = try_create(lambda);
assert_eq!(lambda, n.lambda());
}
fn bad_create_case(lambda: f64) {
let n = Poisson::new(lambda);
assert!(n.is_err());
}
fn get_value<T, F>(lambda: f64, eval: F) -> T
where T: PartialEq + Debug,
F: Fn(Poisson) -> T
{
let n = try_create(lambda);
eval(n)
}
fn test_case<T, F>(lambda: f64, expected: T, eval: F)
where T: PartialEq + Debug,
F: Fn(Poisson) -> T
{
let x = get_value(lambda, eval);
assert_eq!(expected, x);
}
fn test_almost<F>(lambda: f64, expected: f64, acc: f64, eval: F)
where F: Fn(Poisson) -> f64
{
let x = get_value(lambda, eval);
assert_almost_eq!(expected, x, acc);
}
#[test]
fn test_create() {
create_case(1.5);
create_case(5.4);
create_case(10.8);
}
#[test]
fn test_bad_create() {
bad_create_case(f64::NAN);
bad_create_case(-1.5);
bad_create_case(0.0);
}
#[test]
fn test_mean() {
test_case(1.5, 1.5, |x| x.mean());
test_case(5.4, 5.4, |x| x.mean());
test_case(10.8, 10.8, |x| x.mean());
}
#[test]
fn test_variance() {
test_case(1.5, 1.5, |x| x.variance());
test_case(5.4, 5.4, |x| x.variance());
test_case(10.8, 10.8, |x| x.variance());
}
#[test]
fn test_std_dev() {
test_case(1.5, (1.5f64).sqrt(), |x| x.std_dev());
test_case(5.4, (5.4f64).sqrt(), |x| x.std_dev());
test_case(10.8, (10.8f64).sqrt(), |x| x.std_dev());
}
#[test]
fn test_entropy() {
test_almost(1.5, 1.531959153102376331946, 1e-15, |x| x.entropy());
test_almost(5.4, 2.244941839577643504608, 1e-15, |x| x.entropy());
test_case(10.8, 2.600596429676975222694, |x| x.entropy());
}
#[test]
fn test_skewness() {
test_almost(1.5, 0.8164965809277260327324, 1e-15, |x| x.skewness());
test_almost(5.4, 0.4303314829119352094644, 1e-16, |x| x.skewness());
test_almost(10.8, 0.3042903097250922852539, 1e-16, |x| x.skewness());
}
#[test]
fn test_median() {
test_case(1.5, 1.0, |x| x.median());
test_case(5.4, 5.0, |x| x.median());
test_case(10.8, 11.0, |x| x.median());
}
#[test]
fn test_mode() {
test_case(1.5, 1, |x| x.mode());
test_case(5.4, 5, |x| x.mode());
test_case(10.8, 10, |x| x.mode());
}
#[test]
fn test_min_max() {
test_case(1.5, 0, |x| x.min());
test_case(5.4, 0, |x| x.min());
test_case(10.8, 0, |x| x.min());
test_case(1.5, u64::MAX, |x| x.max());
test_case(5.4, u64::MAX, |x| x.max());
test_case(10.8, u64::MAX, |x| x.max());
}
#[test]
fn test_pmf() {
test_almost(1.5, 0.334695240222645000000000000000, 1e-15, |x| x.pmf(1));
test_almost(1.5, 0.000003545747740570180000000000, 1e-20, |x| x.pmf(10));
test_almost(1.5, 0.000000000000000304971208961018, 1e-30, |x| x.pmf(20));
test_almost(5.4, 0.024389537090108400000000000000, 1e-17, |x| x.pmf(1));
test_almost(5.4, 0.026241240591792300000000000000, 1e-16, |x| x.pmf(10));
test_almost(5.4, 0.000000825202200316548000000000, 1e-20, |x| x.pmf(20));
test_almost(10.8, 0.000220314636840657000000000000, 1e-18, |x| x.pmf(1));
test_almost(10.8, 0.121365183659420000000000000000, 1e-15, |x| x.pmf(10));
test_almost(10.8, 0.003908139778574110000000000000, 1e-16, |x| x.pmf(20));
}
#[test]
fn test_ln_pmf() {
test_almost(1.5, -1.09453489189183485135413967177, 1e-15, |x| x.ln_pmf(1));
test_almost(1.5, -12.5497614919938728510400000000, 1e-14, |x| x.ln_pmf(10));
test_almost(1.5, -35.7263142985901000000000000000, 1e-13, |x| x.ln_pmf(20));
test_case(5.4, -3.71360104642977159156055355910, |x| x.ln_pmf(1));
test_almost(5.4, -3.64042303737322774736223038530, 1e-15, |x| x.ln_pmf(10));
test_almost(5.4, -14.0076373893489089949388000000, 1e-14, |x| x.ln_pmf(20));
test_almost(10.8, -8.42045386586982559781714423000, 1e-14, |x| x.ln_pmf(1));
test_almost(10.8, -2.10895123177378079525424989992, 1e-14, |x| x.ln_pmf(10));
test_almost(10.8, -5.54469377815000936289610059500, 1e-14, |x| x.ln_pmf(20));
}
#[test]
fn test_cdf() {
test_almost(1.5, 0.5578254003710750000000, 1e-15, |x| x.cdf(1.0));
test_almost(1.5, 0.9999994482467640000000, 1e-15, |x| x.cdf(10.0));
test_case(1.5, 1.0, |x| x.cdf(20.0));
test_almost(5.4, 0.0289061180327211000000, 1e-16, |x| x.cdf(1.0));
test_almost(5.4, 0.9774863006897650000000, 1e-15, |x| x.cdf(10.0));
test_almost(5.4, 0.9999997199928290000000, 1e-15, |x| x.cdf(20.0));
test_almost(10.8, 0.0002407141402518290000, 1e-16, |x| x.cdf(1.0));
test_almost(10.8, 0.4839692359955690000000, 1e-15, |x| x.cdf(10.0));
test_almost(10.8, 0.9961800769608090000000, 1e-15, |x| x.cdf(20.0));
}
#[test]
fn test_neg_cdf() {
test_case(1.5, 0.0, |x| x.cdf(-1.0));
}
#[test]
fn test_discrete() {
test::check_discrete_distribution(&try_create(0.3), 10);
test::check_discrete_distribution(&try_create(4.5), 30);
}
}