>99.7% speedup to lookup tables

This commit is contained in:
Cian Hughes
2023-07-25 10:14:08 +01:00
parent 26e512394e
commit 14d057e4ff
15 changed files with 355 additions and 251 deletions

View File

@@ -10,19 +10,13 @@ fn pow2_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin_fn", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin_mul", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x) * x).collect::<Vec<f64>>())
b.iter(|| x_f64.iter().map(|&x| 2.0f64.powf(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin_fn", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).powi(2)).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin_mul", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x) * x).collect::<Vec<f32>>())
b.iter(|| x_f32.iter().map(|&x| 2.0f32.powf(black_box(x))).collect::<Vec<f32>>())
});
}

139
build.rs
View File

@@ -1,42 +1,137 @@
// build.rs
mod precalculate_lookup_tables {
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use std::fs::{create_dir_all, File};
use std::io::Write;
include!("src/lookup/lookup_table.rs");
use bincode::serialize;
include!("src/lookup/config.rs");
// use bincode::serialize;
const PRECISION: usize = 1000;
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/sin_f32.bin")?;
// file.write_all(&data)?;
fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
let mut file = File::create("src/lookup/data/sin_f32.bin")?;
file.write_all(&data)?;
// let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/sin_f64.bin")?;
// file.write_all(&data)?;
let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
let mut file = File::create("src/lookup/data/sin_f64.bin")?;
file.write_all(&data)?;
// Ok(())
// }
Ok(())
// fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f32.bin")?;
// file.write_all(&data)?;
// let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f64.bin")?;
// file.write_all(&data)?;
// Ok(())
// }
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)) - half_step
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
// file.write_all(data.as_bytes())?;
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
// let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
macro_rules! precalculate_sin_tables {
() => {{
let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
let half_step: f32 = step / 2.0;
let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)) - half_step
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)).sin()
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let data = format!("pub const SIN_F32_KEYS: [f32; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f32.rs").expect("Failed to create sin_f32.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f32.rs");
let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
let half_step: f64 = step / 2.0;
let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)) - half_step
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)).sin()
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let data = format!("pub const SIN_F64_KEYS: [f64; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f64.rs").expect("Failed to create sin_f64.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f64.rs");
}};
}
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
let mut file = File::create("src/lookup/data/cos_f32.bin")?;
file.write_all(&data)?;
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)) - half_step
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
let mut file = File::create("src/lookup/data/cos_f64.bin")?;
file.write_all(&data)?;
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
// file.write_all(data.as_bytes())?;
Ok(())
}
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
// let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
create_dir_all("src/lookup/data")?;
precalculate_sin_tables()?;
precalculate_cos_tables()?;
precalculate_sin_tables!();
// precalculate_cos_tables()?;
Ok(())
}

View File

@@ -4,12 +4,19 @@
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use crate::lookup::*;
// use crate::lookup::*;
use crate::lookup::lookup_table::EndoCosLookupTable;
pub trait FastMath: FastCos + FastPow2 + FastExp + FastSigmoid {}
const COS_LOOKUP_F32: EndoCosLookupTable<f32> = EndoCosLookupTable::<f32>::new();
const COS_LOOKUP_F64: EndoCosLookupTable<f64> = EndoCosLookupTable::<f64>::new();
pub trait FastMath: FastCos + FastExp + FastSigmoid {}
impl FastMath for f32 {}
impl FastMath for f64 {}
const V_SCALE_F32: f32 = 8388608.0; // the largest possible mantissa of an f32
const V_SCALE_F64: f64 = 4503599627370496.0; // the largest possible mantissa of an f64
pub trait LookupCos {
fn lookup_cos(self: Self) -> Self;
@@ -35,93 +42,66 @@ pub trait FastCos {
impl FastCos for f32 {
#[inline]
fn fast_cos(self: Self) -> f32 {
const BITAND: u32 = u32::MAX / 2;
const ONE: f32 = 1.0;
let mod_x = (((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI;
let v = mod_x.to_bits() & BITAND;
let qpprox = ONE - f32_consts::FRAC_2_PI * f32::from_bits(v);
let v = ((((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI).abs();
let qpprox = ONE - f32_consts::FRAC_2_PI * v;
qpprox + f32_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
}
}
impl FastCos for f64 {
#[inline]
fn fast_cos(self: Self) -> f64 {
const BITAND: u64 = u64::MAX / 2;
const ONE: f64 = 1.0;
let mod_x = (((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI;
let v = mod_x.to_bits() & BITAND;
let qpprox = ONE - f64_consts::FRAC_2_PI * f64::from_bits(v);
let v = ((((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs();
let qpprox = ONE - f64_consts::FRAC_2_PI * v;
qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
}
}
pub trait FastPow2 {
fn fast_pow2(self: Self) -> Self;
}
impl FastPow2 for f32 {
#[inline]
fn fast_pow2(self: Self) -> f32 {
// Khinchins constant over 3. IDK why it gives the best fit, but it does
const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f32 = 0.12847338;
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
const CLIPP_SHIFT: f32 = 126.67740855;
let abs_p = self.abs();
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v) - KHINCHIN_3
}
}
impl FastPow2 for f64 {
#[inline]
fn fast_pow2(self: Self) -> f64 {
const KHINCHIN_3: f64 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f64 = -45774.9247660416;
const V_SCALE: f64 = 4503599627370496.0; // (1i64 << 52) as f64
const CLIPP_SHIFT: f64 = 1022.6769200000002;
const ZERO: f64 = 0.;
let abs_p = self.abs();
let clipp = abs_p.max(CLIPP_THRESH); // if abs_p < CLIPP_THRESH { CLIPP_THRESH } else { abs_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
let y = f64::from_bits(v) - KHINCHIN_3;
if y.is_sign_positive() {
y
} else {
ZERO
}
}
}
pub trait FastExp {
fn fast_exp(self: Self) -> Self;
}
impl FastExp for f32 {
#[inline]
fn fast_exp(self: Self) -> f32 {
const CLIPP_THRESH: f32 = -126.0; // 0.12847338;
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
const CLIPP_THRESH: f32 = -126.0; // exponent of smallest possible f32 to prevent underflow
const CLIPP_SHIFT: f32 = 126.94269504; // shift to align curve, found by regression
let scaled_p = f32_consts::LOG2_E * self;
let clipp = scaled_p.max(CLIPP_THRESH); // if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
let clipp = scaled_p.max(CLIPP_THRESH);
let v = (V_SCALE_F32 * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v)
}
}
impl FastExp for f64 {
#[inline]
fn fast_exp(self: Self) -> f64 {
const CLIPP_THRESH: f64 = -180335.51911105003;
const V_SCALE: f64 = 4524653012949098.0;
const CLIPP_SHIFT: f64 = 1018.1563534409383;
const CLIPP_THRESH: f64 = -1022.0; // exponent of smallest possible f64 to prevent underflow
const CLIPP_SHIFT: f64 = 1022.9349439517318; // shift to align curve, found by regression
let scaled_p = f64_consts::LOG2_E * self;
let clipp = scaled_p.max(CLIPP_THRESH); // let clipp = if scaled_p < CLIPP_THRESH { CLIPP_THRESH } else { scaled_p };
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
let clipp = scaled_p.max(CLIPP_THRESH);
let v = (V_SCALE_F64 * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v)
}
}
pub trait FastPow2 {
fn fast_pow2(self: Self) -> Self;
}
impl FastPow2 for f32 {
#[inline]
fn fast_pow2(self: Self) -> f32 {
(f32_consts::LN_2 * self).fast_exp()
}
}
impl FastPow2 for f64 {
#[inline]
fn fast_pow2(self: Self) -> f64 {
(f64_consts::LN_2 * self).fast_exp()
}
}
pub trait FastSigmoid {
fn fast_sigmoid(self: Self) -> Self;
}
@@ -140,63 +120,7 @@ impl FastSigmoid for f64 {
}
}
// A trait for testing and improving implementations of fast functions
pub trait Test {
fn test(self: Self) -> Self;
}
impl Test for f32 {
#[inline]
fn test(self: Self) -> f32 {
// Khinchins constant over 3. IDK why it gives the best fit, but it does
// const KHINCHIN_3: f32 = 2.68545200106530644530971483548179569382038229399446295305115234555721885953715200280114117493184769799515 / 3.0;
const CLIPP_THRESH: f32 = -126.0; // 0.12847338;
const V_SCALE: f32 = 8388608.0; // (1_i32 << 23) as f32
const CLIPP_SHIFT: f32 = 126.94269504; // 126.67740855;
let scaled_p = f32_consts::LOG2_E * self;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u32;
f32::from_bits(v) // - KHINCHIN_3
}
}
impl Test for f64 {
#[inline]
fn test(self: Self) -> f64 {
const CLIPP_THRESH: f64 = -180335.51911105003;
const V_SCALE: f64 = 4524653012949098.0;
const CLIPP_SHIFT: f64 = 1018.1563534409383;
let scaled_p = f64_consts::LOG2_E * self;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v)
}
}
#[allow(non_snake_case, dead_code)]
pub fn optimizing(p: f64, CLIPP_THRESH: f64, V_SCALE: f64, CLIPP_SHIFT: f64) -> f64 {
// const CLIPP_THRESH: f64 = -45774.9247660416;
// const V_SCALE: f64 = 4503599627370496.0;
// const CLIPP_SHIFT: f64 = 1022.6769200000002;
let scaled_p = f64_consts::LOG2_E * p;
let clipp = if scaled_p < CLIPP_THRESH {
CLIPP_THRESH
} else {
scaled_p
};
let v = (V_SCALE * (clipp + CLIPP_SHIFT)) as u64;
f64::from_bits(v)
}
// functions for testing the accuracy of fast functions against builtin functions
#[inline]
pub fn sigmoid_builtin_f32(p: f32) -> f32 {
(1. + (-p).exp()).recip()

3
src/lookup/config.rs Normal file
View File

@@ -0,0 +1,3 @@
// lookup/config.rs
const TABLE_SIZE: usize = 1000;

View File

@@ -1,30 +1,4 @@
// lookup/const_tables.rs
use once_cell::sync::Lazy;
use std::fs::read;
use bincode::deserialize;
use super::lookup_table::*;
pub const SIN_LOOKUP_F32: Lazy<EndoSinLookupTable<f32>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/sin_f32.bin").expect("Failed to read sin_f64.bin")
).expect("Failed to load SIN_LOOKUP_F32")
});
pub const SIN_LOOKUP_F64: Lazy<EndoSinLookupTable<f64>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/sin_f64.bin").expect("Failed to read sin_f32.bin")
).expect("Failed to load SIN_LOOKUP_F64")
});
pub const COS_LOOKUP_F32: Lazy<EndoCosLookupTable<f32>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/cos_f32.bin").expect("Failed to read cos_f64.bin")
).expect("Failed to load COS_LOOKUP_F32")
});
pub const COS_LOOKUP_F64: Lazy<EndoCosLookupTable<f64>> = Lazy::new(|| {
deserialize(
&read("src/lookup/data/cos_f64.bin").expect("Failed to read cos_f32.bin")
).expect("Failed to load COS_LOOKUP_F64")
});
include!("data/sin_f32.rs");
include!("data/sin_f64.rs");

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@@ -1,89 +1,153 @@
use num_traits::sign::Signed;
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use num_traits::float::{Float, FloatConst};
use num_traits::NumCast;
use std::ops::{Sub, Rem};
use serde::{Serialize, Deserialize};
// use packed_simd::f64x4;
use crate::lookup::TABLE_SIZE;
use crate::lookup::const_tables::*;
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
// This function should never be used in a non-const context.
// It only exists as a workaround for the fact that const fn's cannot use iterators.
const fn make_ordinal<T: Float>(
input: [T; TABLE_SIZE],
mut map_target: [FloatOrd<T>; TABLE_SIZE],
) -> () {
// let mut map_target = [FloatOrd::<T>::new(); TABLE_SIZE];
let mut index = 0;
while index < TABLE_SIZE {
map_target[index] = FloatOrd(input[index]);
index += 1;
}
}
// The following macros are to minimise the amount of boilerplate for static types on the lookup tables.
macro_rules! impl_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl FloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
let ord_keys: [FloatOrd<$key_type>; TABLE_SIZE] = [FloatOrd(0.0 as $key_type); TABLE_SIZE];
make_ordinal(keys, ord_keys);
FloatLookupTable {
keys: ord_keys,
values,
}
}
}
};
}
macro_rules! impl_cycling_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl CyclingFloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, upper_bound: $key_type) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
lower_bound,
upper_bound,
}
}
}
};
}
#[derive(Default, Debug,Clone, Copy, PartialEq, PartialOrd)]
pub struct FloatOrd<T: Float>(pub T);
impl<T: Float> FloatOrd<T> {
pub fn new() -> Self {
FloatOrd(T::zero())
}
}
impl<T: Float> Eq for FloatOrd<T> {}
impl<T: Float> Ord for FloatOrd<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone)]
pub struct FloatLookupTable<T1, T2>
where T1: Float,
T2: Float,
where
T1: Float,
T2: Float,
{
keys: Vec<T1>,
values: Vec<T2>,
keys: [FloatOrd<T1>; TABLE_SIZE],
values: [T2; TABLE_SIZE],
}
impl<T1, T2> FloatLookupTable<T1, T2>
where T1: Float,
T2: Float,
where
T1: Float,
T2: Float,
{
pub fn new(mut keys: Vec<T1>, mut values: Vec<T2>) -> Self {
let mut indices: Vec<_> = (0..keys.len()).collect();
indices.sort_by(|&i, &j| keys[i].partial_cmp(&keys[j]).unwrap());
for i in 0..keys.len() {
while i != indices[i] {
let swap_index = indices[i];
keys.swap(i, swap_index);
values.swap(i, swap_index);
indices.swap(i, swap_index);
}
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE]) -> Self {
FloatLookupTable {
keys: keys.map(|key| FloatOrd(key)),
values,
}
FloatLookupTable { keys, values }
}
#[allow(dead_code)]
pub fn lookup(&self, key: T1) -> T2 {
match self.keys.binary_search_by(|probe| probe.partial_cmp(&key).unwrap()) {
Ok(index) => self.values[index],
Err(index) => {
let upper_key = &self.keys[index];
let upper_val = &self.values[index];
let low_index = index - 1;
let lower_key = &self.keys[low_index];
let lower_val = &self.values[low_index];
// select nearest neighbour
let diff_upper = (key - *upper_key).abs();
let diff_lower = (key - *lower_key).abs();
let mask = diff_lower <= diff_upper;
(*lower_val * T2::from(mask as u8).expect("Failed to unwrap mask")) +
(*upper_val * T2::from(!mask as u8).expect("Failed to unwrap !mask"))
pub fn get_next(&self, key: T1) -> T2 {
let ord_key = FloatOrd(key);
let mut lower_bound = 0;
let mut upper_bound = self.keys.len() - 1;
let mut mid = (lower_bound + upper_bound) / 2;
while upper_bound - lower_bound > 1 {
if self.keys[mid] < ord_key {
lower_bound = mid;
} else {
upper_bound = mid;
}
mid = (lower_bound + upper_bound) / 2;
}
self.values[mid]
}
pub fn lookup(&self, key: T1) -> T2 {
self.get_next(key)
}
}
impl_fbitfbit_lookup_table!(f32, f32);
impl_fbitfbit_lookup_table!(f64, f64);
impl_fbitfbit_lookup_table!(f32, f64);
impl_fbitfbit_lookup_table!(f64, f32);
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Clone)]
pub struct CyclingFloatLookupTable<T1, T2>
where T1: Float,
T2: Float,
where
T1: Float,
T2: Float,
{
lookup_table: FloatLookupTable<T1, T2>,
lower_bound: T1,
upper_bound: T1,
bound_range: T1,
}
impl<T1, T2> CyclingFloatLookupTable<T1, T2>
where T1: Float,
T2: Float,
where
T1: Float,
T2: Float,
{
pub fn new(keys: Vec<T1>, values: Vec<T2>, lower_bound: T1, upper_bound: T1) -> Self {
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE], lower_bound: T1, upper_bound: T1) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::new(keys, values),
lower_bound: lower_bound,
upper_bound: upper_bound,
bound_range: upper_bound - lower_bound,
lower_bound,
upper_bound,
}
}
pub fn lookup(&self, key: T1) -> T2 {
let key = (key % self.bound_range) + self.lower_bound;
let key = (key % (self.upper_bound - self.lower_bound)) + self.lower_bound;
self.lookup_table.lookup(key)
}
}
impl_cycling_fbitfbit_lookup_table!(f32, f32);
impl_cycling_fbitfbit_lookup_table!(f64, f64);
impl_cycling_fbitfbit_lookup_table!(f32, f64);
impl_cycling_fbitfbit_lookup_table!(f64, f32);
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Clone)]
pub struct EndoSinLookupTable<T>
where
T: Float + FloatConst,
@@ -94,22 +158,6 @@ impl<T> EndoSinLookupTable<T>
where
T: Float + FloatConst,
{
pub fn new(precision: usize) -> Self {
let mut keys = Vec::with_capacity(precision);
let mut values = Vec::with_capacity(precision);
let upper_bound = T::PI();
let step = T::FRAC_PI_2() / <T as NumCast>::from(precision).unwrap();
for i in 0..precision+1 {
let key = step * <T as NumCast>::from(i).unwrap();
let value = key.sin();
keys.push(key);
values.push(value);
}
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::new(keys, values, T::zero(), upper_bound),
}
}
#[allow(dead_code)]
pub fn lookup(&self, key: T) -> T {
if key < T::zero() {
@@ -123,27 +171,57 @@ where
}
}
}
impl EndoSinLookupTable<f32>
{
pub const fn new() -> Self {
const UPPER_BOUND: f32 = f32_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, UPPER_BOUND),
}
}
}
impl EndoSinLookupTable<f64>
{
pub const fn new() -> Self {
let upper_bound = f64_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, upper_bound),
}
}
}
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
#[derive(Debug, Clone)]
pub struct EndoCosLookupTable<T>
where
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>,
T: Float + FloatConst,
{
lookup_table: EndoSinLookupTable<T>,
}
impl<T> EndoCosLookupTable<T>
where
T: Float + FloatConst + Signed + Sub<Output = T> + Rem<Output = T> + NumCast + From<u8>,
T: Float + FloatConst + std::fmt::Debug,
{
pub fn new(precision: usize) -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::new(precision),
}
}
#[allow(dead_code)]
pub fn lookup(&self, key: T) -> T {
self.lookup_table.lookup(key + T::FRAC_PI_2())
}
}
impl EndoCosLookupTable<f32>
{
pub const fn new() -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::<f32>::new(),
}
}
}
impl EndoCosLookupTable<f64>
{
pub const fn new() -> Self {
EndoCosLookupTable {
lookup_table: EndoSinLookupTable::<f64>::new(),
}
}
}

View File

@@ -2,3 +2,5 @@ mod const_tables;
pub mod lookup_table;
pub use const_tables::*;
include!("config.rs");

12
src/main.rs Normal file
View File

@@ -0,0 +1,12 @@
use fastmath::LookupCos;
fn main() {
let x = (-10000..10000)
.map(|a| (a as f64) / 1000.)
.collect::<Vec<f64>>();
let y = x.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>();
// to ensure the compiler doesn't optimize away the function call, we print the result
println!("{:?}", y);
}

View File

@@ -39,7 +39,7 @@ mod tests {
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f64>>()
&X.iter().map(|&x| 2.0f64.powf(x)).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN");
assert!(
@@ -77,6 +77,17 @@ mod tests {
"fast_cos<f64> percentage error: {0}",
percentage_error
);
// lookup
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "lookup_cos<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"lookup_cos<f64> percentage error: {0}",
percentage_error
);
Ok(())
}
@@ -114,7 +125,7 @@ mod tests {
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.powi(2)).collect::<Vec<f32>>()
&X.iter().map(|&x| 2.0f32.powf(x)).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
@@ -139,6 +150,13 @@ mod tests {
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE
);
// lookup
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
}