Skip to content

Commit

Permalink
use ndarray axis iter
Browse files Browse the repository at this point in the history
  • Loading branch information
mkolopanis committed Nov 3, 2023
1 parent 21d855f commit ec62237
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ndarray::ArrayView1;
use ndarray::{ArrayView1, Axis};
use numpy::{PyArray2, PyReadonlyArray1};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

Expand Down Expand Up @@ -41,11 +41,14 @@ fn _utils_rs<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
let bls_out = unsafe { PyArray2::new(py, (2, nbls), false) };
let mut _bls_out = unsafe { bls_out.as_array_mut() };

(0..nbls).for_each(|cnt| {
let ants = bl_fn(&bls_array[cnt]);
_bls_out[[0, cnt]] = ants[0];
_bls_out[[1, cnt]] = ants[1];
});
_bls_out
.axis_iter_mut(Axis(1))
.zip(bls_array)
.for_each(|(mut ant_array, bl)| {
let ants = bl_fn(bl);
ant_array[0] = ants[0];
ant_array[1] = ants[1];
});

bls_out
}
Expand Down

0 comments on commit ec62237

Please sign in to comment.