Properly enforced fixed with csv reading

This commit is contained in:
2026-06-19 14:09:44 +01:00
parent c3ba06ceba
commit 5f098185dc
+41 -30
View File
@@ -1,7 +1,7 @@
use csv::ReaderBuilder;
use glob::glob;
use indicatif::ProgressBar;
use ndarray::{concatenate, stack, Array1, Array2, ArrayView1, ArrayView2, Axis, Slice};
use ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis, Slice, concatenate, stack};
use rayon::prelude::*;
use std::fs::File;
use std::path::{Path, PathBuf};
@@ -175,25 +175,32 @@ pub fn read_file(filepath: PathBuf) -> Result<(Array2<f64>, f64, usize)> {
.delimiter(b' ')
.has_headers(false)
.from_reader(file);
let data = rdr
.records()
.collect::<std::result::Result<Vec<csv::StringRecord>, _>>()?
.iter()
.map(|x| {
x.iter()
.map(|y| y.parse::<i64>().map_err(ReadError::ParseIntError))
.collect::<Result<Vec<i64>>>()
})
.collect::<Result<Vec<_>>>()?;
let length = data.len();
let width = data[0].len(); // WARNING: assumes fixed width columns!
let mut arr: Array2<f64> = Array2::zeros((length, width));
for (data_row, mut arr_row) in data.iter().zip(arr.axis_iter_mut(Axis(0))) {
for (data_i, arr_i) in data_row.iter().zip(arr_row.iter_mut()) {
*arr_i = *data_i as f64
let mut flat_data = Vec::new();
let mut width = 0;
let mut length = 0;
for result in rdr.records() {
let record = result?;
length += 1;
if width == 0 {
width = record.len();
} else if record.len() != width {
return Err(ReadError::MiscError(format!(
"Inconsistent row width: expected {}, found {}",
width,
record.len()
)));
}
for field in record.iter() {
flat_data.push(field.parse::<f64>()?);
}
}
let arr = Array2::from_shape_vec((length, width), flat_data)?;
Ok((arr, z, length))
}
@@ -236,7 +243,7 @@ mod tests {
};
use std::path::PathBuf;
use tar::Archive;
use tempfile::{tempdir, TempDir};
use tempfile::{TempDir, tempdir};
use xz::read::XzDecoder;
const TEST_DATA_FILE: &str = "tests/correct.tar.xz";
@@ -382,12 +389,14 @@ mod tests {
let true_x = unpack_test_data("correct_x_out.flex").unwrap();
let mut test_x = Array::range(-1000000., 1000000., 1.);
test_x.par_map_inplace(rust_fn::correct_x);
assert!(true_x
.iter()
.zip(test_x.iter())
.all(|(&x, y): (&f64, &f64)| -> bool {
x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps())
}));
assert!(
true_x
.iter()
.zip(test_x.iter())
.all(|(&x, y): (&f64, &f64)| -> bool {
x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps())
})
);
}
#[test]
@@ -405,11 +414,13 @@ mod tests {
let true_y = unpack_test_data("correct_y_out.flex").unwrap();
let mut test_y = Array::range(-1000000., 1000000., 1.);
test_y.par_map_inplace(rust_fn::correct_y);
assert!(true_y
.iter()
.zip(test_y.iter())
.all(|(&x, y): (&f64, &f64)| -> bool {
x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps())
}));
assert!(
true_y
.iter()
.zip(test_y.iter())
.all(|(&x, y): (&f64, &f64)| -> bool {
x.ulps_eq(y, f64::default_epsilon(), f64::default_max_ulps())
})
);
}
}