diff --git a/src/rust_fn/mod.rs b/src/rust_fn/mod.rs index 8cb36e2..f12d712 100644 --- a/src/rust_fn/mod.rs +++ b/src/rust_fn/mod.rs @@ -1,7 +1,7 @@ use csv::ReaderBuilder; use glob::glob; use indicatif::ProgressBar; -use ndarray::{concatenate, stack, Array1, Array2, ArrayView1, ArrayView2, Axis, Slice}; +use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Slice, concatenate, stack}; use rayon::prelude::*; use std::fs::File; use std::path::{Path, PathBuf}; @@ -175,25 +175,32 @@ pub fn read_file(filepath: PathBuf) -> Result<(Array2, f64, usize)> { .delimiter(b' ') .has_headers(false) .from_reader(file); - let data = rdr - .records() - .collect::, _>>()? - .iter() - .map(|x| { - x.iter() - .map(|y| y.parse::().map_err(ReadError::ParseIntError)) - .collect::>>() - }) - .collect::>>()?; - let length = data.len(); - let width = data[0].len(); // WARNING: assumes fixed width columns! - let mut arr: Array2 = Array2::zeros((length, width)); - for (data_row, mut arr_row) in data.iter().zip(arr.axis_iter_mut(Axis(0))) { - for (data_i, arr_i) in data_row.iter().zip(arr_row.iter_mut()) { - *arr_i = *data_i as f64 + + let mut flat_data = Vec::new(); + let mut width = 0; + let mut length = 0; + + for result in rdr.records() { + let record = result?; + length += 1; + + if width == 0 { + width = record.len(); + } else if record.len() != width { + return Err(ReadError::MiscError(format!( + "Inconsistent row width: expected {}, found {}", + width, + record.len() + ))); + } + + for field in record.iter() { + flat_data.push(field.parse::()?); } } + let arr = Array2::from_shape_vec((length, width), flat_data)?; + Ok((arr, z, length)) } @@ -236,7 +243,7 @@ mod tests { }; use std::path::PathBuf; use tar::Archive; - use tempfile::{tempdir, TempDir}; + use tempfile::{TempDir, tempdir}; use xz::read::XzDecoder; const TEST_DATA_FILE: &str = "tests/correct.tar.xz"; @@ -382,12 +389,14 @@ mod tests { let true_x = unpack_test_data("correct_x_out.flex").unwrap(); let mut test_x = Array::range(-1000000., 1000000., 1.); test_x.par_map_inplace(rust_fn::correct_x); - assert!(true_x - .iter() - .zip(test_x.iter()) - .all(|(&x, y): (&f64, &f64)| -> bool { - x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps()) - })); + assert!( + true_x + .iter() + .zip(test_x.iter()) + .all(|(&x, y): (&f64, &f64)| -> bool { + x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps()) + }) + ); } #[test] @@ -405,11 +414,13 @@ mod tests { let true_y = unpack_test_data("correct_y_out.flex").unwrap(); let mut test_y = Array::range(-1000000., 1000000., 1.); test_y.par_map_inplace(rust_fn::correct_y); - assert!(true_y - .iter() - .zip(test_y.iter()) - .all(|(&x, y): (&f64, &f64)| -> bool { - x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps()) - })); + assert!( + true_y + .iter() + .zip(test_y.iter()) + .all(|(&x, y): (&f64, &f64)| -> bool { + x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps()) + }) + ); } }