use crate::distribution::{Continuous, Univariate};
use crate::function::{beta, gamma};
use rand::distributions::Distribution;
use rand::Rng;
use crate::statistics::*;
use std::f64;
use crate::{Result, StatsError};
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct Beta {
shape_a: f64,
shape_b: f64,
}
impl Beta {
pub fn new(shape_a: f64, shape_b: f64) -> Result<Beta> {
let is_nan = shape_a.is_nan() || shape_b.is_nan();
match (shape_a, shape_b, is_nan) {
(_, _, true) => Err(StatsError::BadParams),
(_, _, false) if shape_a <= 0.0 || shape_b <= 0.0 => Err(StatsError::BadParams),
(_, _, false) => Ok(Beta {
shape_a: shape_a,
shape_b: shape_b,
}),
}
}
pub fn shape_a(&self) -> f64 {
self.shape_a
}
pub fn shape_b(&self) -> f64 {
self.shape_b
}
}
impl Distribution<f64> for Beta {
fn sample<R: Rng + ?Sized>(&self, r: &mut R) -> f64 {
let x = super::gamma::sample_unchecked(r, self.shape_a, 1.0);
let y = super::gamma::sample_unchecked(r, self.shape_b, 1.0);
x / (x + y)
}
}
impl Univariate<f64, f64> for Beta {
fn cdf(&self, x: f64) -> f64 {
if x < 0.0 {
0.0
} else if x >= 1.0 {
1.0
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x < 0.5 {
0.0
} else {
1.0
}
} else if self.shape_a == f64::INFINITY {
if x < 1.0 {
0.0
} else {
1.0
}
} else if self.shape_b == f64::INFINITY {
1.0
} else if self.shape_a == 1.0 && self.shape_b == 1.0 {
x
} else {
beta::beta_reg(self.shape_a, self.shape_b, x)
}
}
}
impl Min<f64> for Beta {
fn min(&self) -> f64 {
0.0
}
}
impl Max<f64> for Beta {
fn max(&self) -> f64 {
1.0
}
}
impl Mean<f64> for Beta {
fn mean(&self) -> f64 {
if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
0.5
} else if self.shape_a == f64::INFINITY {
1.0
} else if self.shape_b == f64::INFINITY {
0.0
} else {
self.shape_a / (self.shape_a + self.shape_b)
}
}
}
impl Variance<f64> for Beta {
fn variance(&self) -> f64 {
self.shape_a * self.shape_b
/ ((self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b)
* (self.shape_a + self.shape_b + 1.0))
}
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
}
impl Entropy<f64> for Beta {
fn entropy(&self) -> f64 {
if self.shape_a == f64::INFINITY || self.shape_b == f64::INFINITY {
0.0
} else {
beta::ln_beta(self.shape_a, self.shape_b)
- (self.shape_a - 1.0) * gamma::digamma(self.shape_a)
- (self.shape_b - 1.0) * gamma::digamma(self.shape_b)
+ (self.shape_a + self.shape_b - 2.0) * gamma::digamma(self.shape_a + self.shape_b)
}
}
}
impl Skewness<f64> for Beta {
fn skewness(&self) -> f64 {
if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
0.0
} else if self.shape_a == f64::INFINITY {
-2.0
} else if self.shape_b == f64::INFINITY {
2.0
} else {
2.0 * (self.shape_b - self.shape_a) * (self.shape_a + self.shape_b + 1.0).sqrt()
/ ((self.shape_a + self.shape_b + 2.0) * (self.shape_a * self.shape_b).sqrt())
}
}
}
impl Mode<f64> for Beta {
fn mode(&self) -> f64 {
self.checked_mode().unwrap()
}
}
impl CheckedMode<f64> for Beta {
fn checked_mode(&self) -> Result<f64> {
if self.shape_a <= 1.0 {
Err(StatsError::ArgGt("shape_a", 1.0))
} else if self.shape_b <= 1.0 {
Err(StatsError::ArgGt("shape_b", 1.0))
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
Ok(0.5)
} else if self.shape_a == f64::INFINITY {
Ok(1.0)
} else if self.shape_b == f64::INFINITY {
Ok(0.0)
} else {
Ok((self.shape_a - 1.0) / (self.shape_a + self.shape_b - 2.0))
}
}
}
impl Continuous<f64, f64> for Beta {
fn pdf(&self, x: f64) -> f64 {
if x < 0.0 || x > 1.0 {
0.0
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x == 0.5 {
f64::INFINITY
} else {
0.0
}
} else if self.shape_a == f64::INFINITY {
if x == 1.0 {
f64::INFINITY
} else {
0.0
}
} else if self.shape_b == f64::INFINITY {
if x == 0.0 {
f64::INFINITY
} else {
0.0
}
} else if self.shape_a == 1.0 && self.shape_b == 1.0 {
1.0
} else if self.shape_a > 80.0 || self.shape_b > 80.0 {
self.ln_pdf(x).exp()
} else {
let bb = gamma::gamma(self.shape_a + self.shape_b)
/ (gamma::gamma(self.shape_a) * gamma::gamma(self.shape_b));
bb * x.powf(self.shape_a - 1.0) * (1.0 - x).powf(self.shape_b - 1.0)
}
}
fn ln_pdf(&self, x: f64) -> f64 {
if x < 0.0 || x > 1.0 {
f64::NEG_INFINITY
} else if self.shape_a == f64::INFINITY && self.shape_b == f64::INFINITY {
if x == 0.5 {
f64::INFINITY
} else {
f64::NEG_INFINITY
}
} else if self.shape_a == f64::INFINITY {
if x == 1.0 {
f64::INFINITY
} else {
f64::NEG_INFINITY
}
} else if self.shape_b == f64::INFINITY {
if x == 0.0 {
f64::INFINITY
} else {
f64::NEG_INFINITY
}
} else if self.shape_a == 1.0 && self.shape_b == 1.0 {
0.0
} else {
let aa = gamma::ln_gamma(self.shape_a + self.shape_b)
- gamma::ln_gamma(self.shape_a)
- gamma::ln_gamma(self.shape_b);
let bb = if self.shape_a == 1.0 && x == 0.0 {
0.0
} else if x == 0.0 {
f64::NEG_INFINITY
} else {
(self.shape_a - 1.0) * x.ln()
};
let cc = if self.shape_b == 1.0 && x == 1.0 {
0.0
} else if x == 1.0 {
f64::NEG_INFINITY
} else {
(self.shape_b - 1.0) * (1.0 - x).ln()
};
aa + bb + cc
}
}
}
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(test)]
mod test {
use std::f64;
use crate::statistics::*;
use crate::distribution::{Univariate, Continuous, Beta};
use crate::distribution::internal::*;
fn try_create(shape_a: f64, shape_b: f64) -> Beta {
let n = Beta::new(shape_a, shape_b);
assert!(n.is_ok());
n.unwrap()
}
fn create_case(shape_a: f64, shape_b: f64) {
let n = try_create(shape_a, shape_b);
assert_eq!(n.shape_a(), shape_a);
assert_eq!(n.shape_b(), shape_b);
}
fn bad_create_case(shape_a: f64, shape_b: f64) {
let n = Beta::new(shape_a, shape_b);
assert!(n.is_err());
}
fn get_value<F>(shape_a: f64, shape_b: f64, eval: F) -> f64
where F: Fn(Beta) -> f64
{
let n = try_create(shape_a, shape_b);
eval(n)
}
fn test_case<F>(shape_a: f64, shape_b: f64, expected: f64, eval: F)
where F: Fn(Beta) -> f64
{
let x = get_value(shape_a, shape_b, eval);
assert_eq!(expected, x);
}
fn test_almost<F>(shape_a: f64, shape_b: f64, expected: f64, acc: f64, eval: F)
where F: Fn(Beta) -> f64
{
let x = get_value(shape_a, shape_b, eval);
assert_almost_eq!(expected, x, acc);
}
fn test_is_nan<F>(shape_a: f64, shape_b: f64, eval: F)
where F: Fn(Beta) -> f64
{
assert!(get_value(shape_a, shape_b, eval).is_nan())
}
#[test]
fn test_create() {
create_case(1.0, 1.0);
create_case(9.0, 1.0);
create_case(5.0, 100.0);
create_case(1.0, f64::INFINITY);
create_case(f64::INFINITY, 1.0);
}
#[test]
fn test_bad_create() {
bad_create_case(0.0, 0.0);
bad_create_case(0.0, 0.1);
bad_create_case(1.0, 0.0);
bad_create_case(0.0, f64::INFINITY);
bad_create_case(f64::INFINITY, 0.0);
bad_create_case(f64::NAN, 1.0);
bad_create_case(1.0, f64::NAN);
bad_create_case(f64::NAN, f64::NAN);
bad_create_case(1.0, -1.0);
bad_create_case(-1.0, 1.0);
bad_create_case(-1.0, -1.0);
}
#[test]
fn test_mean() {
test_case(1.0, 1.0, 0.5, |x| x.mean());
test_case(9.0, 1.0, 0.9, |x| x.mean());
test_case(5.0, 100.0, 0.047619047619047619047616, |x| x.mean());
test_case(1.0, f64::INFINITY, 0.0, |x| x.mean());
test_case(f64::INFINITY, 1.0, 1.0, |x| x.mean());
test_case(f64::INFINITY, f64::INFINITY, 0.5, |x| x.mean());
}
#[test]
fn test_variance() {
test_case(1.0, 1.0, 1.0 / 12.0, |x| x.variance());
test_case(9.0, 1.0, 9.0 / 1100.0, |x| x.variance());
test_case(5.0, 100.0, 500.0 / 1168650.0, |x| x.variance());
test_is_nan(1.0, f64::INFINITY, |x| x.variance());
test_is_nan(f64::INFINITY, 1.0, |x| x.variance());
test_is_nan(f64::INFINITY, f64::INFINITY, |x| x.variance());
}
#[test]
fn test_std_dev() {
test_case(1.0, 1.0, (1f64 / 12.0).sqrt(), |x| x.std_dev());
test_case(9.0, 1.0, (9f64 / 1100.0).sqrt(), |x| x.std_dev());
test_case(5.0, 100.0, (500f64 / 1168650.0).sqrt(), |x| x.std_dev());
test_is_nan(1.0, f64::INFINITY, |x| x.std_dev());
test_is_nan(f64::INFINITY, 1.0, |x| x.std_dev());
test_is_nan(f64::INFINITY, f64::INFINITY, |x| x.std_dev());
}
#[test]
fn test_entropy() {
test_almost(1.0, 1.0, 0.0, 1e-15, |x| x.entropy());
test_almost(9.0, 1.0, -1.3083356884473304939016015849561625204060922267565917, 1e-13, |x| x.entropy());
test_almost(5.0, 100.0, -2.5201623187602743679459255108827601222133603091753153, 1e-13, |x| x.entropy());
test_case(1.0, f64::INFINITY, 0.0, |x| x.entropy());
test_case(f64::INFINITY, 1.0, 0.0, |x| x.entropy());
test_case(f64::INFINITY, f64::INFINITY, 0.0, |x| x.entropy());
}
#[test]
fn test_skewness() {
test_case(1.0, 1.0, 0.0, |x| x.skewness());
test_almost(9.0, 1.0, -1.4740554623801777107177478829647496373009282424841579, 1e-15, |x| x.skewness());
test_almost(5.0, 100.0, 0.81759410927553430354583159143895018978562196953345572, 1e-15, |x| x.skewness());
test_case(1.0, f64::INFINITY, 2.0, |x| x.skewness());
test_case(f64::INFINITY, 1.0, -2.0, |x| x.skewness());
test_case(f64::INFINITY, f64::INFINITY, 0.0, |x| x.skewness());
}
#[test]
fn test_mode() {
test_case(5.0, 100.0, 0.038834951456310676243255386452801758423447608947753906, |x| x.mode());
test_case(2.0, f64::INFINITY, 0.0, |x| x.mode());
test_case(f64::INFINITY, 2.0, 1.0, |x| x.mode());
test_case(f64::INFINITY, f64::INFINITY, 0.5, |x| x.mode());
}
#[test]
#[should_panic]
fn test_mode_shape_a_lte_1() {
get_value(1.0, 5.0, |x| x.mode());
}
#[test]
#[should_panic]
fn test_mode_shape_b_lte_1() {
get_value(5.0, 1.0, |x| x.mode());
}
#[test]
fn test_checked_mode_shape_a_lte_1() {
let n = try_create(1.0, 5.0);
assert!(n.checked_mode().is_err());
}
#[test]
fn test_checked_mode_shape_b_lte_1() {
let n = try_create(5.0, 1.0);
assert!(n.checked_mode().is_err());
}
#[test]
fn test_min_max() {
test_case(1.0, 1.0, 0.0, |x| x.min());
test_case(1.0, 1.0, 1.0, |x| x.max());
}
#[test]
fn test_pdf() {
test_case(1.0, 1.0, 1.0, |x| x.pdf(0.0));
test_case(1.0, 1.0, 1.0, |x| x.pdf(0.5));
test_case(1.0, 1.0, 1.0, |x| x.pdf(1.0));
test_case(9.0, 1.0, 0.0, |x| x.pdf(0.0));
test_almost(9.0, 1.0, 0.03515625, 1e-15, |x| x.pdf(0.5));
test_almost(9.0, 1.0, 9.0, 1e-13, |x| x.pdf(1.0));
test_case(5.0, 100.0, 0.0, |x| x.pdf(0.0));
test_almost(5.0, 100.0, 4.534102298350337661e-23, 1e-35, |x| x.pdf(0.5));
test_case(5.0, 100.0, 0.0, |x| x.pdf(1.0));
test_case(5.0, 100.0, 0.0, |x| x.pdf(1.0));
test_case(1.0, f64::INFINITY, f64::INFINITY, |x| x.pdf(0.0));
test_case(1.0, f64::INFINITY, 0.0, |x| x.pdf(0.5));
test_case(1.0, f64::INFINITY, 0.0, |x| x.pdf(1.0));
test_case(f64::INFINITY, 1.0, 0.0, |x| x.pdf(0.0));
test_case(f64::INFINITY, 1.0, 0.0, |x| x.pdf(0.5));
test_case(f64::INFINITY, 1.0, f64::INFINITY, |x| x.pdf(1.0));
test_case(f64::INFINITY, f64::INFINITY, 0.0, |x| x.pdf(0.0));
test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, |x| x.pdf(0.5));
test_case(f64::INFINITY, f64::INFINITY, 0.0, |x| x.pdf(1.0));
}
#[test]
fn test_pdf_input_lt_0() {
test_case(1.0, 1.0, 0.0, |x| x.pdf(-1.0));
}
#[test]
fn test_pdf_input_gt_0() {
test_case(1.0, 1.0, 0.0, |x| x.pdf(2.0));
}
#[test]
fn test_ln_pdf() {
test_case(1.0, 1.0, 0.0, |x| x.ln_pdf(0.0));
test_case(1.0, 1.0, 0.0, |x| x.ln_pdf(0.5));
test_case(1.0, 1.0, 0.0, |x| x.ln_pdf(1.0));
test_case(9.0, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(0.0));
test_almost(9.0, 1.0, -3.3479528671433430925473664978203611353090199592365458, 1e-13, |x| x.ln_pdf(0.5));
test_almost(9.0, 1.0, 2.1972245773362193827904904738450514092949811156454996, 1e-13, |x| x.ln_pdf(1.0));
test_case(5.0, 100.0, f64::NEG_INFINITY, |x| x.ln_pdf(0.0));
test_almost(5.0, 100.0, -51.447830024537682154565870837960406410586196074573801, 1e-12, |x| x.ln_pdf(0.5));
test_case(5.0, 100.0, f64::NEG_INFINITY, |x| x.ln_pdf(1.0));
test_case(1.0, f64::INFINITY, f64::INFINITY, |x| x.ln_pdf(0.0));
test_case(1.0, f64::INFINITY, f64::NEG_INFINITY, |x| x.ln_pdf(0.5));
test_case(1.0, f64::INFINITY, f64::NEG_INFINITY, |x| x.ln_pdf(1.0));
test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(0.0));
test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(0.5));
test_case(f64::INFINITY, 1.0, f64::INFINITY, |x| x.ln_pdf(1.0));
test_case(f64::INFINITY, f64::INFINITY, f64::NEG_INFINITY, |x| x.ln_pdf(0.0));
test_case(f64::INFINITY, f64::INFINITY, f64::INFINITY, |x| x.ln_pdf(0.5));
test_case(f64::INFINITY, f64::INFINITY, f64::NEG_INFINITY, |x| x.ln_pdf(1.0));
}
#[test]
fn test_ln_pdf_input_lt_0() {
test_case(1.0, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(-1.0));
}
#[test]
fn test_ln_pdf_input_gt_1() {
test_case(1.0, 1.0, f64::NEG_INFINITY, |x| x.ln_pdf(2.0));
}
#[test]
fn test_cdf() {
test_case(1.0, 1.0, 0.0, |x| x.cdf(0.0));
test_case(1.0, 1.0, 0.5, |x| x.cdf(0.5));
test_case(1.0, 1.0, 1.0, |x| x.cdf(1.0));
test_case(9.0, 1.0, 0.0, |x| x.cdf(0.0));
test_almost(9.0, 1.0, 0.001953125, 1e-16, |x| x.cdf(0.5));
test_case(9.0, 1.0, 1.0, |x| x.cdf(1.0));
test_case(5.0, 100.0, 0.0, |x| x.cdf(0.0));
test_case(5.0, 100.0, 1.0, |x| x.cdf(0.5));
test_case(5.0, 100.0, 1.0, |x| x.cdf(1.0));
test_case(1.0, f64::INFINITY, 1.0, |x| x.cdf(0.0));
test_case(1.0, f64::INFINITY, 1.0, |x| x.cdf(0.5));
test_case(1.0, f64::INFINITY, 1.0, |x| x.cdf(1.0));
test_case(f64::INFINITY, 1.0, 0.0, |x| x.cdf(0.0));
test_case(f64::INFINITY, 1.0, 0.0, |x| x.cdf(0.5));
test_case(f64::INFINITY, 1.0, 1.0, |x| x.cdf(1.0));
test_case(f64::INFINITY, f64::INFINITY, 0.0, |x| x.cdf(0.0));
test_case(f64::INFINITY, f64::INFINITY, 1.0, |x| x.cdf(0.5));
test_case(f64::INFINITY, f64::INFINITY, 1.0, |x| x.cdf(1.0));
}
#[test]
fn test_cdf_input_lt_0() {
test_case(1.0, 1.0, 0.0, |x| x.cdf(-1.0));
}
#[test]
fn test_cdf_input_gt_1() {
test_case(1.0, 1.0, 1.0, |x| x.cdf(2.0));
}
#[test]
fn test_continuous() {
test::check_continuous_distribution(&try_create(1.2, 3.4), 0.0, 1.0);
test::check_continuous_distribution(&try_create(4.5, 6.7), 0.0, 1.0);
}
}