From d049162c18bf124ef6e52de86ee061324d674dc5 Mon Sep 17 00:00:00 2001 From: Cian Hughes Date: Mon, 13 Nov 2023 15:11:14 +0000 Subject: [PATCH] Quick, unoptimized impl of fast_tan --- plot.rs | 246 ++++++++++++++++++++++++++++++ src/fastmath.rs | 30 ++++ src/lookup/lookup_table.rs | 2 - src/tests/accuracy/comparisons.rs | 16 ++ src/tests/accuracy/exact.rs | 8 + src/tests/tolerance.json | 1 + src/tests/tolerance.rs | 2 + 7 files changed, 303 insertions(+), 2 deletions(-) create mode 100755 plot.rs diff --git a/plot.rs b/plot.rs new file mode 100755 index 0000000..c754547 --- /dev/null +++ b/plot.rs @@ -0,0 +1,246 @@ +#!/home/cianh/.cargo/bin/run-cargo-script +// cargo-deps: plotters, num-traits, fastmath={ path = "." } + +// A simple script for visualizing plots during development. + +extern crate plotters; +extern crate num_traits; +extern crate fastmath; + +use std::rc::Rc; +use std::error::Error; +use fastmath::*; +use plotters::prelude::*; +use num_traits::Float; +use plotters::coord::types::RangedCoordf64; + +mod exact { + include!("/home/cianh/Programming/Git_Projects/fastmath/src/tests/accuracy/exact.rs"); +} + +fn calculate_percentage_error(vector1: &[T], vector2: &[T]) -> T + where T: Float + std::ops::AddAssign, +{ + let n = vector1.len(); + assert_eq!(n, vector2.len(), "Vectors must have equal lengths."); + + let mut total_error = T::zero(); + for i in 0..n { + let diff = (vector1[i] - vector2[i]).abs(); + let error = diff / if vector1[i] == T::zero() { T::min_positive_value() } else { vector1[i] }; + total_error += error; + } + + let average_error = total_error / T::from(n).unwrap(); + let percentage_error = average_error * T::from(100).expect("Cannot convert 100 to type T"); + percentage_error +} + +#[derive(Clone)] +enum ValidFloatFunction { + F32(Rc f32>), + F64(Rc f64>), +} +impl ValidFloatFunction { + fn new_f32_func(func: impl Fn(f32) -> f32 + 'static) -> Self { + ValidFloatFunction::F32(Rc::new(func)) + } + + fn new_f64_func(func: impl Fn(f64) -> f64 + 'static) -> Self { + ValidFloatFunction::F64(Rc::new(func)) + } + + fn plot_function_f32( + self: Self, + chart: &mut ChartContext<'_, DB, Cartesian2d>, + color: &RGBColor, + label: &str + ) -> Result<(), Box> + where + DB: DrawingBackend, + { + let function = match self { + ValidFloatFunction::F32(f) => f.clone(), + _ => panic!("Function is not f32"), + }; + let series_color = RGBColor(color.0, color.1, color.2); + chart.draw_series(LineSeries::new( + (-1000..=1000).map(|x| x as f32 / 100.0).map(|x| (x as f64, function(x) as f64)), + series_color, // Applying color + )) + .unwrap() + .label(label) // Setting label + .legend(move |(x,y)| PathElement::new(vec![(x,y), (x + 20,y)], &series_color)); // Drawing a legend element + + Ok(()) + } + + fn plot_function_f64( + self: Self, + chart: &mut ChartContext<'_, DB, Cartesian2d>, + color: &RGBColor, + label: &str + ) -> Result<(), Box> + where + DB: DrawingBackend, + { + let function = match self { + ValidFloatFunction::F64(f) => f.clone(), + _ => panic!("Function is not f64"), + }; + + let series_color = RGBColor(color.0, color.1, color.2); + chart.draw_series(LineSeries::new( + (-1000..=1000).map(|x| x as f64 / 100.0).map(|x| (x, function(x))), + series_color, // Applying color + )) + .unwrap() + .label(label) // Setting label + .legend(move |(x,y)| PathElement::new(vec![(x,y), (x + 20,y)], &series_color)); // Drawing a legend element + + Ok(()) + } +} +impl Default for ValidFloatFunction { + fn default() -> Self { + ValidFloatFunction::F32(Rc::new(|x: f32| -> f32 { x })) + } +} +impl Into for Rc f32> { + fn into(self) -> ValidFloatFunction { + ValidFloatFunction::F32(self) + } +} +impl Into for Rc f64> { + fn into(self) -> ValidFloatFunction { + ValidFloatFunction::F64(self) + } +} + +fn plot(functions: Vec, yrange: (f64, f64), labels: Vec, output: &String) -> Result<(), Box> { + let color_palette: Vec = (0..functions.len()).map(|i| { + let hue = (i as f64) / (functions.len() as f64); + ViridisRGB::get_color(hue) + }).collect(); + let filename = format!("{}.png", *output); + let root = BitMapBackend::new(&filename, (1280, 960)).into_drawing_area(); + root.fill(&WHITE)?; + + let (ymin, ymax) = yrange; + let mut chart = ChartBuilder::on(&root) + .caption(output, ("Arial", 50).into_font()) + .margin(5) + .x_label_area_size(30) + .y_label_area_size(30) + .build_cartesian_2d(-5f64..5f64, ymin..ymax)?; + + chart.configure_mesh().draw()?; + + for ((function, color), label) in functions.iter().zip(color_palette.iter()).zip(labels.iter()) { + match function { + ValidFloatFunction::F32(_) => function.clone().plot_function_f32(&mut chart, color, &label)?, + ValidFloatFunction::F64(_) => function.clone().plot_function_f64(&mut chart, color, &label)?, + } + } + + chart.configure_series_labels() + .background_style(&WHITE.mix(0.8)) + .border_style(&BLACK) + .draw()?; + + root.present()?; + + Ok(()) +} + +fn main() -> Result<(), Box> { + // pow2 + println!("Plotting pow2"); + plot( + vec![ + ValidFloatFunction::new_f64_func(exact::f64::pow2), + ValidFloatFunction::new_f32_func(f32::fast_pow2), + ValidFloatFunction::new_f64_func(f64::fast_pow2), + ], + (0f64, 10f64), + vec![ + String::from("exact::pow2"), + String::from("f32::fast_pow2"), + String::from("f64::fast_pow2"), + ], + &String::from("tmp/pow2") + )?; + // exp + println!("Plotting exp"); + plot( + vec![ + ValidFloatFunction::new_f64_func(exact::f64::exp), + ValidFloatFunction::new_f32_func(f32::fast_exp), + ValidFloatFunction::new_f64_func(f64::fast_exp), + ], + (0f64, 10f64), + vec![ + String::from("exact::exp"), + String::from("f32::fast_exp"), + String::from("f64::fast_exp"), + ], + &String::from("tmp/exp") + )?; + // sin + println!("Plotting sin"); + plot( + vec![ + ValidFloatFunction::new_f64_func(exact::f64::sin), + ValidFloatFunction::new_f32_func(f32::fast_sin), + ValidFloatFunction::new_f64_func(f64::fast_sin), + ValidFloatFunction::new_f32_func(f32::lookup_sin), + ValidFloatFunction::new_f64_func(f64::lookup_sin), + ], + (-1.5f64, 1.5f64), + vec![ + String::from("exact::sin"), + String::from("f32::fast_sin"), + String::from("f64::fast_sin"), + String::from("f32::lookup_sin"), + String::from("f64::lookup_sin"), + ], + &String::from("tmp/sin") + )?; + // cos + println!("Plotting cos"); + plot( + vec![ + ValidFloatFunction::new_f64_func(exact::f64::cos), + ValidFloatFunction::new_f32_func(f32::fast_cos), + ValidFloatFunction::new_f64_func(f64::fast_cos), + ValidFloatFunction::new_f32_func(f32::lookup_cos), + ValidFloatFunction::new_f64_func(f64::lookup_cos), + ], + (-1.5f64, 1.5f64), + vec![ + String::from("exact::cos"), + String::from("f32::fast_cos"), + String::from("f64::fast_cos"), + String::from("f32::lookup_cos"), + String::from("f64::lookup_cos"), + ], + &String::from("tmp/cos") + )?; + // sigmoid + println!("Plotting sigmoid"); + plot( + vec![ + ValidFloatFunction::new_f64_func(exact::f64::sigmoid), + ValidFloatFunction::new_f32_func(f32::fast_sigmoid), + ValidFloatFunction::new_f64_func(f64::fast_sigmoid), + ], + (-1.5f64, 1.5f64), + vec![ + String::from("exact::sigmoid"), + String::from("f32::fast_sigmoid"), + String::from("f64::fast_sigmoid"), + ], + &String::from("tmp/sigmoid") + )?; + Ok(()) +} \ No newline at end of file diff --git a/src/fastmath.rs b/src/fastmath.rs index 855aa6c..d4171f4 100644 --- a/src/fastmath.rs +++ b/src/fastmath.rs @@ -138,6 +138,36 @@ impl FastSin for f64 { } } +pub trait FastTan { // tan(x) = sin(x) / cos(x) + fn fast_tan(self: Self) -> Self; +} +impl FastTan for f32 { + #[inline] + fn fast_tan(self: Self) -> f32 { + let qpprox_cos = + 1.0 - f32_consts::FRAC_2_PI * + ((((self + f32_consts::PI).abs()) % f32_consts::TAU) - f32_consts::PI).abs(); + let qpprox_sin = + 1.0 - f32_consts::FRAC_2_PI * + ((((self + f32_consts::FRAC_PI_2).abs()) % f32_consts::TAU) - f32_consts::PI).abs(); + ((qpprox_sin * (1.0 + f32_consts::FRAC_PI_6)) - (qpprox_sin.powi(3) * f32_consts::FRAC_PI_6)) / + ((qpprox_cos * (1.0 + f32_consts::FRAC_PI_6)) - (qpprox_cos.powi(3) * f32_consts::FRAC_PI_6)) + } +} +impl FastTan for f64 { + #[inline] + fn fast_tan(self: Self) -> f64 { + let qpprox_cos = + 1.0 - f64_consts::FRAC_2_PI * + ((((self + f64_consts::PI).abs()) % f64_consts::TAU) - f64_consts::PI).abs(); + let qpprox_sin = + 1.0 - f64_consts::FRAC_2_PI * + ((((self + f64_consts::FRAC_PI_2).abs()) % f64_consts::TAU) - f64_consts::PI).abs(); + ((qpprox_sin * (1.0 + f64_consts::FRAC_PI_6)) - (qpprox_sin.powi(3) * f64_consts::FRAC_PI_6)) / + ((qpprox_cos * (1.0 + f64_consts::FRAC_PI_6)) - (qpprox_cos.powi(3) * f64_consts::FRAC_PI_6)) + } +} + pub trait FastExp { fn fast_exp(self: Self) -> Self; } diff --git a/src/lookup/lookup_table.rs b/src/lookup/lookup_table.rs index 2133926..64e40ec 100644 --- a/src/lookup/lookup_table.rs +++ b/src/lookup/lookup_table.rs @@ -13,8 +13,6 @@ use crate::{ use crate::lookup::TABLE_SIZE; use crate::lookup::const_tables::*; -// TODO: Test phf for lookup tables - pub trait ToIterator: IntoIterator {} impl ToIterator for I where I: IntoIterator {} diff --git a/src/tests/accuracy/comparisons.rs b/src/tests/accuracy/comparisons.rs index fddd650..bf8e662 100644 --- a/src/tests/accuracy/comparisons.rs +++ b/src/tests/accuracy/comparisons.rs @@ -88,6 +88,14 @@ pub mod f64 { panic_if_nan_or_print!(percentage_error, "lookup_sin percentage error") } + pub fn tan() -> Result> { + let percentage_error = calculate_percentage_error( + &X_F64.iter().map(|&x| x.fast_tan()).collect::>(), + &X_F64.iter().map(|&x| exact::f64::tan(x)).collect::>() + ); + panic_if_nan_or_print!(percentage_error, "fast_tan percentage error") + } + pub fn sigmoid() -> Result> { let percentage_error = calculate_percentage_error( &X_F64.iter().map(|&x| x.fast_sigmoid()).collect::>(), @@ -152,6 +160,14 @@ pub mod f32 { panic_if_nan_or_print!(percentage_error, "lookup_sin percentage error") } + pub fn tan() -> Result> { + let percentage_error = calculate_percentage_error( + &X_F32.iter().map(|&x| x.fast_tan()).collect::>(), + &X_F32.iter().map(|&x| exact::f32::tan(x)).collect::>() + ); + panic_if_nan_or_print!(percentage_error, "fast_tan percentage error") + } + pub fn sigmoid() -> Result> { let percentage_error = calculate_percentage_error( &X_F32.iter().map(|&x| x.fast_sigmoid()).collect::>(), diff --git a/src/tests/accuracy/exact.rs b/src/tests/accuracy/exact.rs index 7a53830..8a964c3 100644 --- a/src/tests/accuracy/exact.rs +++ b/src/tests/accuracy/exact.rs @@ -15,6 +15,10 @@ pub mod f64 { n.sin() } + pub fn tan(n: f64) -> f64 { + n.tan() + } + pub fn sigmoid(n: f64) -> f64 { (1. + (-n).exp()).recip() } @@ -37,6 +41,10 @@ pub mod f32 { n.sin() } + pub fn tan(n: f32) -> f32 { + n.tan() + } + pub fn sigmoid(n: f32) -> f32 { (1. + (-n).exp()).recip() } diff --git a/src/tests/tolerance.json b/src/tests/tolerance.json index 3ee9ddb..effa5e6 100644 --- a/src/tests/tolerance.json +++ b/src/tests/tolerance.json @@ -5,5 +5,6 @@ "cos_lk": 1.0, "sin_fast": 1.0, "sin_lk": 1.0, + "tan_fast": 1.0, "sigmoid_fast": 1.0 } \ No newline at end of file diff --git a/src/tests/tolerance.rs b/src/tests/tolerance.rs index 4a0c806..c3e47ed 100644 --- a/src/tests/tolerance.rs +++ b/src/tests/tolerance.rs @@ -37,6 +37,7 @@ mod f64 { test_within_tolerance!(cos_lookup, f64, cos_lk); test_within_tolerance!(sin, f64, sin_fast); test_within_tolerance!(sin_lookup, f64, sin_lk); + test_within_tolerance!(tan, f64, tan_fast); test_within_tolerance!(sigmoid, f64, sigmoid_fast); } @@ -50,5 +51,6 @@ mod f32 { test_within_tolerance!(cos_lookup, f32, cos_lk); test_within_tolerance!(sin, f32, sin_fast); test_within_tolerance!(sin_lookup, f32, sin_lk); + test_within_tolerance!(tan, f32, tan_fast); test_within_tolerance!(sigmoid, f32, sigmoid_fast); } \ No newline at end of file