mirror of
https://github.com/Cian-H/fastmath.git
synced 2025-12-22 14:12:01 +00:00
Dev checkpoint - testing lookup tables
This commit is contained in:
10
.gitignore
vendored
10
.gitignore
vendored
@@ -1,4 +1,8 @@
|
||||
/target
|
||||
/.vscode
|
||||
/tmp
|
||||
Cargo.lock
|
||||
debug/
|
||||
target/
|
||||
**/*.rs.bk
|
||||
*.pdb
|
||||
.vscode/
|
||||
tmp/
|
||||
*.ipynb
|
||||
151
Cargo.lock
generated
151
Cargo.lock
generated
@@ -120,6 +120,17 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
version = "2.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2674ec482fbc38012cf31e6c42ba0177b431a0cb6f15fe40efa5aab1bda516f6"
|
||||
dependencies = [
|
||||
"is-terminal",
|
||||
"lazy_static",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
@@ -213,7 +224,7 @@ checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a"
|
||||
dependencies = [
|
||||
"errno-dragonfly",
|
||||
"libc",
|
||||
"windows-sys",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -232,9 +243,12 @@ version = "0.1.0"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"criterion",
|
||||
"log",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"simple_logger",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -257,7 +271,7 @@ checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"rustix",
|
||||
"windows-sys",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -284,6 +298,12 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.147"
|
||||
@@ -336,6 +356,15 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_threads"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.18.0"
|
||||
@@ -455,7 +484,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -501,15 +530,27 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.103"
|
||||
version = "1.0.106"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d03b412469450d4404fe8499a268edd7f8b79fecb074b0d812ad64ca21f4031b"
|
||||
checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simple_logger"
|
||||
version = "4.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2230cd5c29b815c9b699fb610b49a5ed65588f3509d9f0108be3a885da629333"
|
||||
dependencies = [
|
||||
"colored",
|
||||
"log",
|
||||
"time",
|
||||
"windows-sys 0.42.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.26"
|
||||
@@ -521,6 +562,35 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"libc",
|
||||
"num_threads",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
|
||||
dependencies = [
|
||||
"time-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
@@ -642,6 +712,21 @@ version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.42.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm 0.42.2",
|
||||
"windows_aarch64_msvc 0.42.2",
|
||||
"windows_i686_gnu 0.42.2",
|
||||
"windows_i686_msvc 0.42.2",
|
||||
"windows_x86_64_gnu 0.42.2",
|
||||
"windows_x86_64_gnullvm 0.42.2",
|
||||
"windows_x86_64_msvc 0.42.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.48.0"
|
||||
@@ -657,51 +742,93 @@ version = "0.48.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05d4b17490f70499f20b9e791dcf6a299785ce8af4d709018206dc5b4953e95f"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
"windows_aarch64_gnullvm 0.48.0",
|
||||
"windows_aarch64_msvc 0.48.0",
|
||||
"windows_i686_gnu 0.48.0",
|
||||
"windows_i686_msvc 0.48.0",
|
||||
"windows_x86_64_gnu 0.48.0",
|
||||
"windows_x86_64_gnullvm 0.48.0",
|
||||
"windows_x86_64_msvc 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.48.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.42.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.48.0"
|
||||
|
||||
@@ -7,9 +7,11 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bincode = "1.3.3"
|
||||
log = "0.4.19"
|
||||
num-traits = "0.2.15"
|
||||
once_cell = "1.18.0"
|
||||
serde = {version = "1.0.171", features = ["derive"] }
|
||||
simple_logger = "4.2.0"
|
||||
|
||||
[build-dependencies]
|
||||
bincode = "1.3.3"
|
||||
@@ -26,5 +28,11 @@ bench = true
|
||||
name = "bench"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "devbench"
|
||||
harness = false
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.5.1"
|
||||
serde = "1.0.171"
|
||||
serde_json = "1.0.106"
|
||||
|
||||
@@ -5,18 +5,23 @@ use fastmath::*;
|
||||
use criterion::{Criterion, BenchmarkGroup, measurement::WallTime};
|
||||
use criterion::{black_box, criterion_group, criterion_main};
|
||||
|
||||
pub mod exact {
|
||||
include!("../src/tests/accuracy/exact.rs");
|
||||
}
|
||||
include!("../src/tests/accuracy/x.rs");
|
||||
|
||||
fn pow2_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[f32]) {
|
||||
group.bench_function("f64_fast", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin_fn", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| 2.0f64.powf(black_box(x))).collect::<Vec<f64>>())
|
||||
b.iter(|| x_f64.iter().map(|&x| exact::f64::pow2(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f32_fast", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_pow2()).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin_fn", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| 2.0f32.powf(black_box(x))).collect::<Vec<f32>>())
|
||||
b.iter(|| x_f32.iter().map(|&x| exact::f32::pow2(black_box(x))).collect::<Vec<f32>>())
|
||||
});
|
||||
}
|
||||
|
||||
@@ -25,13 +30,13 @@ fn exp_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_exp()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).exp()).collect::<Vec<f64>>())
|
||||
b.iter(|| x_f64.iter().map(|&x| exact::f64::exp(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f32_fast", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_exp()).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).exp()).collect::<Vec<f32>>())
|
||||
b.iter(|| x_f32.iter().map(|&x| exact::f32::exp(black_box(x))).collect::<Vec<f32>>())
|
||||
});
|
||||
}
|
||||
|
||||
@@ -43,7 +48,7 @@ fn cos_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).cos()).collect::<Vec<f64>>())
|
||||
b.iter(|| x_f64.iter().map(|&x| exact::f64::cos(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f32_fast", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_cos()).collect::<Vec<f32>>())
|
||||
@@ -52,7 +57,7 @@ fn cos_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32: &[
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).cos()).collect::<Vec<f32>>())
|
||||
b.iter(|| x_f32.iter().map(|&x| exact::f32::cos(black_box(x))).collect::<Vec<f32>>())
|
||||
});
|
||||
}
|
||||
|
||||
@@ -61,42 +66,35 @@ fn sigmoid_benchmarks(group: &mut BenchmarkGroup<WallTime>, x_f64: &[f64], x_f32
|
||||
b.iter(|| x_f64.iter().map(|&x| black_box(x).fast_sigmoid()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f64_builtin", |b| {
|
||||
b.iter(|| x_f64.iter().map(|&x| sigmoid_builtin_f64(black_box(x))).collect::<Vec<f64>>())
|
||||
b.iter(|| x_f64.iter().map(|&x| exact::f64::sigmoid(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("f32_fast", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| black_box(x).fast_sigmoid()).collect::<Vec<f32>>())
|
||||
});
|
||||
group.bench_function("f32_builtin", |b| {
|
||||
b.iter(|| x_f32.iter().map(|&x| sigmoid_builtin_f32(black_box(x))).collect::<Vec<f32>>())
|
||||
b.iter(|| x_f32.iter().map(|&x| exact::f32::sigmoid(black_box(x))).collect::<Vec<f32>>())
|
||||
});
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
// Prepare x values for testing functions
|
||||
let x_f64 = (-10000..10000)
|
||||
.map(|a| (a as f64) / 1000.)
|
||||
.collect::<Vec<f64>>();
|
||||
let x_f32 = (-10000..10000)
|
||||
.map(|a| (a as f32) / 1000.)
|
||||
.collect::<Vec<f32>>();
|
||||
// to ensure tests are fair, we need to instantiate the lookup tables
|
||||
1.0f64.lookup_cos();
|
||||
1.0f32.lookup_cos();
|
||||
// Then, tests can begin
|
||||
let mut group = c.benchmark_group("pow2");
|
||||
pow2_benchmarks(&mut group, &x_f64, &x_f32);
|
||||
pow2_benchmarks(&mut group, &X_F64, &X_F32);
|
||||
group.finish();
|
||||
|
||||
let mut group = c.benchmark_group("exp");
|
||||
exp_benchmarks(&mut group, &x_f64, &x_f32);
|
||||
exp_benchmarks(&mut group, &X_F64, &X_F32);
|
||||
group.finish();
|
||||
|
||||
let mut group = c.benchmark_group("cos");
|
||||
cos_benchmarks(&mut group, &x_f64, &x_f32);
|
||||
cos_benchmarks(&mut group, &X_F64, &X_F32);
|
||||
group.finish();
|
||||
|
||||
let mut group = c.benchmark_group("sigmoid");
|
||||
sigmoid_benchmarks(&mut group, &x_f64, &x_f32);
|
||||
sigmoid_benchmarks(&mut group, &X_F64, &X_F32);
|
||||
group.finish();
|
||||
}
|
||||
|
||||
|
||||
46
benches/devbench.rs
Normal file
46
benches/devbench.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
extern crate fastmath;
|
||||
|
||||
use fastmath::*;
|
||||
use criterion::{Criterion, BenchmarkGroup, measurement::WallTime};
|
||||
use criterion::{black_box, criterion_group, criterion_main};
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
|
||||
pub mod exact {
|
||||
include!("../src/tests/accuracy/exact.rs");
|
||||
}
|
||||
include!("../src/tests/accuracy/x.rs");
|
||||
|
||||
fn dev_cos(x: f64) -> f64 {
|
||||
const ONE: f64 = 1.0;
|
||||
let v = ((((x + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs();
|
||||
let qpprox = ONE - f64_consts::FRAC_2_PI * v;
|
||||
qpprox + f64_consts::FRAC_PI_6 * qpprox * (ONE - qpprox * qpprox)
|
||||
}
|
||||
|
||||
fn devbench(group: &mut BenchmarkGroup<WallTime>) {
|
||||
group.bench_function("dev_cos", |b| {
|
||||
b.iter(|| X_F64.iter().map(|&x| dev_cos(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("fast_cos", |b| {
|
||||
b.iter(|| X_F64.iter().map(|&x| black_box(x).fast_cos()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("lookup_cos", |b| {
|
||||
b.iter(|| X_F64.iter().map(|&x| black_box(x).lookup_cos()).collect::<Vec<f64>>())
|
||||
});
|
||||
group.bench_function("builtin_cos", |b| {
|
||||
b.iter(|| X_F64.iter().map(|&x| exact::f64::cos(black_box(x))).collect::<Vec<f64>>())
|
||||
});
|
||||
}
|
||||
|
||||
fn criterion_benchmark(c: &mut Criterion) {
|
||||
// Then, tests can begin
|
||||
let mut group = c.benchmark_group("devbench");
|
||||
devbench(&mut group);
|
||||
group.finish();
|
||||
}
|
||||
|
||||
criterion_group!(devbenches, criterion_benchmark);
|
||||
criterion_main!(devbenches);
|
||||
158
build.rs
158
build.rs
@@ -6,126 +6,39 @@ mod precalculate_lookup_tables {
|
||||
use std::fs::{create_dir_all, File};
|
||||
use std::io::Write;
|
||||
include!("src/lookup/config.rs");
|
||||
// use bincode::serialize;
|
||||
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let data = serialize(&EndoSinLookupTable::<f32>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// let data = serialize(&EndoSinLookupTable::<f64>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
// fn precalculate_cos_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let data = serialize(&EndoCosLookupTable::<f32>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/cos_f32.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// let data = serialize(&EndoCosLookupTable::<f64>::new(PRECISION))?;
|
||||
// let mut file = File::create("src/lookup/data/cos_f64.bin")?;
|
||||
// file.write_all(&data)?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
// let half_step: f32 = step / 2.0;
|
||||
|
||||
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)) - half_step
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)).sin()
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
// let half_step: f64 = step / 2.0;
|
||||
|
||||
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)) - half_step
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)).sin()
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
include!("src/lookup/ordinal_float.rs");
|
||||
|
||||
macro_rules! precalculate_sin_tables {
|
||||
() => {{
|
||||
let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
let half_step: f32 = step / 2.0;
|
||||
|
||||
let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f32)) - half_step
|
||||
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
let keys: [FloatOrd<f32>; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
FloatOrd( (step * (i as f32)) - half_step )
|
||||
}).collect::<Vec<FloatOrd<f32>>>().try_into().unwrap_or([FloatOrd::new(); TABLE_SIZE]);
|
||||
let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f32)).sin()
|
||||
}).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
let data = format!("pub const SIN_F32_KEYS: [f32; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
let data = format!("pub(crate) const SIN_F32_KEYS: [FloatOrd<f32>; {}] = {:?};\npub const SIN_F32_VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
let mut file = File::create("src/lookup/data/sin_f32.rs").expect("Failed to create sin_f32.rs");
|
||||
file.write_all(data.as_bytes()).expect("Failed to write sin_f32.rs");
|
||||
let mut file = File::create("src/lookup/data/sin_f32.rs")?;
|
||||
file.write_all(data.as_bytes())?;
|
||||
|
||||
let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
let half_step: f64 = step / 2.0;
|
||||
|
||||
let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f64)) - half_step
|
||||
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
let keys: [FloatOrd<f64>; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
FloatOrd( (step * (i as f64)) - half_step )
|
||||
}).collect::<Vec<FloatOrd<f64>>>().try_into().unwrap_or([FloatOrd::new(); TABLE_SIZE]);
|
||||
let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
(step * (i as f64)).sin()
|
||||
}).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
let data = format!("pub const SIN_F64_KEYS: [f64; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
let data = format!("pub const SIN_F64_KEYS: [FloatOrd<f64>; {}] = {:?};\npub const SIN_F64_VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
let mut file = File::create("src/lookup/data/sin_f64.rs").expect("Failed to create sin_f64.rs");
|
||||
file.write_all(data.as_bytes()).expect("Failed to write sin_f64.rs");
|
||||
let mut file = File::create("src/lookup/data/sin_f64.rs")?;
|
||||
file.write_all(data.as_bytes())?;
|
||||
}};
|
||||
}
|
||||
// fn precalculate_sin_tables() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// let step: f32 = f32_consts::FRAC_PI_2 / TABLE_SIZE as f32;
|
||||
// let half_step: f32 = step / 2.0;
|
||||
|
||||
// let keys: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)) - half_step
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let values: [f32; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f32)).sin()
|
||||
// }).collect::<Vec<f32>>().try_into().unwrap_or([0.0f32; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f32; {}] = {:?};\npub const VALUES: [f32; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f32.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// let step: f64 = f64_consts::FRAC_PI_2 / TABLE_SIZE as f64;
|
||||
// let half_step: f64 = step / 2.0;
|
||||
|
||||
// let keys: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)) - half_step
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let values: [f64; TABLE_SIZE] = (0..TABLE_SIZE).map(|i| {
|
||||
// (step * (i as f64)).sin()
|
||||
// }).collect::<Vec<f64>>().try_into().unwrap_or([0.0f64; TABLE_SIZE]);
|
||||
// let data = format!("pub const KEYS: [f64; {}] = {:?};\npub const VALUES: [f64; {}] = {:?};\n", TABLE_SIZE, keys, TABLE_SIZE, values);
|
||||
|
||||
// let mut file = File::create("src/lookup/data/sin_f64.rs")?;
|
||||
// file.write_all(data.as_bytes())?;
|
||||
|
||||
// Ok(())
|
||||
// }
|
||||
|
||||
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
|
||||
create_dir_all("src/lookup/data")?;
|
||||
@@ -137,8 +50,49 @@ mod precalculate_lookup_tables {
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
precalculate_lookup_tables::generate()?;
|
||||
mod precalculate_test_tables {
|
||||
use std::fs::{create_dir_all, File};
|
||||
use std::io::Write;
|
||||
include!("src/tests/accuracy/config.rs");
|
||||
|
||||
macro_rules! precalculate_test_tables {
|
||||
() => {{
|
||||
let scaling: f32 = X_SIZE as f32 / (X_MAX as f32 - X_MIN as f32);
|
||||
let x_f32: [f32; X_SIZE] =
|
||||
(0..X_SIZE)
|
||||
.map(|a| ((a as f32) / scaling) + X_MIN as f32)
|
||||
.collect::<Vec<f32>>()
|
||||
.try_into().map_err(|_| "Failed to convert Vec<f32> to [f32; X_SIZE]")?;
|
||||
|
||||
let scaling: f64 = X_SIZE as f64 / (X_MAX as f64 - X_MIN as f64);
|
||||
let x_f64: [f64; X_SIZE] =
|
||||
(0..X_SIZE)
|
||||
.map(|a| ((a as f64) / scaling) + X_MIN as f64)
|
||||
.collect::<Vec<f64>>()
|
||||
.try_into().map_err(|_| "Failed to convert Vec<f64> to [f64; X_SIZE]")?;
|
||||
|
||||
let data = format!(
|
||||
"#[cfg(test)]\n#[allow(dead_code)]\npub const X_F32: [f32; {}] = {:?};\n#[cfg(test)]\n#[allow(dead_code)]\npub const X_F64: [f64; {}] = {:?};",
|
||||
X_SIZE, x_f32, X_SIZE, x_f64
|
||||
);
|
||||
|
||||
let mut file = File::create("src/tests/accuracy/x.rs")?;
|
||||
file.write_all(data.as_bytes())?;
|
||||
}};
|
||||
}
|
||||
|
||||
pub fn generate() -> Result<(), Box<dyn std::error::Error>> {
|
||||
create_dir_all("src/tests/accuracy")?;
|
||||
|
||||
precalculate_test_tables!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
precalculate_lookup_tables::generate()?;
|
||||
precalculate_test_tables::generate()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
73
src/bit_magic.rs
Normal file
73
src/bit_magic.rs
Normal file
@@ -0,0 +1,73 @@
|
||||
fn bool_to_full_byte<T>(b: bool) -> T
|
||||
where
|
||||
T: WrappingSub + One + From<u8> + std::ops::Not<Output=T>,
|
||||
|
||||
{
|
||||
!(
|
||||
( T::from( unsafe { std::mem::transmute::<bool, u8>(b) } ) )
|
||||
.wrapping_sub(&T::one())
|
||||
)
|
||||
}
|
||||
|
||||
pub trait AndBool {
|
||||
fn and(self: Self, b: bool) -> Self;
|
||||
}
|
||||
impl AndBool for f32 {
|
||||
fn and(self: f32, b: bool) -> f32 {
|
||||
let b_byte: u32 = bool_to_full_byte(b);
|
||||
unsafe {
|
||||
std::mem::transmute(
|
||||
std::mem::transmute::<f32, u32>(self) & b_byte
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AndBool for f64 {
|
||||
fn and(self: f64, b: bool) -> f64 {
|
||||
let b_byte: u64 = bool_to_full_byte(b);
|
||||
unsafe {
|
||||
std::mem::transmute(
|
||||
std::mem::transmute::<f64, u64>(self) & b_byte
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait GetSign {
|
||||
fn sign(self: Self) -> Self;
|
||||
}
|
||||
impl GetSign for f32 {
|
||||
fn sign(self: f32) -> f32 {
|
||||
let x_bytes: [u8; 4] = unsafe { std::mem::transmute(self) };
|
||||
unsafe {
|
||||
std::mem::transmute(
|
||||
[
|
||||
0u8,
|
||||
0u8,
|
||||
128u8,
|
||||
(x_bytes[3] & 128u8) | 63u8,
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
impl GetSign for f64 {
|
||||
fn sign(self: f64) -> f64 {
|
||||
let x_bytes: [u8; 8] = unsafe { std::mem::transmute(self) };
|
||||
unsafe {
|
||||
std::mem::transmute(
|
||||
[
|
||||
0u8,
|
||||
0u8,
|
||||
0u8,
|
||||
0u8,
|
||||
0u8,
|
||||
0u8,
|
||||
240u8,
|
||||
(x_bytes[7] & 128u8) | 63u8,
|
||||
]
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
//! A collection of fast (often approximate) mathematical functions for accelerating mathematical functions
|
||||
|
||||
// Optimisation note: lookup tables become faster when calculation takes > ~1ms
|
||||
// Optimisation note: lookup tables become faster when calculation takes > ~400us
|
||||
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
// use crate::lookup::*;
|
||||
use crate::lookup::lookup_table::EndoCosLookupTable;
|
||||
use crate::lookup::{EndoCosLookupTable, EndoSinLookupTable};
|
||||
|
||||
const SIN_LOOKUP_F32: EndoSinLookupTable<f32> = EndoSinLookupTable::<f32>::new();
|
||||
const SIN_LOOKUP_F64: EndoSinLookupTable<f64> = EndoSinLookupTable::<f64>::new();
|
||||
const COS_LOOKUP_F32: EndoCosLookupTable<f32> = EndoCosLookupTable::<f32>::new();
|
||||
const COS_LOOKUP_F64: EndoCosLookupTable<f64> = EndoCosLookupTable::<f64>::new();
|
||||
|
||||
@@ -18,6 +19,25 @@ const V_SCALE_F32: f32 = 8388608.0; // the largest possible mantissa of an f32
|
||||
const V_SCALE_F64: f64 = 4503599627370496.0; // the largest possible mantissa of an f64
|
||||
|
||||
|
||||
pub trait LookupSin {
|
||||
fn lookup_sin(self: Self) -> Self;
|
||||
}
|
||||
impl LookupSin for f64 {
|
||||
#[inline]
|
||||
fn lookup_sin(self: Self) -> f64 {
|
||||
// Look up the value in the table
|
||||
SIN_LOOKUP_F64.lookup(self)
|
||||
}
|
||||
}
|
||||
impl LookupSin for f32 {
|
||||
#[inline]
|
||||
fn lookup_sin(self: Self) -> f32 {
|
||||
// Look up the value in the table
|
||||
SIN_LOOKUP_F32.lookup(self)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
pub trait LookupCos {
|
||||
fn lookup_cos(self: Self) -> Self;
|
||||
}
|
||||
@@ -119,14 +139,3 @@ impl FastSigmoid for f64 {
|
||||
(ONE + (-self).fast_exp()).recip()
|
||||
}
|
||||
}
|
||||
|
||||
// functions for testing the accuracy of fast functions against builtin functions
|
||||
#[inline]
|
||||
pub fn sigmoid_builtin_f32(p: f32) -> f32 {
|
||||
(1. + (-p).exp()).recip()
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn sigmoid_builtin_f64(p: f64) -> f64 {
|
||||
(1. + (-p).exp()).recip()
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
#![allow(unused_imports)]
|
||||
|
||||
pub mod lookup;
|
||||
mod fastmath;
|
||||
pub mod macros;
|
||||
|
||||
mod fastmath;
|
||||
pub use fastmath::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
pub(crate) mod tests;
|
||||
@@ -1,4 +1,5 @@
|
||||
// lookup/const_tables.rs
|
||||
use crate::lookup::ordinal_float::FloatOrd;
|
||||
|
||||
include!("data/sin_f32.rs");
|
||||
include!("data/sin_f64.rs");
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -1,75 +1,25 @@
|
||||
use std::f32::consts as f32_consts;
|
||||
use std::f64::consts as f64_consts;
|
||||
use std::cmp::Ordering;
|
||||
use num_traits::identities::One;
|
||||
use num_traits::ops::wrapping::WrappingSub;
|
||||
use num_traits::float::{Float, FloatConst};
|
||||
|
||||
use crate::{
|
||||
impl_fbitfbit_lookup_table,
|
||||
impl_cycling_fbitfbit_lookup_table,
|
||||
};
|
||||
use crate::lookup::TABLE_SIZE;
|
||||
use crate::lookup::ordinal_float::FloatOrd;
|
||||
use crate::lookup::const_tables::*;
|
||||
|
||||
// This function should never be used in a non-const context.
|
||||
// It only exists as a workaround for the fact that const fn's cannot use iterators.
|
||||
const fn make_ordinal<T: Float>(
|
||||
input: [T; TABLE_SIZE],
|
||||
mut map_target: [FloatOrd<T>; TABLE_SIZE],
|
||||
) -> () {
|
||||
// let mut map_target = [FloatOrd::<T>::new(); TABLE_SIZE];
|
||||
let mut index = 0;
|
||||
while index < TABLE_SIZE {
|
||||
map_target[index] = FloatOrd(input[index]);
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// The following macros are to minimise the amount of boilerplate for static types on the lookup tables.
|
||||
macro_rules! impl_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl FloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
|
||||
let ord_keys: [FloatOrd<$key_type>; TABLE_SIZE] = [FloatOrd(0.0 as $key_type); TABLE_SIZE];
|
||||
make_ordinal(keys, ord_keys);
|
||||
FloatLookupTable {
|
||||
keys: ord_keys,
|
||||
values,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_cycling_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl CyclingFloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [$key_type; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, upper_bound: $key_type) -> Self {
|
||||
CyclingFloatLookupTable {
|
||||
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
#[derive(Default, Debug,Clone, Copy, PartialEq, PartialOrd)]
|
||||
pub struct FloatOrd<T: Float>(pub T);
|
||||
impl<T: Float> FloatOrd<T> {
|
||||
pub fn new() -> Self {
|
||||
FloatOrd(T::zero())
|
||||
}
|
||||
}
|
||||
impl<T: Float> Eq for FloatOrd<T> {}
|
||||
impl<T: Float> Ord for FloatOrd<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FloatLookupTable<T1, T2>
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
FloatOrd<T1>: Ord,
|
||||
{
|
||||
keys: [FloatOrd<T1>; TABLE_SIZE],
|
||||
values: [T2; TABLE_SIZE],
|
||||
@@ -78,6 +28,7 @@ impl<T1, T2> FloatLookupTable<T1, T2>
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
FloatOrd<T1>: Ord,
|
||||
{
|
||||
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE]) -> Self {
|
||||
FloatLookupTable {
|
||||
@@ -86,20 +37,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_next(&self, key: T1) -> T2 {
|
||||
pub fn get_next(&self, key: T1) -> T2
|
||||
{
|
||||
let ord_key = FloatOrd(key);
|
||||
let mut lower_bound = 0;
|
||||
let mut upper_bound = self.keys.len() - 1;
|
||||
let mut mid = (lower_bound + upper_bound) / 2;
|
||||
while upper_bound - lower_bound > 1 {
|
||||
let mut mid: usize;
|
||||
|
||||
while lower_bound < upper_bound {
|
||||
mid = lower_bound + (upper_bound - lower_bound) / 2;
|
||||
if self.keys[mid] < ord_key {
|
||||
lower_bound = mid;
|
||||
lower_bound = mid + 1;
|
||||
} else {
|
||||
upper_bound = mid;
|
||||
}
|
||||
mid = (lower_bound + upper_bound) / 2;
|
||||
}
|
||||
self.values[mid]
|
||||
self.values[upper_bound]
|
||||
}
|
||||
|
||||
pub fn lookup(&self, key: T1) -> T2 {
|
||||
@@ -117,27 +70,30 @@ pub struct CyclingFloatLookupTable<T1, T2>
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
FloatOrd<T1>: Ord,
|
||||
{
|
||||
lookup_table: FloatLookupTable<T1, T2>,
|
||||
lower_bound: T1,
|
||||
upper_bound: T1,
|
||||
range: T1,
|
||||
}
|
||||
impl<T1, T2> CyclingFloatLookupTable<T1, T2>
|
||||
where
|
||||
T1: Float,
|
||||
T2: Float,
|
||||
FloatOrd<T1>: Ord,
|
||||
{
|
||||
pub fn new(keys: [T1; TABLE_SIZE], values: [T2; TABLE_SIZE], lower_bound: T1, upper_bound: T1) -> Self {
|
||||
CyclingFloatLookupTable {
|
||||
lookup_table: FloatLookupTable::new(keys, values),
|
||||
lower_bound,
|
||||
upper_bound,
|
||||
lower_bound: lower_bound,
|
||||
range: upper_bound - lower_bound,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lookup(&self, key: T1) -> T2 {
|
||||
let key = (key % (self.upper_bound - self.lower_bound)) + self.lower_bound;
|
||||
self.lookup_table.lookup(key)
|
||||
self.lookup_table.lookup(
|
||||
(key % self.range) + self.lower_bound
|
||||
)
|
||||
}
|
||||
}
|
||||
impl_cycling_fbitfbit_lookup_table!(f32, f32);
|
||||
@@ -146,19 +102,19 @@ impl_cycling_fbitfbit_lookup_table!(f32, f64);
|
||||
impl_cycling_fbitfbit_lookup_table!(f64, f32);
|
||||
|
||||
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EndoSinLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst,
|
||||
FloatOrd<T>: Ord,
|
||||
{
|
||||
lookup_table: CyclingFloatLookupTable<T, T>,
|
||||
}
|
||||
impl<T> EndoSinLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst,
|
||||
FloatOrd<T>: Ord,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub fn lookup(&self, key: T) -> T {
|
||||
if key < T::zero() {
|
||||
-self.lookup(-key)
|
||||
@@ -166,28 +122,30 @@ where
|
||||
self.lookup_table.lookup(key)
|
||||
} else if key < T::PI() {
|
||||
self.lookup_table.lookup(T::PI() - key)
|
||||
} else {
|
||||
} else if key < T::TAU() { // obviously, mod is slow so we want to avoid it until this would start recursing deeply
|
||||
-self.lookup(key - T::PI())
|
||||
} else {
|
||||
-self.lookup(key % T::PI())
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EndoSinLookupTable<f32>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
const UPPER_BOUND: f32 = f32_consts::PI;
|
||||
|
||||
EndoSinLookupTable {
|
||||
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, UPPER_BOUND),
|
||||
lookup_table: CyclingFloatLookupTable::<f32, f32>::new_const(
|
||||
SIN_F32_KEYS, SIN_F32_VALUES, 0.0f32, f32_consts::PI
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl EndoSinLookupTable<f64>
|
||||
{
|
||||
pub const fn new() -> Self {
|
||||
let upper_bound = f64_consts::PI;
|
||||
|
||||
EndoSinLookupTable {
|
||||
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, upper_bound),
|
||||
lookup_table: CyclingFloatLookupTable::<f64, f64>::new_const(
|
||||
SIN_F64_KEYS, SIN_F64_VALUES, 0.0f64, f64_consts::PI
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -202,9 +160,8 @@ where
|
||||
}
|
||||
impl<T> EndoCosLookupTable<T>
|
||||
where
|
||||
T: Float + FloatConst + std::fmt::Debug,
|
||||
T: Float + FloatConst,
|
||||
{
|
||||
#[allow(dead_code)]
|
||||
pub fn lookup(&self, key: T) -> T {
|
||||
self.lookup_table.lookup(key + T::FRAC_PI_2())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod const_tables;
|
||||
pub mod const_tables;
|
||||
pub mod lookup_table;
|
||||
pub(crate) mod ordinal_float;
|
||||
|
||||
pub use const_tables::*;
|
||||
pub use lookup_table::*;
|
||||
|
||||
include!("config.rs");
|
||||
16
src/lookup/ordinal_float.rs
Normal file
16
src/lookup/ordinal_float.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use num_traits::float::Float;
|
||||
|
||||
#[derive(Default, Debug, Clone, Copy, PartialEq, PartialOrd)]
|
||||
pub struct FloatOrd<T: Float>(pub T);
|
||||
impl<T: Float> FloatOrd<T> {
|
||||
#[allow(dead_code)]
|
||||
pub fn new() -> Self {
|
||||
FloatOrd(T::zero())
|
||||
}
|
||||
}
|
||||
impl<T: Float> Eq for FloatOrd<T> {}
|
||||
impl<T: Float> Ord for FloatOrd<T> {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.partial_cmp(other).unwrap_or(std::cmp::Ordering::Equal)
|
||||
}
|
||||
}
|
||||
30
src/macros/lookup.rs
Normal file
30
src/macros/lookup.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
mod lookup_table {
|
||||
#[macro_export]
|
||||
macro_rules! impl_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl FloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [FloatOrd<$key_type>; TABLE_SIZE], values: [$value_type; TABLE_SIZE]) -> Self {
|
||||
FloatLookupTable {
|
||||
keys: keys,
|
||||
values: values,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_cycling_fbitfbit_lookup_table {
|
||||
($key_type:ty, $value_type:ty) => {
|
||||
impl CyclingFloatLookupTable<$key_type, $value_type> {
|
||||
pub const fn new_const(keys: [FloatOrd<$key_type>; TABLE_SIZE], values: [$value_type; TABLE_SIZE], lower_bound: $key_type, range: $key_type) -> Self {
|
||||
CyclingFloatLookupTable {
|
||||
lookup_table: FloatLookupTable::<$key_type, $value_type>::new_const(keys, values),
|
||||
lower_bound: lower_bound,
|
||||
range: range,
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
1
src/macros/mod.rs
Normal file
1
src/macros/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod lookup;
|
||||
174
src/tests.rs
174
src/tests.rs
@@ -1,174 +0,0 @@
|
||||
//tests.rs
|
||||
|
||||
use num_traits::Float;
|
||||
|
||||
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
|
||||
where T: Float + std::ops::AddAssign,
|
||||
{
|
||||
let n = vector1.len();
|
||||
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
|
||||
|
||||
let mut total_error = T::zero();
|
||||
for i in 0..n {
|
||||
let diff = (vector1[i] - vector2[i]).abs();
|
||||
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
|
||||
total_error += error;
|
||||
}
|
||||
|
||||
let average_error = total_error / T::from(n).unwrap();
|
||||
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
|
||||
percentage_error
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
mod f64_error {
|
||||
use crate::*;
|
||||
use super::super::calculate_percentage_error;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
const TOLERANCE: f64 = 2.5;
|
||||
|
||||
static X: Lazy<Vec<f64>> = Lazy::new(|| {
|
||||
(-10000..10000)
|
||||
.map(|a| (a as f64) / 1000.)
|
||||
.collect::<Vec<f64>>()
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| 2.0f64.powf(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "fast_pow2<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"fast_pow2<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_exp()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| x.exp()).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "fast_exp<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"fast_exp<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_cos()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "fast_cos<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"fast_cos<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
// lookup
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "lookup_cos<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"lookup_cos<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sigmoid() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f64>>(),
|
||||
&X.iter().map(|&x| sigmoid_builtin_f64(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
assert!(!percentage_error.is_nan(), "fast_sigmoid<f64> percentage error is NaN");
|
||||
assert!(
|
||||
percentage_error < TOLERANCE,
|
||||
"fast_sigmoid<f64> percentage error: {0}",
|
||||
percentage_error
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
mod f32_error {
|
||||
use crate::*;
|
||||
use super::super::calculate_percentage_error;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
const TOLERANCE: f32 = 2.5;
|
||||
|
||||
static X: Lazy<Vec<f32>> = Lazy::new(|| {
|
||||
(-10000..10000)
|
||||
.map(|a| (a as f32) / 1000.)
|
||||
.collect::<Vec<f32>>()
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn pow2() -> Result<(), Box<dyn std::error::Error>> {
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| 2.0f32.powf(x)).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn exp() -> Result<(), Box<dyn std::error::Error>> {
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_exp()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| x.exp()).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cos() -> Result<(), Box<dyn std::error::Error>> {
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_cos()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
// lookup
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| x.cos()).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sigmoid() -> Result<(), Box<dyn std::error::Error>> {
|
||||
assert!(
|
||||
calculate_percentage_error(
|
||||
&X.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f32>>(),
|
||||
&X.iter().map(|&x| sigmoid_builtin_f32(x)).collect::<Vec<f32>>()
|
||||
) < TOLERANCE
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
130
src/tests/accuracy/comparisons.rs
Normal file
130
src/tests/accuracy/comparisons.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
#[cfg(test)]
|
||||
|
||||
use super::exact;
|
||||
|
||||
use num_traits::Float;
|
||||
|
||||
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
|
||||
where T: Float,
|
||||
{
|
||||
let n = vector1.len();
|
||||
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
|
||||
|
||||
let mut total_error = T::zero();
|
||||
for i in 0..n {
|
||||
let diff = (vector1[i] - vector2[i]).abs();
|
||||
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
|
||||
total_error = total_error + error;
|
||||
}
|
||||
|
||||
let average_error = total_error / T::from(n).unwrap();
|
||||
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
|
||||
percentage_error
|
||||
}
|
||||
|
||||
macro_rules! panic_if_nan_or_print {
|
||||
($x:expr, $varname:expr) => {
|
||||
if $x.is_nan() {
|
||||
Err(format!("{} is NaN!", $varname))?
|
||||
} else {
|
||||
println!("{}: {}%", $varname, $x);
|
||||
Ok($x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod f64 {
|
||||
use crate::*;
|
||||
use super::exact;
|
||||
use super::calculate_percentage_error;
|
||||
|
||||
include!("x.rs");
|
||||
|
||||
pub fn pow2() -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F64.iter().map(|&x| x.fast_pow2()).collect::<Vec<f64>>(),
|
||||
&X_F64.iter().map(|&x| exact::f64::pow2(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_pow2<f64> percentage error")
|
||||
}
|
||||
|
||||
pub fn exp() -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F64.iter().map(|&x| x.fast_exp()).collect::<Vec<f64>>(),
|
||||
&X_F64.iter().map(|&x| exact::f64::exp(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_exp<f64> percentage error")
|
||||
}
|
||||
|
||||
pub fn cos() -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F64.iter().map(|&x| x.fast_cos()).collect::<Vec<f64>>(),
|
||||
&X_F64.iter().map(|&x| exact::f64::cos(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_cos<f64> percentage error")
|
||||
}
|
||||
|
||||
pub fn cos_lookup() -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F64.iter().map(|&x| x.lookup_cos()).collect::<Vec<f64>>(),
|
||||
&X_F64.iter().map(|&x| exact::f64::cos(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "lookup_cos<f64> percentage error")
|
||||
}
|
||||
|
||||
pub fn sigmoid() -> Result<f64, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F64.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f64>>(),
|
||||
&X_F64.iter().map(|&x| exact::f64::sigmoid(x)).collect::<Vec<f64>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_sigmoid<f64> percentage error")
|
||||
}
|
||||
}
|
||||
|
||||
pub mod f32 {
|
||||
use crate::*;
|
||||
use super::exact;
|
||||
use super::calculate_percentage_error;
|
||||
|
||||
include!("x.rs");
|
||||
|
||||
pub fn pow2() -> Result<f32, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F32.iter().map(|&x| x.fast_pow2()).collect::<Vec<f32>>(),
|
||||
&X_F32.iter().map(|&x| exact::f32::pow2(x)).collect::<Vec<f32>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_pow2<f32> percentage error")
|
||||
}
|
||||
|
||||
pub fn exp() -> Result<f32, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F32.iter().map(|&x| x.fast_exp()).collect::<Vec<f32>>(),
|
||||
&X_F32.iter().map(|&x| exact::f32::exp(x)).collect::<Vec<f32>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_exp<f32> percentage error")
|
||||
}
|
||||
|
||||
pub fn cos() -> Result<f32, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F32.iter().map(|&x| x.fast_cos()).collect::<Vec<f32>>(),
|
||||
&X_F32.iter().map(|&x| exact::f32::cos(x)).collect::<Vec<f32>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_cos<f32> percentage error")
|
||||
}
|
||||
|
||||
pub fn cos_lookup() -> Result<f32, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F32.iter().map(|&x| x.lookup_cos()).collect::<Vec<f32>>(),
|
||||
&X_F32.iter().map(|&x| exact::f32::cos(x)).collect::<Vec<f32>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "lookup_cos<f32> percentage error")
|
||||
}
|
||||
|
||||
pub fn sigmoid() -> Result<f32, Box<dyn std::error::Error>> {
|
||||
let percentage_error = calculate_percentage_error(
|
||||
&X_F32.iter().map(|&x| x.fast_sigmoid()).collect::<Vec<f32>>(),
|
||||
&X_F32.iter().map(|&x| exact::f32::sigmoid(x)).collect::<Vec<f32>>()
|
||||
);
|
||||
panic_if_nan_or_print!(percentage_error, "fast_sigmoid<f32> percentage error")
|
||||
}
|
||||
}
|
||||
3
src/tests/accuracy/config.rs
Normal file
3
src/tests/accuracy/config.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
const X_MIN: f64 = -10.0;
|
||||
const X_MAX: f64 = 10.0;
|
||||
const X_SIZE: usize = 20000;
|
||||
35
src/tests/accuracy/exact.rs
Normal file
35
src/tests/accuracy/exact.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
pub mod f64 {
|
||||
pub fn pow2(n: f64) -> f64 {
|
||||
2.0f64.powf(n)
|
||||
}
|
||||
|
||||
pub fn exp(n: f64) -> f64 {
|
||||
n.exp()
|
||||
}
|
||||
|
||||
pub fn cos(n: f64) -> f64 {
|
||||
n.cos()
|
||||
}
|
||||
|
||||
pub fn sigmoid(n: f64) -> f64 {
|
||||
(1. + (-n).exp()).recip()
|
||||
}
|
||||
}
|
||||
|
||||
pub mod f32 {
|
||||
pub fn pow2(n: f32) -> f32 {
|
||||
2.0f32.powf(n)
|
||||
}
|
||||
|
||||
pub fn exp(n: f32) -> f32 {
|
||||
n.exp()
|
||||
}
|
||||
|
||||
pub fn cos(n: f32) -> f32 {
|
||||
n.cos()
|
||||
}
|
||||
|
||||
pub fn sigmoid(n: f32) -> f32 {
|
||||
(1. + (-n).exp()).recip()
|
||||
}
|
||||
}
|
||||
4
src/tests/accuracy/mod.rs
Normal file
4
src/tests/accuracy/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
mod exact;
|
||||
|
||||
mod comparisons;
|
||||
pub use comparisons::*;
|
||||
6
src/tests/accuracy/x.rs
Normal file
6
src/tests/accuracy/x.rs
Normal file
File diff suppressed because one or more lines are too long
3
src/tests/mod.rs
Normal file
3
src/tests/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
#![cfg(test)]
|
||||
mod accuracy;
|
||||
mod tolerance;
|
||||
7
src/tests/tolerance.json
Normal file
7
src/tests/tolerance.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"pow2_fast": 2.0,
|
||||
"exp_fast": 2.0,
|
||||
"cos_fast": 1.0,
|
||||
"cos_lk": 1.0,
|
||||
"sigmoid_fast": 1.0
|
||||
}
|
||||
50
src/tests/tolerance.rs
Normal file
50
src/tests/tolerance.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use super::accuracy;
|
||||
use serde_json;
|
||||
|
||||
fn get_tolerance<T>(key: &str) -> Result<T, serde_json::Error>
|
||||
where
|
||||
T: serde::de::DeserializeOwned
|
||||
{
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_str(
|
||||
include_str!("tolerance.json")
|
||||
)?;
|
||||
let value = serde_json::value::from_value(
|
||||
json[key].clone()
|
||||
)?;
|
||||
Ok(value)
|
||||
}
|
||||
|
||||
macro_rules! test_within_tolerance {
|
||||
($function:ident, $t:ty, $test_name:ident) => {
|
||||
#[test]
|
||||
fn $test_name() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let tolerance: $t = get_tolerance::<$t>(stringify!($test_name))?;
|
||||
let percentage_error: $t = $function()?;
|
||||
assert!(percentage_error < tolerance);
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
mod f64 {
|
||||
use super::{accuracy, get_tolerance};
|
||||
use accuracy::f64::*;
|
||||
|
||||
test_within_tolerance!(pow2, f64, pow2_fast);
|
||||
test_within_tolerance!(exp, f64, exp_fast);
|
||||
test_within_tolerance!(cos, f64, cos_fast);
|
||||
test_within_tolerance!(cos_lookup, f64, cos_lk);
|
||||
test_within_tolerance!(sigmoid, f64, sigmoid_fast);
|
||||
}
|
||||
|
||||
mod f32 {
|
||||
use super::{accuracy, get_tolerance};
|
||||
use accuracy::f32::*;
|
||||
|
||||
test_within_tolerance!(pow2, f32, pow2_fast);
|
||||
test_within_tolerance!(exp, f32, exp_fast);
|
||||
test_within_tolerance!(cos, f32, cos_fast);
|
||||
test_within_tolerance!(cos_lookup, f32, cos_lk);
|
||||
test_within_tolerance!(sigmoid, f32, sigmoid_fast);
|
||||
}
|
||||
70
test.rs
Normal file
70
test.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use std::fs::read;
|
||||
use std::f64::consts as f64_consts;
|
||||
use bincode::deserialize;
|
||||
use once_cell::sync::Lazy;
|
||||
use ndarray::prelude::*;
|
||||
use optimize::*;
|
||||
use num_traits::Float;
|
||||
|
||||
|
||||
fn calculate_percentage_error<T>(vector1: &[T], vector2: &[T]) -> T
|
||||
where T: Float + std::ops::AddAssign,
|
||||
{
|
||||
let n = vector1.len();
|
||||
assert_eq!(n, vector2.len(), "Vectors must have equal lengths.");
|
||||
|
||||
let mut total_error = T::zero();
|
||||
for i in 0..n {
|
||||
let diff = (vector1[i] - vector2[i]).abs();
|
||||
let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] };
|
||||
total_error += error;
|
||||
}
|
||||
|
||||
let average_error = total_error / T::from(n).unwrap();
|
||||
let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T");
|
||||
percentage_error
|
||||
}
|
||||
|
||||
|
||||
fn fast_exp(x: f64, clipp_thresh: f64, v_scale: f64, clipp_shift: f64) -> f64 {
|
||||
// const CLIPP_THRESH: f64 = -180335.51911105003;
|
||||
// const V_SCALE: f64 = 4524653012949098.0;
|
||||
// const CLIPP_SHIFT: f64 = 1018.1563534409383;
|
||||
let scaled_p = f64_consts::LOG2_E * x;
|
||||
let clipp = scaled_p.max(clipp_thresh);
|
||||
let v = (v_scale * (clipp + clipp_shift)) as u64;
|
||||
f64::from_bits(v)
|
||||
}
|
||||
|
||||
const Y: Lazy<Vec<f64>> = Lazy::new(|| { deserialize(&read("tmp/Y.bin").unwrap()).unwrap() } );
|
||||
|
||||
fn objective(args: ArrayView1<f64>) -> f64 {
|
||||
let clipp_thresh: f64 = args[0];
|
||||
let v_scale: f64 = args[1];
|
||||
let clipp_shift: f64 = args[2];
|
||||
|
||||
let X: Vec<f64> = (-10000..10000)
|
||||
.map(|a| (a as f64) / 1000.)
|
||||
.collect::<Vec<f64>>();
|
||||
let Y_hat: Vec<f64> = X.iter().map(|&x| fast_exp(x, clipp_thresh, v_scale, clipp_shift)).collect::<Vec<f64>>();
|
||||
calculate_percentage_error(&(*Y), &Y_hat)
|
||||
}
|
||||
|
||||
fn optimize_params() {
|
||||
// Create a minimizer using the builder pattern.
|
||||
let minimizer = NelderMeadBuilder::default()
|
||||
.xtol(1e-6f64)
|
||||
.ftol(1e-6f64)
|
||||
.maxiter(50000)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
// Set the starting guess
|
||||
let args: Array1<f64> = Array1::from_vec(vec![-180335.51911105003, 4524653012949098.0, 1018.1563534409383]);
|
||||
|
||||
// Run the optimization
|
||||
let ans = minimizer.minimize(objective, args.view());
|
||||
|
||||
// Print the optimized values
|
||||
println!("Final optimized arguments: {}", ans);
|
||||
}
|
||||
Reference in New Issue
Block a user