diff --git a/rust/lib.rs b/rust/lib.rs index cc47458ac..02a6e1cbd 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,4 +1,4 @@ -use ndarray::ArrayView1; +use ndarray::{ArrayView1, Axis}; use numpy::{PyArray2, PyReadonlyArray1}; use pyo3::{pymodule, types::PyModule, PyResult, Python}; @@ -41,11 +41,14 @@ fn _utils_rs<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> { let bls_out = unsafe { PyArray2::new(py, (2, nbls), false) }; let mut _bls_out = unsafe { bls_out.as_array_mut() }; - (0..nbls).for_each(|cnt| { - let ants = bl_fn(&bls_array[cnt]); - _bls_out[[0, cnt]] = ants[0]; - _bls_out[[1, cnt]] = ants[1]; - }); + _bls_out + .axis_iter_mut(Axis(1)) + .zip(bls_array) + .for_each(|(mut ant_array, bl)| { + let ants = bl_fn(bl); + ant_array[0] = ants[0]; + ant_array[1] = ants[1]; + }); bls_out }