Skip to content

Commit

Permalink
init as empty py array
Browse files Browse the repository at this point in the history
  • Loading branch information
mkolopanis committed Nov 3, 2023
1 parent 982edd5 commit 6503682
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use ndarray::{Array, Array2, ArrayView1, Axis, Ix2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray1};
use ndarray::{ArrayView1, ArrayViewMut2, Axis};
use numpy::{PyArray, PyArray2, PyReadonlyArray1};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[inline]
Expand All @@ -23,8 +23,7 @@ fn bls_to_ants_2_147_483_648(bl: &u64) -> [u64; 2] {
const BLS_2_147_483_648: u64 = 2_u64.pow(16) + 2_u64.pow(22);
const BLS_2048: u64 = 2_u64.pow(16);

fn _baseline_to_antnums(bls_array: ArrayView1<u64>) -> Array2<u64> {
let nbls = bls_array.len();
fn _baseline_to_antnums(bls_array: ArrayView1<u64>, mut bls_out: ArrayViewMut2<u64>) {
let bls_min = bls_array.fold(bls_array[0], |x, y| x.min(*y));

let bl_fn = if bls_min >= BLS_2_147_483_648 {
Expand All @@ -34,7 +33,6 @@ fn _baseline_to_antnums(bls_array: ArrayView1<u64>) -> Array2<u64> {
} else {
bls_to_ants_256
};
let mut bls_out = Array::<u64, Ix2>::zeros((2, nbls));
bls_out
.axis_iter_mut(Axis(1))
.zip(bls_array)
Expand All @@ -43,8 +41,6 @@ fn _baseline_to_antnums(bls_array: ArrayView1<u64>) -> Array2<u64> {
ant_array[0] = ants[0];
ant_array[1] = ants[1];
});

bls_out
}

#[pymodule]
Expand All @@ -54,14 +50,17 @@ fn _utils_rs<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
py: Python<'py>,
bls: PyReadonlyArray1<'py, u64>,
) -> &'py PyArray2<u64> {
_baseline_to_antnums(bls.as_array()).into_pyarray(py)
let nbls = bls.shape()[0];
let bls_out = unsafe { PyArray::new(py, (2, nbls), false) };
_baseline_to_antnums(bls.as_array(), unsafe { bls_out.as_array_mut() });
bls_out
}
Ok(())
}

#[cfg(test)]
mod test {
use ndarray::Array;
use ndarray::{Array, Ix2};

use super::*;

Expand All @@ -80,7 +79,10 @@ mod test {
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
let nbls = bls.len();
let mut antnums = Array::<u64, Ix2>::zeros((2, nbls));

_baseline_to_antnums((&bls).into(), (&mut antnums).into());
assert_eq!(ants, antnums)
}

Expand All @@ -99,7 +101,10 @@ mod test {
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
let nbls = bls.len();
let mut antnums = Array::<u64, Ix2>::zeros((2, nbls));

_baseline_to_antnums((&bls).into(), (&mut antnums).into());
assert_eq!(ants, antnums)
}

Expand All @@ -118,7 +123,10 @@ mod test {
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
let nbls = bls.len();
let mut antnums = Array::<u64, Ix2>::zeros((2, nbls));

_baseline_to_antnums((&bls).into(), (&mut antnums).into());
assert_eq!(ants, antnums)
}
}

0 comments on commit 6503682

Please sign in to comment.