Dev checkpoint - testing lookup tables

This commit is contained in:
Cian Hughes
2023-09-15 09:34:15 +01:00
parent 14d057e4ff
commit a9b671986e
27 changed files with 771 additions and 411 deletions

10
.gitignore vendored
View File

@@ -1,4 +1,8 @@
/target
/.vscode
/tmp
Cargo.lock
debug/
target/
**/*.rs.bk
*.pdb
.vscode/
tmp/
*.ipynb

151
Cargo.lock generated
View File

@@ -120,6 +120,17 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
[[package]]
name = "colored"
version = "2.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6"
dependencies = [
"is-terminal",
"lazy_static",
"windows-sys 0.48.0",
]
[[package]]
name = "criterion"
version = "0.5.1"
@@ -213,7 +224,7 @@ checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a"
dependencies = [
"errno-dragonfly",
"libc",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@@ -232,9 +243,12 @@ version = "0.1.0"
dependencies = [
"bincode",
"criterion",
"log",
"num-traits",
"once_cell",
"serde",
"serde_json",
"simple_logger",
]
[[package]]
@@ -257,7 +271,7 @@ checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
dependencies = [
"hermit-abi",
"rustix",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@@ -284,6 +298,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.147"
@@ -336,6 +356,15 @@ dependencies = [
"libc",
]
[[package]]
name = "num_threads"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
dependencies = [
"libc",
]
[[package]]
name = "once_cell"
version = "1.18.0"
@@ -455,7 +484,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys",
"windows-sys",
"windows-sys 0.48.0",
]
[[package]]
@@ -501,15 +530,27 @@ dependencies = [
[[package]]
name = "serde_json"
version = "1.0.103"
version = "1.0.106"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d03b412469450d4404fe8499a268edd7f8b79fecb074b0d812ad64ca21f4031b"
checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "simple_logger"
version = "4.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2230cd5c29b815c9b699fb610b49a5ed65588f3509d9f0108be3a885da629333"
dependencies = [
"colored",
"log",
"time",
"windows-sys 0.42.0",
]
[[package]]
name = "syn"
version = "2.0.26"
@@ -521,6 +562,35 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "time"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
dependencies = [
"itoa",
"libc",
"num_threads",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
dependencies = [
"time-core",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"
@@ -642,6 +712,21 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm 0.42.2",
"windows_aarch64_msvc 0.42.2",
"windows_i686_gnu 0.42.2",
"windows_i686_msvc 0.42.2",
"windows_x86_64_gnu 0.42.2",
"windows_x86_64_gnullvm 0.42.2",
"windows_x86_64_msvc 0.42.2",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@@ -657,51 +742,93 @@ version = "0.48.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
"windows_aarch64_gnullvm 0.48.0",
"windows_aarch64_msvc 0.48.0",
"windows_i686_gnu 0.48.0",
"windows_i686_msvc 0.48.0",
"windows_x86_64_gnu 0.48.0",
"windows_x86_64_gnullvm 0.48.0",
"windows_x86_64_msvc 0.48.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
[[package]]
name = "windows_i686_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
[[package]]
name = "windows_i686_msvc"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.0"

View File

@@ -7,9 +7,11 @@ edition = "2021"
[dependencies]
bincode = "1.3.3"
log = "0.4.19"
num-traits = "0.2.15"
once_cell = "1.18.0"
serde = {version = "1.0.171", features = ["derive"] }
simple_logger = "4.2.0"
[build-dependencies]
bincode = "1.3.3"
@@ -26,5 +28,11 @@ bench = true
name = "bench"
harness = false
[[bench]]
name = "devbench"
harness = false
[dev-dependencies]
criterion = "0.5.1"
serde = "1.0.171"
serde_json = "1.0.106"

View File

@@ -5,18 +5,23 @@ use fastmath::*;
use criterion::{Criterion, BenchmarkGroup, measurement::WallTime};
use criterion::{black_box, criterion_group, criterion_main};
pub mod exact {
include!("../src/tests/accuracy/exact.rs");
}
include!("../src/tests/accuracy/x.rs");
fn pow2_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[f32]) {
group.bench_function("f64_fast", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin_fn", |b| {
b.iter(|| x_f64.iter().map(|&x| 2.0f64.powf(black_box(x))).collect::<Vec<f64>>())
b.iter(|| x_f64.iter().map(|&x| exact::f64::pow2(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin_fn", |b| {
b.iter(|| x_f32.iter().map(|&x| 2.0f32.powf(black_box(x))).collect::<Vec<f32>>())
b.iter(|| x_f32.iter().map(|&x| exact::f32::pow2(black_box(x))).collect::<Vec<f32>>())
});
}
@@ -25,13 +30,13 @@ fn exp_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_exp()).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x).exp()).collect::<Vec<f64>>())
b.iter(|| x_f64.iter().map(|&x| exact::f64::exp(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_exp()).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).exp()).collect::<Vec<f32>>())
b.iter(|| x_f32.iter().map(|&x| exact::f32::exp(black_box(x))).collect::<Vec<f32>>())
});
}
@@ -43,7 +48,7 @@ fn cos_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
b.iter(|| x_f64.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin", |b| {
b.iter(|| x_f64.iter().map(|&x| black_box(x).cos()).collect::<Vec<f64>>())
b.iter(|| x_f64.iter().map(|&x| exact::f64::cos(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_cos()).collect::<Vec<f32>>())
@@ -52,7 +57,7 @@ fn cos_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
b.iter(|| x_f32.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).cos()).collect::<Vec<f32>>())
b.iter(|| x_f32.iter().map(|&x| exact::f32::cos(black_box(x))).collect::<Vec<f32>>())
});
}
@@ -61,42 +66,35 @@ fn sigmoid_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_sigmoid()).collect::<Vec<f64>>())
});
group.bench_function("f64_builtin", |b| {
b.iter(|| x_f64.iter().map(|&x| sigmoid_builtin_f64(black_box(x))).collect::<Vec<f64>>())
b.iter(|| x_f64.iter().map(|&x| exact::f64::sigmoid(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("f32_fast", |b| {
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_sigmoid()).collect::<Vec<f32>>())
});
group.bench_function("f32_builtin", |b| {
b.iter(|| x_f32.iter().map(|&x| sigmoid_builtin_f32(black_box(x))).collect::<Vec<f32>>())
b.iter(|| x_f32.iter().map(|&x| exact::f32::sigmoid(black_box(x))).collect::<Vec<f32>>())
});
}
fn criterion_benchmark(c: &mut Criterion) {
// Prepare x values for testing functions
let x_f64 = (-10000..10000)
.map(|a| (a as f64) / 1000.)
.collect::<Vec<f64>>();
let x_f32 = (-10000..10000)
.map(|a| (a as f32) / 1000.)
.collect::<Vec<f32>>();
// to ensure tests are fair, we need to instantiate the lookup tables
1.0f64.lookup_cos();
1.0f32.lookup_cos();
// Then, tests can begin
let mut group = c.benchmark_group("pow2");
pow2_benchmarks(&mut group, &x_f64, &x_f32);
pow2_benchmarks(&mut group, &X_F64, &X_F32);
group.finish();
let mut group = c.benchmark_group("exp");
exp_benchmarks(&mut group, &x_f64, &x_f32);
exp_benchmarks(&mut group, &X_F64, &X_F32);
group.finish();
let mut group = c.benchmark_group("cos");
cos_benchmarks(&mut group, &x_f64, &x_f32);
cos_benchmarks(&mut group, &X_F64, &X_F32);
group.finish();
let mut group = c.benchmark_group("sigmoid");
sigmoid_benchmarks(&mut group, &x_f64, &x_f32);
sigmoid_benchmarks(&mut group, &X_F64, &X_F32);
group.finish();
}

46
benches/devbench.rs Normal file
View File

@@ -0,0 +1,46 @@
#![allow(dead_code, unused_imports)]
extern crate fastmath;
use fastmath::*;
use criterion::{Criterion, BenchmarkGroup, measurement::WallTime};
use criterion::{black_box, criterion_group, criterion_main};
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
pub mod exact {
include!("../src/tests/accuracy/exact.rs");
}
include!("../src/tests/accuracy/x.rs");
fn dev_cos(x: f64) -> f64 {
const ONE: f64 = 1.0;
let v = ((((x + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs();
let qpprox = ONE - f64_consts::FRAC_2_PI * v;
qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
}
fn devbench(group: &mut BenchmarkGroup<WallTime>) {
group.bench_function("dev_cos", |b| {
b.iter(|| X_F64.iter().map(|&x| dev_cos(black_box(x))).collect::<Vec<f64>>())
});
group.bench_function("fast_cos", |b| {
b.iter(|| X_F64.iter().map(|&x| black_box(x).fast_cos()).collect::<Vec<f64>>())
});
group.bench_function("lookup_cos", |b| {
b.iter(|| X_F64.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f64>>())
});
group.bench_function("builtin_cos", |b| {
b.iter(|| X_F64.iter().map(|&x| exact::f64::cos(black_box(x))).collect::<Vec<f64>>())
});
}
fn criterion_benchmark(c: &mut Criterion) {
// Then, tests can begin
let mut group = c.benchmark_group("devbench");
devbench(&mut group);
group.finish();
}
criterion_group!(devbenches, criterion_benchmark);
criterion_main!(devbenches);

154
build.rs
View File

@@ -6,126 +6,39 @@ mod precalculate_lookup_tables {
use std::fs::{create_dir_all, File};
use std::io::Write;
include!("src/lookup/config.rs");
// use bincode::serialize;
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/sin_f32.bin")?;
// file.write_all(&data)?;
// let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/sin_f64.bin")?;
// file.write_all(&data)?;
// Ok(())
// }
// fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
// let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f32.bin")?;
// file.write_all(&data)?;
// let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
// let mut file = File::create("src/lookup/data/cos_f64.bin")?;
// file.write_all(&data)?;
// Ok(())
// }
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)) - half_step
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
// file.write_all(data.as_bytes())?;
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
// let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
include!("src/lookup/ordinal_float.rs");
macro_rules! precalculate_sin_tables {
() => {{
let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
let half_step: f32 = step / 2.0;
let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)) - half_step
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let keys: [FloatOrd<f32>; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
FloatOrd( (step * (i as f32)) - half_step )
}).collect::<Vec<FloatOrd<f32>>>().try_into().unwrap_or([FloatOrd::new(); TABLE_SIZE]);
let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f32)).sin()
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
let data = format!("pub const SIN_F32_KEYS: [f32; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let data = format!("pub(crate) const SIN_F32_KEYS: [FloatOrd<f32>; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f32.rs").expect("Failed to create sin_f32.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f32.rs");
let mut file = File::create("src/lookup/data/sin_f32.rs")?;
file.write_all(data.as_bytes())?;
let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
let half_step: f64 = step / 2.0;
let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)) - half_step
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let keys: [FloatOrd<f64>; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
FloatOrd( (step * (i as f64)) - half_step )
}).collect::<Vec<FloatOrd<f64>>>().try_into().unwrap_or([FloatOrd::new(); TABLE_SIZE]);
let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
(step * (i as f64)).sin()
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
let data = format!("pub const SIN_F64_KEYS: [f64; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let data = format!("pub const SIN_F64_KEYS: [FloatOrd<f64>; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
let mut file = File::create("src/lookup/data/sin_f64.rs").expect("Failed to create sin_f64.rs");
file.write_all(data.as_bytes()).expect("Failed to write sin_f64.rs");
let mut file = File::create("src/lookup/data/sin_f64.rs")?;
file.write_all(data.as_bytes())?;
}};
}
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
// let half_step: f32 = step / 2.0;
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)) - half_step
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f32)).sin()
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
// file.write_all(data.as_bytes())?;
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
// let half_step: f64 = step / 2.0;
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)) - half_step
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
// (step * (i as f64)).sin()
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
// file.write_all(data.as_bytes())?;
// Ok(())
// }
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
create_dir_all("src/lookup/data")?;
@@ -137,8 +50,49 @@ mod precalculate_lookup_tables {
}
}
mod precalculate_test_tables {
use std::fs::{create_dir_all, File};
use std::io::Write;
include!("src/tests/accuracy/config.rs");
macro_rules! precalculate_test_tables {
() => {{
let scaling: f32 = X_SIZE as f32 / (X_MAX as f32 - X_MIN as f32);
let x_f32: [f32; X_SIZE] =
(0..X_SIZE)
.map(|a| ((a as f32) / scaling) + X_MIN as f32)
.collect::<Vec<f32>>()
.try_into().map_err(|_| "Failed to convert Vec<f32> to [f32; X_SIZE]")?;
let scaling: f64 = X_SIZE as f64 / (X_MAX as f64 - X_MIN as f64);
let x_f64: [f64; X_SIZE] =
(0..X_SIZE)
.map(|a| ((a as f64) / scaling) + X_MIN as f64)
.collect::<Vec<f64>>()
.try_into().map_err(|_| "Failed to convert Vec<f64> to [f64; X_SIZE]")?;
let data = format!(
"#[cfg(test)]\n#[allow(dead_code)]\npub const X_F32: [f32; {}] = {:?};\n#[cfg(test)]\n#[allow(dead_code)]\npub const X_F64: [f64; {}] = {:?};",
X_SIZE, x_f32, X_SIZE, x_f64
);
let mut file = File::create("src/tests/accuracy/x.rs")?;
file.write_all(data.as_bytes())?;
}};
}
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
create_dir_all("src/tests/accuracy")?;
precalculate_test_tables!();
Ok(())
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
precalculate_lookup_tables::generate()?;
precalculate_test_tables::generate()?;
Ok(())
}

73
src/bit_magic.rs Normal file
View File

@@ -0,0 +1,73 @@
fn bool_to_full_byte<T>(b: bool) -> T
where
T: WrappingSub + One + From<u8> + std::ops::Not<Output=T>,
{
!(
( T::from( unsafe { std::mem::transmute::<bool, u8>(b) } ) )
.wrapping_sub(&T::one())
)
}
pub trait AndBool {
fn and(self: Self, b: bool) -> Self;
}
impl AndBool for f32 {
fn and(self: f32, b: bool) -> f32 {
let b_byte: u32 = bool_to_full_byte(b);
unsafe {
std::mem::transmute(
std::mem::transmute::<f32, u32>(self) & b_byte
)
}
}
}
impl AndBool for f64 {
fn and(self: f64, b: bool) -> f64 {
let b_byte: u64 = bool_to_full_byte(b);
unsafe {
std::mem::transmute(
std::mem::transmute::<f64, u64>(self) & b_byte
)
}
}
}
pub trait GetSign {
fn sign(self: Self) -> Self;
}
impl GetSign for f32 {
fn sign(self: f32) -> f32 {
let x_bytes: [u8; 4] = unsafe { std::mem::transmute(self) };
unsafe {
std::mem::transmute(
[
0u8,
0u8,
128u8,
(x_bytes[3] & 128u8) | 63u8,
]
)
}
}
}
impl GetSign for f64 {
fn sign(self: f64) -> f64 {
let x_bytes: [u8; 8] = unsafe { std::mem::transmute(self) };
unsafe {
std::mem::transmute(
[
0u8,
0u8,
0u8,
0u8,
0u8,
0u8,
240u8,
(x_bytes[7] & 128u8) | 63u8,
]
)
}
}
}

View File

@@ -1,12 +1,13 @@
//! A collection of fast (often approximate) mathematical functions for accelerating mathematical functions
// Optimisation note: lookup tables become faster when calculation takes > ~1ms
// Optimisation note: lookup tables become faster when calculation takes > ~400us
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
// use crate::lookup::*;
use crate::lookup::lookup_table::EndoCosLookupTable;
use crate::lookup::{EndoCosLookupTable, EndoSinLookupTable};
const SIN_LOOKUP_F32: EndoSinLookupTable<f32> = EndoSinLookupTable::<f32>::new();
const SIN_LOOKUP_F64: EndoSinLookupTable<f64> = EndoSinLookupTable::<f64>::new();
const COS_LOOKUP_F32: EndoCosLookupTable<f32> = EndoCosLookupTable::<f32>::new();
const COS_LOOKUP_F64: EndoCosLookupTable<f64> = EndoCosLookupTable::<f64>::new();
@@ -18,6 +19,25 @@ const V_SCALE_F32: f32 = 8388608.0; // the largest possible mantissa of an f32
const V_SCALE_F64: f64 = 4503599627370496.0; // the largest possible mantissa of an f64
pub trait LookupSin {
fn lookup_sin(self: Self) -> Self;
}
impl LookupSin for f64 {
#[inline]
fn lookup_sin(self: Self) -> f64 {
// Look up the value in the table
SIN_LOOKUP_F64.lookup(self)
}
}
impl LookupSin for f32 {
#[inline]
fn lookup_sin(self: Self) -> f32 {
// Look up the value in the table
SIN_LOOKUP_F32.lookup(self)
}
}
pub trait LookupCos {
fn lookup_cos(self: Self) -> Self;
}
@@ -118,15 +138,4 @@ impl FastSigmoid for f64 {
const ONE: f64 = 1.0;
(ONE + (-self).fast_exp()).recip()
}
}
// functions for testing the accuracy of fast functions against builtin functions
#[inline]
pub fn sigmoid_builtin_f32(p: f32) -> f32 {
(1. + (-p).exp()).recip()
}
#[inline]
pub fn sigmoid_builtin_f64(p: f64) -> f64 {
(1. + (-p).exp()).recip()
}
}

View File

@@ -2,9 +2,10 @@
#![allow(unused_imports)]
pub mod lookup;
mod fastmath;
pub mod macros;
mod fastmath;
pub use fastmath::*;
#[cfg(test)]
mod tests;
pub(crate) mod tests;

View File

@@ -1,4 +1,5 @@
// lookup/const_tables.rs
use crate::lookup::ordinal_float::FloatOrd;
include!("data/sin_f32.rs");
include!("data/sin_f64.rs");

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,75 +1,25 @@
use std::f32::consts as f32_consts;
use std::f64::consts as f64_consts;
use std::cmp::Ordering;
use num_traits::identities::One;
use num_traits::ops::wrapping::WrappingSub;
use num_traits::float::{Float, FloatConst};
use crate::{
impl_fbitfbit_lookup_table,
impl_cycling_fbitfbit_lookup_table,
};
use crate::lookup::TABLE_SIZE;
use crate::lookup::ordinal_float::FloatOrd;
use crate::lookup::const_tables::*;
// This function should never be used in a non-const context.
// It only exists as a workaround for the fact that const fn's cannot use iterators.
const fn make_ordinal<T: Float>(
input: [T; TABLE_SIZE],
mut map_target: [FloatOrd<T>; TABLE_SIZE],
) -> () {
// let mut map_target = [FloatOrd::<T>::new(); TABLE_SIZE];
let mut index = 0;
while index < TABLE_SIZE {
map_target[index] = FloatOrd(input[index]);
index += 1;
}
}
// The following macros are to minimise the amount of boilerplate for static types on the lookup tables.
macro_rules! impl_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl FloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
let ord_keys: [FloatOrd<$key_type>; TABLE_SIZE] = [FloatOrd(0.0 as $key_type); TABLE_SIZE];
make_ordinal(keys, ord_keys);
FloatLookupTable {
keys: ord_keys,
values,
}
}
}
};
}
macro_rules! impl_cycling_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl CyclingFloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, upper_bound: $key_type) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
lower_bound,
upper_bound,
}
}
}
};
}
#[derive(Default, Debug,Clone, Copy, PartialEq, PartialOrd)]
pub struct FloatOrd<T: Float>(pub T);
impl<T: Float> FloatOrd<T> {
pub fn new() -> Self {
FloatOrd(T::zero())
}
}
impl<T: Float> Eq for FloatOrd<T> {}
impl<T: Float> Ord for FloatOrd<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone)]
pub struct FloatLookupTable<T1, T2>
where
T1: Float,
T2: Float,
FloatOrd<T1>: Ord,
{
keys: [FloatOrd<T1>; TABLE_SIZE],
values: [T2; TABLE_SIZE],
@@ -78,6 +28,7 @@ impl<T1, T2> FloatLookupTable<T1, T2>
where
T1: Float,
T2: Float,
FloatOrd<T1>: Ord,
{
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE]) -> Self {
FloatLookupTable {
@@ -86,20 +37,22 @@ where
}
}
pub fn get_next(&self, key: T1) -> T2 {
pub fn get_next(&self, key: T1) -> T2
{
let ord_key = FloatOrd(key);
let mut lower_bound = 0;
let mut upper_bound = self.keys.len() - 1;
let mut mid = (lower_bound + upper_bound) / 2;
while upper_bound - lower_bound > 1 {
let mut mid: usize;
while lower_bound < upper_bound {
mid = lower_bound + (upper_bound - lower_bound) / 2;
if self.keys[mid] < ord_key {
lower_bound = mid;
lower_bound = mid + 1;
} else {
upper_bound = mid;
}
mid = (lower_bound + upper_bound) / 2;
}
self.values[mid]
self.values[upper_bound]
}
pub fn lookup(&self, key: T1) -> T2 {
@@ -117,27 +70,30 @@ pub struct CyclingFloatLookupTable<T1, T2>
where
T1: Float,
T2: Float,
FloatOrd<T1>: Ord,
{
lookup_table: FloatLookupTable<T1, T2>,
lower_bound: T1,
upper_bound: T1,
range: T1,
}
impl<T1, T2> CyclingFloatLookupTable<T1, T2>
where
T1: Float,
T2: Float,
FloatOrd<T1>: Ord,
{
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE], lower_bound: T1, upper_bound: T1) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::new(keys, values),
lower_bound,
upper_bound,
lower_bound: lower_bound,
range: upper_bound - lower_bound,
}
}
pub fn lookup(&self, key: T1) -> T2 {
let key = (key % (self.upper_bound - self.lower_bound)) + self.lower_bound;
self.lookup_table.lookup(key)
self.lookup_table.lookup(
(key % self.range) + self.lower_bound
)
}
}
impl_cycling_fbitfbit_lookup_table!(f32, f32);
@@ -146,19 +102,19 @@ impl_cycling_fbitfbit_lookup_table!(f32, f64);
impl_cycling_fbitfbit_lookup_table!(f64, f32);
#[derive(Debug, Clone)]
pub struct EndoSinLookupTable<T>
where
T: Float + FloatConst,
FloatOrd<T>: Ord,
{
lookup_table: CyclingFloatLookupTable<T, T>,
}
impl<T> EndoSinLookupTable<T>
where
T: Float + FloatConst,
FloatOrd<T>: Ord,
{
#[allow(dead_code)]
pub fn lookup(&self, key: T) -> T {
if key < T::zero() {
-self.lookup(-key)
@@ -166,28 +122,30 @@ where
self.lookup_table.lookup(key)
} else if key < T::PI() {
self.lookup_table.lookup(T::PI() - key)
} else {
} else if key < T::TAU() { // obviously, mod is slow so we want to avoid it until this would start recursing deeply
-self.lookup(key - T::PI())
} else {
-self.lookup(key % T::PI())
}
}
}
impl EndoSinLookupTable<f32>
{
pub const fn new() -> Self {
const UPPER_BOUND: f32 = f32_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, UPPER_BOUND),
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(
SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, f32_consts::PI
),
}
}
}
impl EndoSinLookupTable<f64>
{
pub const fn new() -> Self {
let upper_bound = f64_consts::PI;
EndoSinLookupTable {
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, upper_bound),
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(
SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, f64_consts::PI
),
}
}
}
@@ -202,9 +160,8 @@ where
}
impl<T> EndoCosLookupTable<T>
where
T: Float + FloatConst + std::fmt::Debug,
T: Float + FloatConst,
{
#[allow(dead_code)]
pub fn lookup(&self, key: T) -> T {
self.lookup_table.lookup(key + T::FRAC_PI_2())
}

View File

@@ -1,6 +1,7 @@
mod const_tables;
pub mod const_tables;
pub mod lookup_table;
pub(crate) mod ordinal_float;
pub use const_tables::*;
pub use lookup_table::*;
include!("config.rs");

View File

@@ -0,0 +1,16 @@
use num_traits::float::Float;
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd)]
pub struct FloatOrd<T: Float>(pub T);
impl<T: Float> FloatOrd<T> {
#[allow(dead_code)]
pub fn new() -> Self {
FloatOrd(T::zero())
}
}
impl<T: Float> Eq for FloatOrd<T> {}
impl<T: Float> Ord for FloatOrd<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
}
}

30
src/macros/lookup.rs Normal file
View File

@@ -0,0 +1,30 @@
mod lookup_table {
#[macro_export]
macro_rules! impl_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl FloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [FloatOrd<$key_type>; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
FloatLookupTable {
keys: keys,
values: values,
}
}
}
};
}
#[macro_export]
macro_rules! impl_cycling_fbitfbit_lookup_table {
($key_type:ty, $value_type:ty) => {
impl CyclingFloatLookupTable<$key_type, $value_type> {
pub const fn new_const(keys: [FloatOrd<$key_type>; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, range: $key_type) -> Self {
CyclingFloatLookupTable {
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
lower_bound: lower_bound,
range: range,
}
}
}
};
}
}

1
src/macros/mod.rs Normal file
View File

@@ -0,0 +1 @@
mod lookup;

View File

@@ -1,174 +0,0 @@
//tests.rs
use num_traits::Float;
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
where T: Float + std::ops::AddAssign,
{
let n = vector1.len();
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
let mut total_error = T::zero();
for i in 0..n {
let diff = (vector1[i] - vector2[i]).abs();
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
total_error += error;
}
let average_error = total_error / T::from(n).unwrap();
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
percentage_error
}
#[cfg(test)]
mod tests {
mod f64_error {
use crate::*;
use super::super::calculate_percentage_error;
use once_cell::sync::Lazy;
const TOLERANCE: f64 = 2.5;
static X: Lazy<Vec<f64>> = Lazy::new(|| {
(-10000..10000)
.map(|a| (a as f64) / 1000.)
.collect::<Vec<f64>>()
});
#[test]
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
&X.iter().map(|&x| 2.0f64.powf(x)).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"fast_pow2<f64> percentage error: {0}",
percentage_error
);
Ok(())
}
#[test]
fn exp() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_exp()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.exp()).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "fast_exp<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"fast_exp<f64> percentage error: {0}",
percentage_error
);
Ok(())
}
#[test]
fn cos() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_cos()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "fast_cos<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"fast_cos<f64> percentage error: {0}",
percentage_error
);
// lookup
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "lookup_cos<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"lookup_cos<f64> percentage error: {0}",
percentage_error
);
Ok(())
}
#[test]
fn sigmoid() -> Result<(), Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f64>>(),
&X.iter().map(|&x| sigmoid_builtin_f64(x)).collect::<Vec<f64>>()
);
assert!(!percentage_error.is_nan(), "fast_sigmoid<f64> percentage error is NaN");
assert!(
percentage_error < TOLERANCE,
"fast_sigmoid<f64> percentage error: {0}",
percentage_error
);
Ok(())
}
}
mod f32_error {
use crate::*;
use super::super::calculate_percentage_error;
use once_cell::sync::Lazy;
const TOLERANCE: f32 = 2.5;
static X: Lazy<Vec<f32>> = Lazy::new(|| {
(-10000..10000)
.map(|a| (a as f32) / 1000.)
.collect::<Vec<f32>>()
});
#[test]
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
&X.iter().map(|&x| 2.0f32.powf(x)).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
}
#[test]
fn exp() -> Result<(), Box<dyn std::error::Error>> {
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.fast_exp()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.exp()).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
}
#[test]
fn cos() -> Result<(), Box<dyn std::error::Error>> {
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.fast_cos()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE
);
// lookup
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
}
#[test]
fn sigmoid() -> Result<(), Box<dyn std::error::Error>> {
assert!(
calculate_percentage_error(
&X.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f32>>(),
&X.iter().map(|&x| sigmoid_builtin_f32(x)).collect::<Vec<f32>>()
) < TOLERANCE
);
Ok(())
}
}
}

View File

@@ -0,0 +1,130 @@
#[cfg(test)]
use super::exact;
use num_traits::Float;
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
where T: Float,
{
let n = vector1.len();
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
let mut total_error = T::zero();
for i in 0..n {
let diff = (vector1[i] - vector2[i]).abs();
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
total_error = total_error + error;
}
let average_error = total_error / T::from(n).unwrap();
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
percentage_error
}
macro_rules! panic_if_nan_or_print {
($x:expr, $varname:expr) => {
if $x.is_nan() {
Err(format!("{} is NaN!", $varname))?
} else {
println!("{}: {}%", $varname, $x);
Ok($x)
}
}
}
pub mod f64 {
use crate::*;
use super::exact;
use super::calculate_percentage_error;
include!("x.rs");
pub fn pow2() -> Result<f64, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F64.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
&X_F64.iter().map(|&x| exact::f64::pow2(x)).collect::<Vec<f64>>()
);
panic_if_nan_or_print!(percentage_error, "fast_pow2<f64> percentage error")
}
pub fn exp() -> Result<f64, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F64.iter().map(|&x| x.fast_exp()).collect::<Vec<f64>>(),
&X_F64.iter().map(|&x| exact::f64::exp(x)).collect::<Vec<f64>>()
);
panic_if_nan_or_print!(percentage_error, "fast_exp<f64> percentage error")
}
pub fn cos() -> Result<f64, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F64.iter().map(|&x| x.fast_cos()).collect::<Vec<f64>>(),
&X_F64.iter().map(|&x| exact::f64::cos(x)).collect::<Vec<f64>>()
);
panic_if_nan_or_print!(percentage_error, "fast_cos<f64> percentage error")
}
pub fn cos_lookup() -> Result<f64, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F64.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
&X_F64.iter().map(|&x| exact::f64::cos(x)).collect::<Vec<f64>>()
);
panic_if_nan_or_print!(percentage_error, "lookup_cos<f64> percentage error")
}
pub fn sigmoid() -> Result<f64, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F64.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f64>>(),
&X_F64.iter().map(|&x| exact::f64::sigmoid(x)).collect::<Vec<f64>>()
);
panic_if_nan_or_print!(percentage_error, "fast_sigmoid<f64> percentage error")
}
}
pub mod f32 {
use crate::*;
use super::exact;
use super::calculate_percentage_error;
include!("x.rs");
pub fn pow2() -> Result<f32, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F32.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
&X_F32.iter().map(|&x| exact::f32::pow2(x)).collect::<Vec<f32>>()
);
panic_if_nan_or_print!(percentage_error, "fast_pow2<f32> percentage error")
}
pub fn exp() -> Result<f32, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F32.iter().map(|&x| x.fast_exp()).collect::<Vec<f32>>(),
&X_F32.iter().map(|&x| exact::f32::exp(x)).collect::<Vec<f32>>()
);
panic_if_nan_or_print!(percentage_error, "fast_exp<f32> percentage error")
}
pub fn cos() -> Result<f32, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F32.iter().map(|&x| x.fast_cos()).collect::<Vec<f32>>(),
&X_F32.iter().map(|&x| exact::f32::cos(x)).collect::<Vec<f32>>()
);
panic_if_nan_or_print!(percentage_error, "fast_cos<f32> percentage error")
}
pub fn cos_lookup() -> Result<f32, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F32.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
&X_F32.iter().map(|&x| exact::f32::cos(x)).collect::<Vec<f32>>()
);
panic_if_nan_or_print!(percentage_error, "lookup_cos<f32> percentage error")
}
pub fn sigmoid() -> Result<f32, Box<dyn std::error::Error>> {
let percentage_error = calculate_percentage_error(
&X_F32.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f32>>(),
&X_F32.iter().map(|&x| exact::f32::sigmoid(x)).collect::<Vec<f32>>()
);
panic_if_nan_or_print!(percentage_error, "fast_sigmoid<f32> percentage error")
}
}

View File

@@ -0,0 +1,3 @@
const X_MIN: f64 = -10.0;
const X_MAX: f64 = 10.0;
const X_SIZE: usize = 20000;

View File

@@ -0,0 +1,35 @@
pub mod f64 {
pub fn pow2(n: f64) -> f64 {
2.0f64.powf(n)
}
pub fn exp(n: f64) -> f64 {
n.exp()
}
pub fn cos(n: f64) -> f64 {
n.cos()
}
pub fn sigmoid(n: f64) -> f64 {
(1. + (-n).exp()).recip()
}
}
pub mod f32 {
pub fn pow2(n: f32) -> f32 {
2.0f32.powf(n)
}
pub fn exp(n: f32) -> f32 {
n.exp()
}
pub fn cos(n: f32) -> f32 {
n.cos()
}
pub fn sigmoid(n: f32) -> f32 {
(1. + (-n).exp()).recip()
}
}

View File

@@ -0,0 +1,4 @@
mod exact;
mod comparisons;
pub use comparisons::*;

6
src/tests/accuracy/x.rs Normal file

File diff suppressed because one or more lines are too long

3
src/tests/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
#![cfg(test)]
mod accuracy;
mod tolerance;

7
src/tests/tolerance.json Normal file
View File

@@ -0,0 +1,7 @@
{
"pow2_fast": 2.0,
"exp_fast": 2.0,
"cos_fast": 1.0,
"cos_lk": 1.0,
"sigmoid_fast": 1.0
}

50
src/tests/tolerance.rs Normal file
View File

@@ -0,0 +1,50 @@
use super::accuracy;
use serde_json;
fn get_tolerance<T>(key: &str) -> Result<T, serde_json::Error>
where
T: serde::de::DeserializeOwned
{
let json: serde_json::Value =
serde_json::from_str(
include_str!("tolerance.json")
)?;
let value = serde_json::value::from_value(
json[key].clone()
)?;
Ok(value)
}
macro_rules! test_within_tolerance {
($function:ident, $t:ty, $test_name:ident) => {
#[test]
fn $test_name() -> Result<(), Box<dyn std::error::Error>> {
let tolerance: $t = get_tolerance::<$t>(stringify!($test_name))?;
let percentage_error: $t = $function()?;
assert!(percentage_error < tolerance);
Ok(())
}
};
}
mod f64 {
use super::{accuracy, get_tolerance};
use accuracy::f64::*;
test_within_tolerance!(pow2, f64, pow2_fast);
test_within_tolerance!(exp, f64, exp_fast);
test_within_tolerance!(cos, f64, cos_fast);
test_within_tolerance!(cos_lookup, f64, cos_lk);
test_within_tolerance!(sigmoid, f64, sigmoid_fast);
}
mod f32 {
use super::{accuracy, get_tolerance};
use accuracy::f32::*;
test_within_tolerance!(pow2, f32, pow2_fast);
test_within_tolerance!(exp, f32, exp_fast);
test_within_tolerance!(cos, f32, cos_fast);
test_within_tolerance!(cos_lookup, f32, cos_lk);
test_within_tolerance!(sigmoid, f32, sigmoid_fast);
}

70
test.rs Normal file
View File

@@ -0,0 +1,70 @@
use std::fs::read;
use std::f64::consts as f64_consts;
use bincode::deserialize;
use once_cell::sync::Lazy;
use ndarray::prelude::*;
use optimize::*;
use num_traits::Float;
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
where T: Float + std::ops::AddAssign,
{
let n = vector1.len();
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
let mut total_error = T::zero();
for i in 0..n {
let diff = (vector1[i] - vector2[i]).abs();
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
total_error += error;
}
let average_error = total_error / T::from(n).unwrap();
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
percentage_error
}
fn fast_exp(x: f64, clipp_thresh: f64, v_scale: f64, clipp_shift: f64) -> f64 {
// const CLIPP_THRESH: f64 = -180335.51911105003;
// const V_SCALE: f64 = 4524653012949098.0;
// const CLIPP_SHIFT: f64 = 1018.1563534409383;
let scaled_p = f64_consts::LOG2_E * x;
let clipp = scaled_p.max(clipp_thresh);
let v = (v_scale * (clipp + clipp_shift)) as u64;
f64::from_bits(v)
}
const Y: Lazy<Vec<f64>> = Lazy::new(|| { deserialize(&read("tmp/Y.bin").unwrap()).unwrap() } );
fn objective(args: ArrayView1<f64>) -> f64 {
let clipp_thresh: f64 = args[0];
let v_scale: f64 = args[1];
let clipp_shift: f64 = args[2];
let X: Vec<f64> = (-10000..10000)
.map(|a| (a as f64) / 1000.)
.collect::<Vec<f64>>();
let Y_hat: Vec<f64> = X.iter().map(|&x| fast_exp(x, clipp_thresh, v_scale, clipp_shift)).collect::<Vec<f64>>();
calculate_percentage_error(&(*Y), &Y_hat)
}
fn optimize_params() {
// Create a minimizer using the builder pattern.
let minimizer = NelderMeadBuilder::default()
.xtol(1e-6f64)
.ftol(1e-6f64)
.maxiter(50000)
.build()
.unwrap();
// Set the starting guess
let args: Array1<f64> = Array1::from_vec(vec![-180335.51911105003, 4524653012949098.0, 1018.1563534409383]);
// Run the optimization
let ans = minimizer.minimize(objective, args.view());
// Print the optimized values
println!("Final optimized arguments: {}", ans);
}