Skip to content

Commit

Permalink
move some functions to just rust for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mkolopanis committed Nov 3, 2023
1 parent 1521a8f commit 982edd5
Showing 1 changed file with 94 additions and 26 deletions.
120 changes: 94 additions & 26 deletions rust/lib.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
use ndarray::{ArrayView1, Axis};
use numpy::{PyArray2, PyReadonlyArray1};
use ndarray::{Array, Array2, ArrayView1, Axis, Ix2};
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray1};
use pyo3::{pymodule, types::PyModule, PyResult, Python};

#[inline]
fn bls_to_ants_256(bl: &u64) -> [u64; 2] {
let a1 = bl % 256;
[(bl - a1) / 256, a1]
}

#[inline]
fn bls_to_ants_2048(bl: &u64) -> [u64; 2] {
let a1 = (bl - BLS_2048) % 2048;
[(bl - BLS_2048 - a1) / 2048, a1]
}

#[inline]
fn bls_to_ants_2_147_483_648(bl: &u64) -> [u64; 2] {
let a1 = (bl - BLS_2_147_483_648) % 2_147_483_648;
[(bl - BLS_2_147_483_648 - a1) / 2_147_483_648, a1]
Expand All @@ -20,37 +23,102 @@ fn bls_to_ants_2_147_483_648(bl: &u64) -> [u64; 2] {
const BLS_2_147_483_648: u64 = 2_u64.pow(16) + 2_u64.pow(22);
const BLS_2048: u64 = 2_u64.pow(16);

fn _baseline_to_antnums(bls_array: ArrayView1<u64>) -> Array2<u64> {
let nbls = bls_array.len();
let bls_min = bls_array.fold(bls_array[0], |x, y| x.min(*y));

let bl_fn = if bls_min >= BLS_2_147_483_648 {
bls_to_ants_2_147_483_648
} else if bls_min >= BLS_2048 {
bls_to_ants_2048
} else {
bls_to_ants_256
};
let mut bls_out = Array::<u64, Ix2>::zeros((2, nbls));
bls_out
.axis_iter_mut(Axis(1))
.zip(bls_array)
.for_each(|(mut ant_array, bl)| {
let ants = bl_fn(bl);
ant_array[0] = ants[0];
ant_array[1] = ants[1];
});

bls_out
}

#[pymodule]
fn _utils_rs<'py>(_py: Python<'py>, m: &'py PyModule) -> PyResult<()> {
#[pyfn(m)]
fn baseline_to_antnums<'py>(
py: Python<'py>,
bls: PyReadonlyArray1<'py, u64>,
) -> &'py PyArray2<u64> {
let nbls = bls.len();
let bls_array: ArrayView1<u64> = bls.as_array();
let bls_min = bls_array.fold(bls_array[0], |x, y| x.min(*y));

let bl_fn = if bls_min >= BLS_2_147_483_648 {
bls_to_ants_2_147_483_648
} else if bls_min >= BLS_2048 {
bls_to_ants_2048
} else {
bls_to_ants_256
};
let bls_out = unsafe { PyArray2::new(py, (2, nbls), false) };
let mut _bls_out = unsafe { bls_out.as_array_mut() };

_bls_out
.axis_iter_mut(Axis(1))
.zip(bls_array)
.for_each(|(mut ant_array, bl)| {
let ants = bl_fn(bl);
ant_array[0] = ants[0];
ant_array[1] = ants[1];
});

bls_out
_baseline_to_antnums(bls.as_array()).into_pyarray(py)
}
Ok(())
}

#[cfg(test)]
mod test {
use ndarray::Array;

use super::*;

#[test]
fn bls_to_ants256() {
let bls = Array::from_iter(1..50_u64);
let ants = Array::from_shape_vec(
(2, 49),
vec![
0_u64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49,
],
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
assert_eq!(ants, antnums)
}

#[test]
fn bls_to_ants2048() {
let bls = Array::from_iter(1..50_u64) + 2_u64.pow(16);
let ants = Array::from_shape_vec(
(2, 49),
vec![
0_u64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49,
],
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
assert_eq!(ants, antnums)
}

#[test]
fn bls_to_antslarge() {
let bls = Array::from_iter(1..50_u64) + 2_u64.pow(16) + 2_u64.pow(22);
let ants = Array::from_shape_vec(
(2, 49),
vec![
0_u64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49,
],
)
.unwrap();

let antnums = _baseline_to_antnums((&bls).into());
assert_eq!(ants, antnums)
}
}

0 comments on commit 982edd5

Please sign in to comment.