>99.7% speedup to lookup tables

This commit is contained in:
Cian Hughes
2023-07-25 10:14:08 +01:00
parent 26e512394e
commit 14d057e4ff
15 changed files with 355 additions and 251 deletions

View File

@@ -10,19 +10,13 @@ fn pow2_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>()) b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>())
}); });
group.bench_function("f64_builtin_fn", |b| { group.bench_function("f64_builtin_fn", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f64>>()) b.iter(|| x_f64.iter().map(|&x| 2.0f64.powf(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin_mul", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x) * x).collect::<Vec<f64>>())
}); });
group.bench_function("f32_fast", |b| { group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>()) b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>())
}); });
group.bench_function("f32_builtin_fn", |b| { group.bench_function("f32_builtin_fn", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f32>>()) b.iter(|| x_f32.iter().map(|&x| 2.0f32.powf(black_box(x))).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin_mul", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x) * x).collect::<Vec<f32>>())
}); });
} }

139
build.rs
View File

@@ -1,42 +1,137 @@
// build.rs // build.rs
mod precalculate_lookup_tables { mod precalculate_lookup_tables {
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use std::fs::{create_dir_all, File}; use std::fs::{create_dir_all, File};
use std::io::Write; use std::io::Write;
include!("src/lookup/lookup_table.rs"); include!("src/lookup/config.rs");
use bincode::serialize; // use bincode::serialize;
const PRECISION: usize = 1000; // fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/sin_f32.bin")?;
// file.write_all(&data)?;
fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> { // let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?; // let mut file = File::create("src/lookup/data/sin_f64.bin")?;
let mut file = File::create("src/lookup/data/sin_f32.bin")?; // file.write_all(&data)?;
file.write_all(&data)?;
let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?; // Ok(())
let mut file = File::create("src/lookup/data/sin_f64.bin")?; // }
file.write_all(&data)?;
Ok(()) // fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f32.bin")?;
// file.write_all(&data)?;
// let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f64.bin")?;
// file.write_all(&data)?;
// Ok(())
// }
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)) - half_step
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
// file.write_all(data.as_bytes())?;
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
// let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
macro_rules! precalculate_sin_tables {
() => {{
let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
let half_step: f32 = step / 2.0;
let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)) - half_step
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)).sin()
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let data = format!("pub const SIN_F32_KEYS: [f32; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f32.rs").expect("Failed to create sin_f32.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f32.rs");
let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
let half_step: f64 = step / 2.0;
let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)) - half_step
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)).sin()
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let data = format!("pub const SIN_F64_KEYS: [f64; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f64.rs").expect("Failed to create sin_f64.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f64.rs");
}};
} }
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> { // let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?; // (step * (i as f32)) - half_step
let mut file = File::create("src/lookup/data/cos_f32.bin")?; // }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
file.write_all(&data)?; // let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?; // let mut file = File::create("src/lookup/data/sin_f32.rs")?;
let mut file = File::create("src/lookup/data/cos_f64.bin")?; // file.write_all(data.as_bytes())?;
file.write_all(&data)?;
Ok(()) // let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
} // let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
pub fn generate() -> Result<(), Box<dyn std::error::Error>> { pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
create_dir_all("src/lookup/data")?; create_dir_all("src/lookup/data")?;
precalculate_sin_tables()?; precalculate_sin_tables!();
precalculate_cos_tables()?; // precalculate_cos_tables()?;
Ok(()) Ok(())
} }

View File

@@ -4,12 +4,19 @@
use std::f32::consts as f32_consts; use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts; use std::f64::consts as f64_consts;
use crate::lookup::*; // use crate::lookup::*;
use crate::lookup::lookup_table::EndoCosLookupTable;
pub trait FastMath: FastCos + FastPow2 + FastExp + FastSigmoid {} const COS_LOOKUP_F32: EndoCosLookupTable<f32> = EndoCosLookupTable::<f32>::new();
const COS_LOOKUP_F64: EndoCosLookupTable<f64> = EndoCosLookupTable::<f64>::new();
pub trait FastMath: FastCos + FastExp + FastSigmoid {}
impl FastMath for f32 {} impl FastMath for f32 {}
impl FastMath for f64 {} impl FastMath for f64 {}
const V_SCALE_F32: f32 = 8388608.0; // the largest possible mantissa of an f32
const V_SCALE_F64: f64 = 4503599627370496.0; // the largest possible mantissa of an f64
pub trait LookupCos { pub trait LookupCos {
fn lookup_cos(self: Self) -> Self; fn lookup_cos(self: Self) -> Self;
@@ -35,93 +42,66 @@ pub trait FastCos {
impl FastCos for f32 { impl FastCos for f32 {
#[inline] #[inline]
fn fast_cos(self: Self) -> f32 { fn fast_cos(self: Self) -> f32 {
const BITAND: u32 = u32::MAX / 2;
const ONE: f32 = 1.0; const ONE: f32 = 1.0;
let mod_x = (((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI; let v = ((((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI).abs();
let v = mod_x.to_bits() & BITAND; let qpprox = ONE - f32_consts::FRAC_2_PI * v;
let qpprox = ONE - f32_consts::FRAC_2_PI * f32::from_bits(v);
qpprox + f32_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox) qpprox + f32_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
} }
} }
impl FastCos for f64 { impl FastCos for f64 {
#[inline] #[inline]
fn fast_cos(self: Self) -> f64 { fn fast_cos(self: Self) -> f64 {
const BITAND: u64 = u64::MAX / 2;
const ONE: f64 = 1.0; const ONE: f64 = 1.0;
let mod_x = (((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI; let v = ((((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs();
let v = mod_x.to_bits() & BITAND; let qpprox = ONE - f64_consts::FRAC_2_PI * v;
let qpprox = ONE - f64_consts::FRAC_2_PI * f64::from_bits(v);
qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox) qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
} }
} }
pub trait FastPow2 {
fn fast_pow2(self: Self) -> Self;
}
impl FastPow2 for f32 {
#[inline]
fn fast_pow2(self: Self) -> f32 {
// Khinchins constant over 3. IDK why it gives the best fit, but it does
const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f32 = 0.12847338;
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
const CLIPP_SHIFT: f32 = 126.67740855;
let abs_p = self.abs();
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v) - KHINCHIN_3
}
}
impl FastPow2 for f64 {
#[inline]
fn fast_pow2(self: Self) -> f64 {
const KHINCHIN_3: f64 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f64 = -45774.9247660416;
const V_SCALE: f64 = 4503599627370496.0; // (1i64 << 52) as f64
const CLIPP_SHIFT: f64 = 1022.6769200000002;
const ZERO: f64 = 0.;
let abs_p = self.abs();
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
let y = f64::from_bits(v) - KHINCHIN_3;
if y.is_sign_positive() {
y
} else {
ZERO
}
}
}
pub trait FastExp { pub trait FastExp {
fn fast_exp(self: Self) -> Self; fn fast_exp(self: Self) -> Self;
} }
impl FastExp for f32 { impl FastExp for f32 {
#[inline] #[inline]
fn fast_exp(self: Self) -> f32 { fn fast_exp(self: Self) -> f32 {
const CLIPP_THRESH: f32 = -126.0; // 0.12847338; const CLIPP_THRESH: f32 = -126.0; // exponent of smallest possible f32 to prevent underflow
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32 const CLIPP_SHIFT: f32 = 126.94269504; // shift to align curve, found by regression
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
let scaled_p = f32_consts::LOG2_E * self; let scaled_p = f32_consts::LOG2_E * self;
let clipp = scaled_p.max(CLIPP_THRESH); // if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p }; let clipp = scaled_p.max(CLIPP_THRESH);
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32; let v = (V_SCALE_F32 * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v) f32::from_bits(v)
} }
} }
impl FastExp for f64 { impl FastExp for f64 {
#[inline] #[inline]
fn fast_exp(self: Self) -> f64 { fn fast_exp(self: Self) -> f64 {
const CLIPP_THRESH: f64 = -180335.51911105003; const CLIPP_THRESH: f64 = -1022.0; // exponent of smallest possible f64 to prevent underflow
const V_SCALE: f64 = 4524653012949098.0; const CLIPP_SHIFT: f64 = 1022.9349439517318; // shift to align curve, found by regression
const CLIPP_SHIFT: f64 = 1018.1563534409383;
let scaled_p = f64_consts::LOG2_E * self; let scaled_p = f64_consts::LOG2_E * self;
let clipp = scaled_p.max(CLIPP_THRESH); // let clipp = if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p }; let clipp = scaled_p.max(CLIPP_THRESH);
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64; let v = (V_SCALE_F64 * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v) f64::from_bits(v)
} }
} }
pub trait FastPow2 {
fn fast_pow2(self: Self) -> Self;
}
impl FastPow2 for f32 {
#[inline]
fn fast_pow2(self: Self) -> f32 {
(f32_consts::LN_2 * self).fast_exp()
}
}
impl FastPow2 for f64 {
#[inline]
fn fast_pow2(self: Self) -> f64 {
(f64_consts::LN_2 * self).fast_exp()
}
}
pub trait FastSigmoid { pub trait FastSigmoid {
fn fast_sigmoid(self: Self) -> Self; fn fast_sigmoid(self: Self) -> Self;
} }
@@ -140,63 +120,7 @@ impl FastSigmoid for f64 {
} }
} }
// A trait for testing and improving implementations of fast functions // functions for testing the accuracy of fast functions against builtin functions
pub trait Test {
fn test(self: Self) -> Self;
}
impl Test for f32 {
#[inline]
fn test(self: Self) -> f32 {
// Khinchins constant over 3. IDK why it gives the best fit, but it does
// const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f32 = -126.0; // 0.12847338;
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
let scaled_p = f32_consts::LOG2_E * self;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v) // - KHINCHIN_3
}
}
impl Test for f64 {
#[inline]
fn test(self: Self) -> f64 {
const CLIPP_THRESH: f64 = -180335.51911105003;
const V_SCALE: f64 = 4524653012949098.0;
const CLIPP_SHIFT: f64 = 1018.1563534409383;
let scaled_p = f64_consts::LOG2_E * self;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v)
}
}
#[allow(non_snake_case, dead_code)]
pub fn optimizing(p: f64, CLIPP_THRESH: f64, V_SCALE: f64, CLIPP_SHIFT: f64) -> f64 {
// const CLIPP_THRESH: f64 = -45774.9247660416;
// const V_SCALE: f64 = 4503599627370496.0;
// const CLIPP_SHIFT: f64 = 1022.6769200000002;
let scaled_p = f64_consts::LOG2_E * p;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v)
}
#[inline] #[inline]
pub fn sigmoid_builtin_f32(p: f32) -> f32 { pub fn sigmoid_builtin_f32(p: f32) -> f32 {
(1. + (-p).exp()).recip() (1. + (-p).exp()).recip()

3
src/lookup/config.rs Normal file
View File

@@ -0,0 +1,3 @@
// lookup/config.rs
const TABLE_SIZE: usize = 1000;

View File

@@ -1,30 +1,4 @@
// lookup/const_tables.rs // lookup/const_tables.rs
use once_cell::sync::Lazy; include!("data/sin_f32.rs");
use std::fs::read; include!("data/sin_f64.rs");
use bincode::deserialize;
use super::lookup_table::*;
pub const SIN_LOOKUP_F32: Lazy<EndoSinLookupTable<f32>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/sin_f32.bin").expect("Failed to read sin_f64.bin")
).expect("Failed to load SIN_LOOKUP_F32")
});
pub const SIN_LOOKUP_F64: Lazy<EndoSinLookupTable<f64>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/sin_f64.bin").expect("Failed to read sin_f32.bin")
).expect("Failed to load SIN_LOOKUP_F64")
});
pub const COS_LOOKUP_F32: Lazy<EndoCosLookupTable<f32>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/cos_f32.bin").expect("Failed to read cos_f64.bin")
).expect("Failed to load COS_LOOKUP_F32")
});
pub const COS_LOOKUP_F64: Lazy<EndoCosLookupTable<f64>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/cos_f64.bin").expect("Failed to read cos_f32.bin")
).expect("Failed to load COS_LOOKUP_F64")
});

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -1,89 +1,153 @@
use num_traits::sign::Signed; use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use num_traits::float::{Float, FloatConst}; use num_traits::float::{Float, FloatConst};
use num_traits::NumCast; use crate::lookup::TABLE_SIZE;
use std::ops::{Sub, Rem}; use crate::lookup::const_tables::*;
use serde::{Serialize, Deserialize};
// use packed_simd::f64x4;
#[derive(Default, Debug, Serialize, Deserialize, Clone)] // This function should never be used in a non-const context.
// It only exists as a workaround for the fact that const fn's cannot use iterators.
const fn make_ordinal<T: Float>(
input: [T; TABLE_SIZE],
mut map_target: [FloatOrd<T>; TABLE_SIZE],
) -> () {
// let mut map_target = [FloatOrd::<T>::new(); TABLE_SIZE];
let mut index = 0;
while index < TABLE_SIZE {
map_target[index] = FloatOrd(input[index]);
index += 1;
}
}
// The following macros are to minimise the amount of boilerplate for static types on the lookup tables.
macro_rules! impl_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl FloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
let ord_keys: [FloatOrd<$key_type>; TABLE_SIZE] = [FloatOrd(0.0 as $key_type); TABLE_SIZE];
make_ordinal(keys, ord_keys);
FloatLookupTable {
keys: ord_keys,
values,
}
}
}
};
}
macro_rules! impl_cycling_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl CyclingFloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, upper_bound: $key_type) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
lower_bound,
upper_bound,
}
}
}
};
}
#[derive(Default, Debug,Clone, Copy, PartialEq, PartialOrd)]
pub struct FloatOrd<T: Float>(pub T);
impl<T: Float> FloatOrd<T> {
pub fn new() -> Self {
FloatOrd(T::zero())
}
}
impl<T: Float> Eq for FloatOrd<T> {}
impl<T: Float> Ord for FloatOrd<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone)]
pub struct FloatLookupTable<T1, T2> pub struct FloatLookupTable<T1, T2>
where T1: Float, where
T2: Float, T1: Float,
T2: Float,
{ {
keys: Vec<T1>, keys: [FloatOrd<T1>; TABLE_SIZE],
values: Vec<T2>, values: [T2; TABLE_SIZE],
} }
impl<T1, T2> FloatLookupTable<T1, T2> impl<T1, T2> FloatLookupTable<T1, T2>
where T1: Float, where
T2: Float, T1: Float,
T2: Float,
{ {
pub fn new(mut keys: Vec<T1>, mut values: Vec<T2>) -> Self { pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE]) -> Self {
let mut indices: Vec<_> = (0..keys.len()).collect(); FloatLookupTable {
indices.sort_by(|&i, &j| keys[i].partial_cmp(&keys[j]).unwrap()); keys: keys.map(|key| FloatOrd(key)),
for i in 0..keys.len() { values,
while i != indices[i] {
let swap_index = indices[i];
keys.swap(i, swap_index);
values.swap(i, swap_index);
indices.swap(i, swap_index);
}
} }
FloatLookupTable { keys, values }
} }
#[allow(dead_code)] pub fn get_next(&self, key: T1) -> T2 {
pub fn lookup(&self, key: T1) -> T2 { let ord_key = FloatOrd(key);
match self.keys.binary_search_by(|probe| probe.partial_cmp(&key).unwrap()) { let mut lower_bound = 0;
Ok(index) => self.values[index], let mut upper_bound = self.keys.len() - 1;
Err(index) => { let mut mid = (lower_bound + upper_bound) / 2;
let upper_key = &self.keys[index]; while upper_bound - lower_bound > 1 {
let upper_val = &self.values[index]; if self.keys[mid] < ord_key {
let low_index = index - 1; lower_bound = mid;
let lower_key = &self.keys[low_index]; } else {
let lower_val = &self.values[low_index]; upper_bound = mid;
// select nearest neighbour
let diff_upper = (key - *upper_key).abs();
let diff_lower = (key - *lower_key).abs();
let mask = diff_lower <= diff_upper;
(*lower_val * T2::from(mask as u8).expect("Failed to unwrap mask")) +
(*upper_val * T2::from(!mask as u8).expect("Failed to unwrap !mask"))
} }
mid = (lower_bound + upper_bound) / 2;
} }
self.values[mid]
}
pub fn lookup(&self, key: T1) -> T2 {
self.get_next(key)
} }
} }
impl_fbitfbit_lookup_table!(f32, f32);
impl_fbitfbit_lookup_table!(f64, f64);
impl_fbitfbit_lookup_table!(f32, f64);
impl_fbitfbit_lookup_table!(f64, f32);
#[derive(Default, Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Clone)]
pub struct CyclingFloatLookupTable<T1, T2> pub struct CyclingFloatLookupTable<T1, T2>
where T1: Float, where
T2: Float, T1: Float,
T2: Float,
{ {
lookup_table: FloatLookupTable<T1, T2>, lookup_table: FloatLookupTable<T1, T2>,
lower_bound: T1, lower_bound: T1,
upper_bound: T1, upper_bound: T1,
bound_range: T1,
} }
impl<T1, T2> CyclingFloatLookupTable<T1, T2> impl<T1, T2> CyclingFloatLookupTable<T1, T2>
where T1: Float, where
T2: Float, T1: Float,
T2: Float,
{ {
pub fn new(keys: Vec<T1>, values: Vec<T2>, lower_bound: T1, upper_bound: T1) -> Self { pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE], lower_bound: T1, upper_bound: T1) -> Self {
CyclingFloatLookupTable { CyclingFloatLookupTable {
lookup_table: FloatLookupTable::new(keys, values), lookup_table: FloatLookupTable::new(keys, values),
lower_bound: lower_bound, lower_bound,
upper_bound: upper_bound, upper_bound,
bound_range: upper_bound - lower_bound,
} }
} }
pub fn lookup(&self, key: T1) -> T2 { pub fn lookup(&self, key: T1) -> T2 {
let key = (key % self.bound_range) + self.lower_bound; let key = (key % (self.upper_bound - self.lower_bound)) + self.lower_bound;
self.lookup_table.lookup(key) self.lookup_table.lookup(key)
} }
} }
impl_cycling_fbitfbit_lookup_table!(f32, f32);
impl_cycling_fbitfbit_lookup_table!(f64, f64);
impl_cycling_fbitfbit_lookup_table!(f32, f64);
impl_cycling_fbitfbit_lookup_table!(f64, f32);
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Clone)]
pub struct EndoSinLookupTable<T> pub struct EndoSinLookupTable<T>
where where
T: Float + FloatConst, T: Float + FloatConst,
@@ -94,22 +158,6 @@ impl<T> EndoSinLookupTable<T>
where where
T: Float + FloatConst, T: Float + FloatConst,
{ {
pub fn new(precision: usize) -> Self {
let mut keys = Vec::with_capacity(precision);
let mut values = Vec::with_capacity(precision);
let upper_bound = T::PI();
let step = T::FRAC_PI_2() / <T as NumCast>::from(precision).unwrap();
for i in 0..precision+1 {
let key = step * <T as NumCast>::from(i).unwrap();
let value = key.sin();
keys.push(key);
values.push(value);
}
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::new(keys, values, T::zero(), upper_bound),
}
}
#[allow(dead_code)] #[allow(dead_code)]
pub fn lookup(&self, key: T) -> T { pub fn lookup(&self, key: T) -> T {
if key < T::zero() { if key < T::zero() {
@@ -123,27 +171,57 @@ where
} }
} }
} }
impl EndoSinLookupTable<f32>
{
pub const fn new() -> Self {
const UPPER_BOUND: f32 = f32_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, UPPER_BOUND),
}
}
}
impl EndoSinLookupTable<f64>
{
pub const fn new() -> Self {
let upper_bound = f64_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, upper_bound),
}
}
}
#[derive(Default, Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Clone)]
pub struct EndoCosLookupTable<T> pub struct EndoCosLookupTable<T>
where where
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>, T: Float + FloatConst,
{ {
lookup_table: EndoSinLookupTable<T>, lookup_table: EndoSinLookupTable<T>,
} }
impl<T> EndoCosLookupTable<T> impl<T> EndoCosLookupTable<T>
where where
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>, T: Float + FloatConst + std::fmt::Debug,
{ {
pub fn new(precision: usize) -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::new(precision),
}
}
#[allow(dead_code)] #[allow(dead_code)]
pub fn lookup(&self, key: T) -> T { pub fn lookup(&self, key: T) -> T {
self.lookup_table.lookup(key + T::FRAC_PI_2()) self.lookup_table.lookup(key + T::FRAC_PI_2())
} }
}
impl EndoCosLookupTable<f32>
{
pub const fn new() -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::<f32>::new(),
}
}
}
impl EndoCosLookupTable<f64>
{
pub const fn new() -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::<f64>::new(),
}
}
} }

View File

@@ -1,4 +1,6 @@
mod const_tables; mod const_tables;
pub mod lookup_table; pub mod lookup_table;
pub use const_tables::*; pub use const_tables::*;
include!("config.rs");

12
src/main.rs Normal file
View File

@@ -0,0 +1,12 @@
use fastmath::LookupCos;
fn main() {
let x = (-10000..10000)
.map(|a| (a as f64) / 1000.)
.collect::<Vec<f64>>();
let y = x.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>();
// to ensure the compiler doesn't optimize away the function call, we print the result
println!("{:?}", y);
}

View File

@@ -39,7 +39,7 @@ mod tests {
fn pow2() -> Result<(), Box<dyn std::error::Error>> { fn pow2() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error( let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(), &X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f64>>() &X.iter().map(|&x| 2.0f64.powf(x)).collect::<Vec<f64>>()
); );
assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN"); assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN");
assert!( assert!(
@@ -77,6 +77,17 @@ mod tests {
"fast_cos<f64> percentage error: {0}", "fast_cos<f64> percentage error: {0}",
percentage_error percentage_error
); );
// lookup
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "lookup_cos<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"lookup_cos<f64> percentage error: {0}",
percentage_error
);
Ok(()) Ok(())
} }
@@ -114,7 +125,7 @@ mod tests {
assert!( assert!(
calculate_percentage_error( calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(), &X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f32>>() &X.iter().map(|&x| 2.0f32.powf(x)).collect::<Vec<f32>>()
) < TOLERANCE ) < TOLERANCE
); );
Ok(()) Ok(())
@@ -139,6 +150,13 @@ mod tests {
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>() &X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE ) < TOLERANCE
); );
// lookup
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(()) Ok(())
} }