diff --git a/rust/lib.rs b/rust/lib.rs index fe85af587..cd3905b55 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -1,5 +1,5 @@ -use ndarray::{Array, Array2, ArrayView1, Axis, Ix2}; -use numpy::{IntoPyArray, PyArray2, PyReadonlyArray1}; +use ndarray::{ArrayView1, ArrayViewMut2, Axis}; +use numpy::{PyArray, PyArray2, PyReadonlyArray1}; use pyo3::{pymodule, types::PyModule, PyResult, Python}; #[inline] @@ -23,8 +23,7 @@ fn bls_to_ants_2_147_483_648(bl: &u64) -> [u64; 2] { const BLS_2_147_483_648: u64 = 2_u64.pow(16) + 2_u64.pow(22); const BLS_2048: u64 = 2_u64.pow(16); -fn _baseline_to_antnums(bls_array: ArrayView1) -> Array2 { - let nbls = bls_array.len(); +fn _baseline_to_antnums(bls_array: ArrayView1, mut bls_out: ArrayViewMut2) { let bls_min = bls_array.fold(bls_array[0], |x, y| x.min(*y)); let bl_fn = if bls_min >= BLS_2_147_483_648 { @@ -34,7 +33,6 @@ fn _baseline_to_antnums(bls_array: ArrayView1) -> Array2 { } else { bls_to_ants_256 }; - let mut bls_out = Array::::zeros((2, nbls)); bls_out .axis_iter_mut(Axis(1)) .zip(bls_array) @@ -43,8 +41,6 @@ fn _baseline_to_antnums(bls_array: ArrayView1) -> Array2 { ant_array[0] = ants[0]; ant_array[1] = ants[1]; }); - - bls_out } #[pymodule] @@ -54,14 +50,17 @@ fn _utils_rs<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> { py: Python<'py>, bls: PyReadonlyArray1<'py, u64>, ) -> &'py PyArray2 { - _baseline_to_antnums(bls.as_array()).into_pyarray(py) + let nbls = bls.shape()[0]; + let bls_out = unsafe { PyArray::new(py, (2, nbls), false) }; + _baseline_to_antnums(bls.as_array(), unsafe { bls_out.as_array_mut() }); + bls_out } Ok(()) } #[cfg(test)] mod test { - use ndarray::Array; + use ndarray::{Array, Ix2}; use super::*; @@ -80,7 +79,10 @@ mod test { ) .unwrap(); - let antnums = _baseline_to_antnums((&bls).into()); + let nbls = bls.len(); + let mut antnums = Array::::zeros((2, nbls)); + + _baseline_to_antnums((&bls).into(), (&mut antnums).into()); assert_eq!(ants, antnums) } @@ -99,7 +101,10 @@ mod test { ) .unwrap(); - let antnums = _baseline_to_antnums((&bls).into()); + let nbls = bls.len(); + let mut antnums = Array::::zeros((2, nbls)); + + _baseline_to_antnums((&bls).into(), (&mut antnums).into()); assert_eq!(ants, antnums) } @@ -118,7 +123,10 @@ mod test { ) .unwrap(); - let antnums = _baseline_to_antnums((&bls).into()); + let nbls = bls.len(); + let mut antnums = Array::::zeros((2, nbls)); + + _baseline_to_antnums((&bls).into(), (&mut antnums).into()); assert_eq!(ants, antnums) } }