Skip to content

Commit

Permalink
Add put_into for ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
magnusuMET committed Dec 3, 2023
1 parent b93b775 commit 0a6d9e4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 0 deletions.
63 changes: 63 additions & 0 deletions netcdf/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,69 @@ impl<'g> Variable<'g> {
self.values_arr_mono(&extents)
}

#[cfg(feature = "ndarray")]
/// Get values into an ndarray
pub fn values_arr_into<T: NcPutGet, E, D>(
&self,
extents: E,
mut out: ndarray::ArrayViewMut<T, D>,
) -> error::Result<()>
where
D: ndarray::Dimension,
E: TryInto<Extents>,
E::Error: Into<error::Error>,
{
let extents = extents.try_into().map_err(|e| e.into())?;

let dims = self.dimensions();
let mut start = Vec::with_capacity(dims.len());
let mut count = Vec::with_capacity(dims.len());
let mut stride = Vec::with_capacity(dims.len());

let mut rem_outshape = out.shape();

for (pos, item) in extents.iter_with_dims(dims)?.enumerate() {
start.push(item.start);
count.push(item.count);
stride.push(item.stride);
if !item.is_an_index {
let cur_dim_len = if let Some((&head, rest)) = rem_outshape.split_first() {
rem_outshape = rest;
head
} else {
return Err(("Output array dimensionality is less than extents").into());
};
if item.count != cur_dim_len {
return Err(format!("Item count (position {pos}) as {} but expected in output was {cur_dim_len}", item.count).into());
}
}
}
if !rem_outshape.is_empty() {
return Err(("Output array dimensionality is larger than extents").into());
}

let slice = if let Some(slice) = out.as_slice_mut() {
slice
} else {
return Err("Output array must be in standard layout".into());
};

assert_eq!(
slice.len(),
count.iter().copied().fold(1, usize::saturating_mul),
"Output size and number of elements to get are not compatible"
);

// Safety:
// start, count, stride are correct length
// slice is valid pointer, with enough space to hold all elements
unsafe {
T::get_vars(self, &start, &count, &stride, slice.as_mut_ptr())?;
}

Ok(())
}

/// Get the fill value of a variable
pub fn fill_value<T: NcPutGet>(&self) -> error::Result<Option<T>> {
if T::NCTYPE != self.vartype {
Expand Down
46 changes: 46 additions & 0 deletions netcdf/tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1696,3 +1696,49 @@ fn ndarray_put() {
put_values!(var, (5..6, .., ..), values, s![.., .., .., 0], Failure);
put_values!(var, (5..15, .., ..), values, s![.., .., .., 0]);
}

#[test]
#[cfg(feature = "ndarray")]
fn ndarray_get_into() {
use ndarray::s;

let d = tempfile::tempdir().unwrap();
let path = d.path().join("get_into.nc");

let mut f = netcdf::create(path).unwrap();
f.add_dimension("d1", 4).unwrap();
f.add_dimension("d2", 5).unwrap();
f.add_dimension("d3", 6).unwrap();

let values = ndarray::Array::<u64, _>::from_shape_fn((4, 5, 6), |(k, j, i)| {
(100 * k + 10 * j + i).try_into().unwrap()
});

let mut var = f.add_variable::<u64>("var", &["d1", "d2", "d3"]).unwrap();

var.put_values_arr(.., values.view()).unwrap();

let mut outarray = ndarray::Array::<u64, _>::zeros((4, 5, 6));

var.values_arr_into(.., outarray.view_mut()).unwrap();
assert_eq!(values, outarray);
outarray.fill(0);

var.values_arr_into((1, .., ..), outarray.slice_mut(s![0, .., ..]))
.unwrap();
assert_eq!(values.slice(s![1, .., ..]), outarray.slice(s![0, .., ..]));
outarray.fill(0);

var.values_arr_into((3, 1, ..), outarray.slice_mut(s![0, 0, ..]))
.unwrap();
assert_eq!(values.slice(s![3, 1, ..]), outarray.slice(s![0, 0, ..]));
outarray.fill(0);

var.values_arr_into((.., .., 1), outarray.slice_mut(s![.., .., 1]))
.unwrap_err();

let mut outarray = ndarray::Array::<u64, _>::zeros((3, 4, 5, 6));
var.values_arr_into((.., .., ..), outarray.slice_mut(s![0, .., .., ..]))
.unwrap();
assert_eq!(values, outarray.slice(s![0, .., .., ..]));
}

0 comments on commit 0a6d9e4

Please sign in to comment.