mirror of
https://github.com/Cian-H/fastmath.git
synced 2025-12-22 22:22:02 +00:00
>99.7% speedup to lookup tables
This commit is contained in:
@@ -10,19 +10,13 @@ fn pow2_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin_fn", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin_mul", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x) * x).collect::<Vec<f64>>())
|
||||
b.iter(|| x_f64.iter().map(|&x| 2.0f64.powf(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f32_fast", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin_fn", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin_mul", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x) * x).collect::<Vec<f32>>())
|
||||
b.iter(|| x_f32.iter().map(|&x| 2.0f32.powf(black_box(x))).collect::<Vec<f32>>())
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
139
build.rs
139
build.rs
@@ -1,42 +1,137 @@
|
||||
// build.rs
|
||||
|
||||
mod precalculate_lookup_tables {
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
use std::fs::{create_dir_all, File};
|
||||
use std::io::Write;
|
||||
include!("src/lookup/lookup_table.rs");
|
||||
use bincode::serialize;
|
||||
include!("src/lookup/config.rs");
|
||||
// use bincode::serialize;
|
||||
|
||||
const PRECISION: usize = 1000;
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
|
||||
let mut file = File::create("src/lookup/data/sin_f32.bin")?;
|
||||
file.write_all(&data)?;
|
||||
// let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
|
||||
let mut file = File::create("src/lookup/data/sin_f64.bin")?;
|
||||
file.write_all(&data)?;
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
Ok(())
|
||||
// fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/cos_f32.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/cos_f64.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
// let half_step: f32 = step / 2.0;
|
||||
|
||||
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)) - half_step
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)).sin()
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
// let half_step: f64 = step / 2.0;
|
||||
|
||||
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)) - half_step
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)).sin()
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
macro_rules! precalculate_sin_tables {
|
||||
() => {{
|
||||
let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
let half_step: f32 = step / 2.0;
|
||||
|
||||
let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f32)) - half_step
|
||||
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f32)).sin()
|
||||
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
let data = format!("pub const SIN_F32_KEYS: [f32; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
let mut file = File::create("src/lookup/data/sin_f32.rs").expect("Failed to create sin_f32.rs");
|
||||
file.write_all(data.as_bytes()).expect("Failed to write sin_f32.rs");
|
||||
|
||||
let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
let half_step: f64 = step / 2.0;
|
||||
|
||||
let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f64)) - half_step
|
||||
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f64)).sin()
|
||||
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
let data = format!("pub const SIN_F64_KEYS: [f64; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
let mut file = File::create("src/lookup/data/sin_f64.rs").expect("Failed to create sin_f64.rs");
|
||||
file.write_all(data.as_bytes()).expect("Failed to write sin_f64.rs");
|
||||
}};
|
||||
}
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
// let half_step: f32 = step / 2.0;
|
||||
|
||||
fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
|
||||
let mut file = File::create("src/lookup/data/cos_f32.bin")?;
|
||||
file.write_all(&data)?;
|
||||
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)) - half_step
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)).sin()
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
|
||||
let mut file = File::create("src/lookup/data/cos_f64.bin")?;
|
||||
file.write_all(&data)?;
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
// let half_step: f64 = step / 2.0;
|
||||
|
||||
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)) - half_step
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)).sin()
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
|
||||
create_dir_all("src/lookup/data")?;
|
||||
|
||||
precalculate_sin_tables()?;
|
||||
precalculate_cos_tables()?;
|
||||
precalculate_sin_tables!();
|
||||
// precalculate_cos_tables()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
152
src/fastmath.rs
152
src/fastmath.rs
@@ -4,12 +4,19 @@
|
||||
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
use crate::lookup::*;
|
||||
// use crate::lookup::*;
|
||||
use crate::lookup::lookup_table::EndoCosLookupTable;
|
||||
|
||||
pub trait FastMath: FastCos + FastPow2 + FastExp + FastSigmoid {}
|
||||
const COS_LOOKUP_F32: EndoCosLookupTable<f32> = EndoCosLookupTable::<f32>::new();
|
||||
const COS_LOOKUP_F64: EndoCosLookupTable<f64> = EndoCosLookupTable::<f64>::new();
|
||||
|
||||
pub trait FastMath: FastCos + FastExp + FastSigmoid {}
|
||||
impl FastMath for f32 {}
|
||||
impl FastMath for f64 {}
|
||||
|
||||
const V_SCALE_F32: f32 = 8388608.0; // the largest possible mantissa of an f32
|
||||
const V_SCALE_F64: f64 = 4503599627370496.0; // the largest possible mantissa of an f64
|
||||
|
||||
|
||||
pub trait LookupCos {
|
||||
fn lookup_cos(self: Self) -> Self;
|
||||
@@ -35,93 +42,66 @@ pub trait FastCos {
|
||||
impl FastCos for f32 {
|
||||
#[inline]
|
||||
fn fast_cos(self: Self) -> f32 {
|
||||
const BITAND: u32 = u32::MAX / 2;
|
||||
const ONE: f32 = 1.0;
|
||||
let mod_x = (((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI;
|
||||
let v = mod_x.to_bits() & BITAND;
|
||||
let qpprox = ONE - f32_consts::FRAC_2_PI * f32::from_bits(v);
|
||||
let v = ((((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI).abs();
|
||||
let qpprox = ONE - f32_consts::FRAC_2_PI * v;
|
||||
qpprox + f32_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
|
||||
}
|
||||
}
|
||||
impl FastCos for f64 {
|
||||
#[inline]
|
||||
fn fast_cos(self: Self) -> f64 {
|
||||
const BITAND: u64 = u64::MAX / 2;
|
||||
const ONE: f64 = 1.0;
|
||||
let mod_x = (((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI;
|
||||
let v = mod_x.to_bits() & BITAND;
|
||||
let qpprox = ONE - f64_consts::FRAC_2_PI * f64::from_bits(v);
|
||||
let v = ((((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs();
|
||||
let qpprox = ONE - f64_consts::FRAC_2_PI * v;
|
||||
qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FastPow2 {
|
||||
fn fast_pow2(self: Self) -> Self;
|
||||
}
|
||||
impl FastPow2 for f32 {
|
||||
#[inline]
|
||||
fn fast_pow2(self: Self) -> f32 {
|
||||
// Khinchins constant over 3. IDK why it gives the best fit, but it does
|
||||
const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
|
||||
const CLIPP_THRESH: f32 = 0.12847338;
|
||||
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
|
||||
const CLIPP_SHIFT: f32 = 126.67740855;
|
||||
let abs_p = self.abs();
|
||||
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
|
||||
f32::from_bits(v) - KHINCHIN_3
|
||||
}
|
||||
}
|
||||
impl FastPow2 for f64 {
|
||||
#[inline]
|
||||
fn fast_pow2(self: Self) -> f64 {
|
||||
const KHINCHIN_3: f64 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
|
||||
const CLIPP_THRESH: f64 = -45774.9247660416;
|
||||
const V_SCALE: f64 = 4503599627370496.0; // (1i64 << 52) as f64
|
||||
const CLIPP_SHIFT: f64 = 1022.6769200000002;
|
||||
const ZERO: f64 = 0.;
|
||||
let abs_p = self.abs();
|
||||
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
|
||||
let y = f64::from_bits(v) - KHINCHIN_3;
|
||||
if y.is_sign_positive() {
|
||||
y
|
||||
} else {
|
||||
ZERO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FastExp {
|
||||
fn fast_exp(self: Self) -> Self;
|
||||
}
|
||||
impl FastExp for f32 {
|
||||
#[inline]
|
||||
fn fast_exp(self: Self) -> f32 {
|
||||
const CLIPP_THRESH: f32 = -126.0; // 0.12847338;
|
||||
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
|
||||
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
|
||||
const CLIPP_THRESH: f32 = -126.0; // exponent of smallest possible f32 to prevent underflow
|
||||
const CLIPP_SHIFT: f32 = 126.94269504; // shift to align curve, found by regression
|
||||
|
||||
let scaled_p = f32_consts::LOG2_E * self;
|
||||
let clipp = scaled_p.max(CLIPP_THRESH); // if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p };
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
|
||||
let clipp = scaled_p.max(CLIPP_THRESH);
|
||||
let v = (V_SCALE_F32 * (clipp + CLIPP_SHIFT)) as u32;
|
||||
f32::from_bits(v)
|
||||
}
|
||||
}
|
||||
impl FastExp for f64 {
|
||||
#[inline]
|
||||
fn fast_exp(self: Self) -> f64 {
|
||||
const CLIPP_THRESH: f64 = -180335.51911105003;
|
||||
const V_SCALE: f64 = 4524653012949098.0;
|
||||
const CLIPP_SHIFT: f64 = 1018.1563534409383;
|
||||
const CLIPP_THRESH: f64 = -1022.0; // exponent of smallest possible f64 to prevent underflow
|
||||
const CLIPP_SHIFT: f64 = 1022.9349439517318; // shift to align curve, found by regression
|
||||
|
||||
let scaled_p = f64_consts::LOG2_E * self;
|
||||
let clipp = scaled_p.max(CLIPP_THRESH); // let clipp = if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p };
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
|
||||
let clipp = scaled_p.max(CLIPP_THRESH);
|
||||
let v = (V_SCALE_F64 * (clipp + CLIPP_SHIFT)) as u64;
|
||||
f64::from_bits(v)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FastPow2 {
|
||||
fn fast_pow2(self: Self) -> Self;
|
||||
}
|
||||
impl FastPow2 for f32 {
|
||||
#[inline]
|
||||
fn fast_pow2(self: Self) -> f32 {
|
||||
(f32_consts::LN_2 * self).fast_exp()
|
||||
}
|
||||
}
|
||||
impl FastPow2 for f64 {
|
||||
#[inline]
|
||||
fn fast_pow2(self: Self) -> f64 {
|
||||
(f64_consts::LN_2 * self).fast_exp()
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FastSigmoid {
|
||||
fn fast_sigmoid(self: Self) -> Self;
|
||||
}
|
||||
@@ -140,63 +120,7 @@ impl FastSigmoid for f64 {
|
||||
}
|
||||
}
|
||||
|
||||
// A trait for testing and improving implementations of fast functions
|
||||
pub trait Test {
|
||||
fn test(self: Self) -> Self;
|
||||
}
|
||||
impl Test for f32 {
|
||||
#[inline]
|
||||
fn test(self: Self) -> f32 {
|
||||
// Khinchins constant over 3. IDK why it gives the best fit, but it does
|
||||
// const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
|
||||
const CLIPP_THRESH: f32 = -126.0; // 0.12847338;
|
||||
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
|
||||
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
|
||||
|
||||
let scaled_p = f32_consts::LOG2_E * self;
|
||||
let clipp = if scaled_p < CLIPP_THRESH {
|
||||
CLIPP_THRESH
|
||||
} else {
|
||||
scaled_p
|
||||
};
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
|
||||
f32::from_bits(v) // - KHINCHIN_3
|
||||
}
|
||||
}
|
||||
impl Test for f64 {
|
||||
#[inline]
|
||||
fn test(self: Self) -> f64 {
|
||||
const CLIPP_THRESH: f64 = -180335.51911105003;
|
||||
const V_SCALE: f64 = 4524653012949098.0;
|
||||
const CLIPP_SHIFT: f64 = 1018.1563534409383;
|
||||
|
||||
let scaled_p = f64_consts::LOG2_E * self;
|
||||
let clipp = if scaled_p < CLIPP_THRESH {
|
||||
CLIPP_THRESH
|
||||
} else {
|
||||
scaled_p
|
||||
};
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
|
||||
f64::from_bits(v)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(non_snake_case, dead_code)]
|
||||
pub fn optimizing(p: f64, CLIPP_THRESH: f64, V_SCALE: f64, CLIPP_SHIFT: f64) -> f64 {
|
||||
// const CLIPP_THRESH: f64 = -45774.9247660416;
|
||||
// const V_SCALE: f64 = 4503599627370496.0;
|
||||
// const CLIPP_SHIFT: f64 = 1022.6769200000002;
|
||||
|
||||
let scaled_p = f64_consts::LOG2_E * p;
|
||||
let clipp = if scaled_p < CLIPP_THRESH {
|
||||
CLIPP_THRESH
|
||||
} else {
|
||||
scaled_p
|
||||
};
|
||||
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
|
||||
f64::from_bits(v)
|
||||
}
|
||||
|
||||
// functions for testing the accuracy of fast functions against builtin functions
|
||||
#[inline]
|
||||
pub fn sigmoid_builtin_f32(p: f32) -> f32 {
|
||||
(1. + (-p).exp()).recip()
|
||||
|
||||
3
src/lookup/config.rs
Normal file
3
src/lookup/config.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
// lookup/config.rs
|
||||
|
||||
const TABLE_SIZE: usize = 1000;
|
||||
@@ -1,30 +1,4 @@
|
||||
// lookup/const_tables.rs
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use std::fs::read;
|
||||
use bincode::deserialize;
|
||||
use super::lookup_table::*;
|
||||
|
||||
pub const SIN_LOOKUP_F32: Lazy<EndoSinLookupTable<f32>> = Lazy::new(|| {
|
||||
deserialize(
|
||||
&read("src/lookup/data/sin_f32.bin").expect("Failed to read sin_f64.bin")
|
||||
).expect("Failed to load SIN_LOOKUP_F32")
|
||||
});
|
||||
|
||||
pub const SIN_LOOKUP_F64: Lazy<EndoSinLookupTable<f64>> = Lazy::new(|| {
|
||||
deserialize(
|
||||
&read("src/lookup/data/sin_f64.bin").expect("Failed to read sin_f32.bin")
|
||||
).expect("Failed to load SIN_LOOKUP_F64")
|
||||
});
|
||||
|
||||
pub const COS_LOOKUP_F32: Lazy<EndoCosLookupTable<f32>> = Lazy::new(|| {
|
||||
deserialize(
|
||||
&read("src/lookup/data/cos_f32.bin").expect("Failed to read cos_f64.bin")
|
||||
).expect("Failed to load COS_LOOKUP_F32")
|
||||
});
|
||||
|
||||
pub const COS_LOOKUP_F64: Lazy<EndoCosLookupTable<f64>> = Lazy::new(|| {
|
||||
deserialize(
|
||||
&read("src/lookup/data/cos_f64.bin").expect("Failed to read cos_f32.bin")
|
||||
).expect("Failed to load COS_LOOKUP_F64")
|
||||
});
|
||||
include!("data/sin_f32.rs");
|
||||
include!("data/sin_f64.rs");
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
2
src/lookup/data/sin_f32.rs
Normal file
2
src/lookup/data/sin_f32.rs
Normal file
File diff suppressed because one or more lines are too long
Binary file not shown.
2
src/lookup/data/sin_f64.rs
Normal file
2
src/lookup/data/sin_f64.rs
Normal file
File diff suppressed because one or more lines are too long
@@ -1,89 +1,153 @@
|
||||
use num_traits::sign::Signed;
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
use num_traits::float::{Float, FloatConst};
|
||||
use num_traits::NumCast;
|
||||
use std::ops::{Sub, Rem};
|
||||
use serde::{Serialize, Deserialize};
|
||||
// use packed_simd::f64x4;
|
||||
use crate::lookup::TABLE_SIZE;
|
||||
use crate::lookup::const_tables::*;
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
|
||||
// This function should never be used in a non-const context.
|
||||
// It only exists as a workaround for the fact that const fn's cannot use iterators.
|
||||
const fn make_ordinal<T: Float>(
|
||||
input: [T; TABLE_SIZE],
|
||||
mut map_target: [FloatOrd<T>; TABLE_SIZE],
|
||||
) -> () {
|
||||
// let mut map_target = [FloatOrd::<T>::new(); TABLE_SIZE];
|
||||
let mut index = 0;
|
||||
while index < TABLE_SIZE {
|
||||
map_target[index] = FloatOrd(input[index]);
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// The following macros are to minimise the amount of boilerplate for static types on the lookup tables.
|
||||
macro_rules! impl_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl FloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
|
||||
let ord_keys: [FloatOrd<$key_type>; TABLE_SIZE] = [FloatOrd(0.0 as $key_type); TABLE_SIZE];
|
||||
make_ordinal(keys, ord_keys);
|
||||
FloatLookupTable {
|
||||
keys: ord_keys,
|
||||
values,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_cycling_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl CyclingFloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, upper_bound: $key_type) -> Self {
|
||||
CyclingFloatLookupTable {
|
||||
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#[derive(Default, Debug,Clone, Copy, PartialEq, PartialOrd)]
|
||||
pub struct FloatOrd<T: Float>(pub T);
|
||||
impl<T: Float> FloatOrd<T> {
|
||||
pub fn new() -> Self {
|
||||
FloatOrd(T::zero())
|
||||
}
|
||||
}
|
||||
impl<T: Float> Eq for FloatOrd<T> {}
|
||||
impl<T: Float> Ord for FloatOrd<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FloatLookupTable<T1, T2>
|
||||
where T1: Float,
|
||||
T2: Float,
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
{
|
||||
keys: Vec<T1>,
|
||||
values: Vec<T2>,
|
||||
keys: [FloatOrd<T1>; TABLE_SIZE],
|
||||
values: [T2; TABLE_SIZE],
|
||||
}
|
||||
impl<T1, T2> FloatLookupTable<T1, T2>
|
||||
where T1: Float,
|
||||
T2: Float,
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
{
|
||||
pub fn new(mut keys: Vec<T1>, mut values: Vec<T2>) -> Self {
|
||||
let mut indices: Vec<_> = (0..keys.len()).collect();
|
||||
indices.sort_by(|&i, &j| keys[i].partial_cmp(&keys[j]).unwrap());
|
||||
for i in 0..keys.len() {
|
||||
while i != indices[i] {
|
||||
let swap_index = indices[i];
|
||||
keys.swap(i, swap_index);
|
||||
values.swap(i, swap_index);
|
||||
indices.swap(i, swap_index);
|
||||
}
|
||||
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE]) -> Self {
|
||||
FloatLookupTable {
|
||||
keys: keys.map(|key| FloatOrd(key)),
|
||||
values,
|
||||
}
|
||||
FloatLookupTable { keys, values }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn lookup(&self, key: T1) -> T2 {
|
||||
match self.keys.binary_search_by(|probe| probe.partial_cmp(&key).unwrap()) {
|
||||
Ok(index) => self.values[index],
|
||||
Err(index) => {
|
||||
let upper_key = &self.keys[index];
|
||||
let upper_val = &self.values[index];
|
||||
let low_index = index - 1;
|
||||
let lower_key = &self.keys[low_index];
|
||||
let lower_val = &self.values[low_index];
|
||||
// select nearest neighbour
|
||||
let diff_upper = (key - *upper_key).abs();
|
||||
let diff_lower = (key - *lower_key).abs();
|
||||
let mask = diff_lower <= diff_upper;
|
||||
(*lower_val * T2::from(mask as u8).expect("Failed to unwrap mask")) +
|
||||
(*upper_val * T2::from(!mask as u8).expect("Failed to unwrap !mask"))
|
||||
pub fn get_next(&self, key: T1) -> T2 {
|
||||
let ord_key = FloatOrd(key);
|
||||
let mut lower_bound = 0;
|
||||
let mut upper_bound = self.keys.len() - 1;
|
||||
let mut mid = (lower_bound + upper_bound) / 2;
|
||||
while upper_bound - lower_bound > 1 {
|
||||
if self.keys[mid] < ord_key {
|
||||
lower_bound = mid;
|
||||
} else {
|
||||
upper_bound = mid;
|
||||
}
|
||||
mid = (lower_bound + upper_bound) / 2;
|
||||
}
|
||||
self.values[mid]
|
||||
}
|
||||
|
||||
pub fn lookup(&self, key: T1) -> T2 {
|
||||
self.get_next(key)
|
||||
}
|
||||
}
|
||||
impl_fbitfbit_lookup_table!(f32, f32);
|
||||
impl_fbitfbit_lookup_table!(f64, f64);
|
||||
impl_fbitfbit_lookup_table!(f32, f64);
|
||||
impl_fbitfbit_lookup_table!(f64, f32);
|
||||
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CyclingFloatLookupTable<T1, T2>
|
||||
where T1: Float,
|
||||
T2: Float,
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
{
|
||||
lookup_table: FloatLookupTable<T1, T2>,
|
||||
lower_bound: T1,
|
||||
upper_bound: T1,
|
||||
bound_range: T1,
|
||||
}
|
||||
impl<T1, T2> CyclingFloatLookupTable<T1, T2>
|
||||
where T1: Float,
|
||||
T2: Float,
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
{
|
||||
pub fn new(keys: Vec<T1>, values: Vec<T2>, lower_bound: T1, upper_bound: T1) -> Self {
|
||||
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE], lower_bound: T1, upper_bound: T1) -> Self {
|
||||
CyclingFloatLookupTable {
|
||||
lookup_table: FloatLookupTable::new(keys, values),
|
||||
lower_bound: lower_bound,
|
||||
upper_bound: upper_bound,
|
||||
bound_range: upper_bound - lower_bound,
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lookup(&self, key: T1) -> T2 {
|
||||
let key = (key % self.bound_range) + self.lower_bound;
|
||||
let key = (key % (self.upper_bound - self.lower_bound)) + self.lower_bound;
|
||||
self.lookup_table.lookup(key)
|
||||
}
|
||||
}
|
||||
impl_cycling_fbitfbit_lookup_table!(f32, f32);
|
||||
impl_cycling_fbitfbit_lookup_table!(f64, f64);
|
||||
impl_cycling_fbitfbit_lookup_table!(f32, f64);
|
||||
impl_cycling_fbitfbit_lookup_table!(f64, f32);
|
||||
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EndoSinLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst,
|
||||
@@ -94,22 +158,6 @@ impl<T> EndoSinLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst,
|
||||
{
|
||||
pub fn new(precision: usize) -> Self {
|
||||
let mut keys = Vec::with_capacity(precision);
|
||||
let mut values = Vec::with_capacity(precision);
|
||||
let upper_bound = T::PI();
|
||||
let step = T::FRAC_PI_2() / <T as NumCast>::from(precision).unwrap();
|
||||
for i in 0..precision+1 {
|
||||
let key = step * <T as NumCast>::from(i).unwrap();
|
||||
let value = key.sin();
|
||||
keys.push(key);
|
||||
values.push(value);
|
||||
}
|
||||
EndoSinLookupTable {
|
||||
lookup_table: CyclingFloatLookupTable::new(keys, values, T::zero(), upper_bound),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn lookup(&self, key: T) -> T {
|
||||
if key < T::zero() {
|
||||
@@ -123,27 +171,57 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EndoSinLookupTable<f32>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
const UPPER_BOUND: f32 = f32_consts::PI;
|
||||
|
||||
EndoSinLookupTable {
|
||||
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, UPPER_BOUND),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EndoSinLookupTable<f64>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
let upper_bound = f64_consts::PI;
|
||||
|
||||
EndoSinLookupTable {
|
||||
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, upper_bound),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EndoCosLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>,
|
||||
T: Float + FloatConst,
|
||||
{
|
||||
lookup_table: EndoSinLookupTable<T>,
|
||||
}
|
||||
impl<T> EndoCosLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>,
|
||||
T: Float + FloatConst + std::fmt::Debug,
|
||||
{
|
||||
pub fn new(precision: usize) -> Self {
|
||||
EndoCosLookupTable {
|
||||
lookup_table: EndoSinLookupTable::new(precision),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn lookup(&self, key: T) -> T {
|
||||
self.lookup_table.lookup(key + T::FRAC_PI_2())
|
||||
}
|
||||
}
|
||||
impl EndoCosLookupTable<f32>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
EndoCosLookupTable {
|
||||
lookup_table: EndoSinLookupTable::<f32>::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EndoCosLookupTable<f64>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
EndoCosLookupTable {
|
||||
lookup_table: EndoSinLookupTable::<f64>::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
mod const_tables;
|
||||
pub mod lookup_table;
|
||||
|
||||
pub use const_tables::*;
|
||||
pub use const_tables::*;
|
||||
|
||||
include!("config.rs");
|
||||
12
src/main.rs
Normal file
12
src/main.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use fastmath::LookupCos;
|
||||
|
||||
fn main() {
|
||||
let x = (-10000..10000)
|
||||
.map(|a| (a as f64) / 1000.)
|
||||
.collect::<Vec<f64>>();
|
||||
|
||||
let y = x.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>();
|
||||
|
||||
// to ensure the compiler doesn't optimize away the function call, we print the result
|
||||
println!("{:?}", y);
|
||||
}
|
||||
22
src/tests.rs
22
src/tests.rs
@@ -39,7 +39,7 @@ mod tests {
|
||||
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f64>>()
|
||||
&X.iter().map(|&x| 2.0f64.powf(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN");
|
||||
assert!(
|
||||
@@ -77,6 +77,17 @@ mod tests {
|
||||
"fast_cos<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
// lookup
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "lookup_cos<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"lookup_cos<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -114,7 +125,7 @@ mod tests {
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f32>>()
|
||||
&X.iter().map(|&x| 2.0f32.powf(x)).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
@@ -139,6 +150,13 @@ mod tests {
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
// lookup
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user