1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
#[cfg(target_arch="x86")] use core::arch::x86::*; #[cfg(target_arch="x86_64")] use core::arch::x86_64::*; #[macro_use] mod macros; pub(crate) struct FusedMulAdd; pub(crate) struct AvxMulAdd; pub(crate) trait SMultiplyAdd { const IS_FUSED: bool; unsafe fn multiply_add(__m256, __m256, __m256) -> __m256; } impl SMultiplyAdd for AvxMulAdd { const IS_FUSED: bool = false; #[inline(always)] unsafe fn multiply_add(a: __m256, b: __m256, c: __m256) -> __m256 { _mm256_add_ps(_mm256_mul_ps(a, b), c) } } impl SMultiplyAdd for FusedMulAdd { const IS_FUSED: bool = true; #[inline(always)] unsafe fn multiply_add(a: __m256, b: __m256, c: __m256) -> __m256 { _mm256_fmadd_ps(a, b, c) } } pub(crate) trait DMultiplyAdd { const IS_FUSED: bool; unsafe fn multiply_add(__m256d, __m256d, __m256d) -> __m256d; } impl DMultiplyAdd for AvxMulAdd { const IS_FUSED: bool = false; #[inline(always)] unsafe fn multiply_add(a: __m256d, b: __m256d, c: __m256d) -> __m256d { _mm256_add_pd(_mm256_mul_pd(a, b), c) } } impl DMultiplyAdd for FusedMulAdd { const IS_FUSED: bool = true; #[inline(always)] unsafe fn multiply_add(a: __m256d, b: __m256d, c: __m256d) -> __m256d { _mm256_fmadd_pd(a, b, c) } }