From 2674200992a010fe151835a3aeedba1c0265b874 Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 6 Mar 2024 11:12:00 -0500 Subject: [PATCH] add support for strides for offsets --- mwatershed/__init__.py | 3 +- src/lib.rs | 118 +++++++++++++++++++++++++++++++---------- tests/test_agglom.py | 29 ++++++++++ 3 files changed, 120 insertions(+), 30 deletions(-) create mode 100644 tests/test_agglom.py diff --git a/mwatershed/__init__.py b/mwatershed/__init__.py index b13aa70..7a140b8 100644 --- a/mwatershed/__init__.py +++ b/mwatershed/__init__.py @@ -16,8 +16,9 @@ def agglom( offsets: list[list[int]], seeds: Optional[np.ndarray] = None, edges: Optional[list[tuple[bool, int, int]]] = None, + strides: Optional[list[list[int]]] = None ): - return agglom_rs(affinities, offsets, seeds, edges) + return agglom_rs(affinities, offsets, seeds, edges, strides) __all__ = ["agglom", "cluster"] diff --git a/src/lib.rs b/src/lib.rs index 239758a..34394cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,6 +25,7 @@ pub fn get_edges( affinities: &Array, offsets: Vec>, seeds: &Array, + strides: Option>>, ) -> (Vec, HashSet) { // let (_, array_shape) = get_dims::(seeds.dim(), 0); let offsets: Vec<[isize; D]> = offsets @@ -38,36 +39,46 @@ pub fn get_edges( let mut to_filter: HashSet = HashSet::from_iter(seeds.iter().copied()); + let strides = strides.unwrap_or_else(|| { + (0..D) + .map(|_| (0..D).map(|_| 1).collect()) + .collect::>>() + }); + offsets .iter() + .zip(strides.iter()) .enumerate() - .for_each(|(offset_index, offset)| { + .for_each(|(offset_index, (offset, stride))| { let all_offset_affs = affinities.index_axis(Axis(0), offset_index); let offset_affs = all_offset_affs.slice_each_axis(|ax| { - Slice::from( - std::cmp::max(0, -offset[ax.axis.index()]) - ..std::cmp::min( - ax.len as isize, - (ax.len as isize) - offset[ax.axis.index()], - ), + Slice::new( + std::cmp::max(0, -offset[ax.axis.index()]), + Some(std::cmp::min( + ax.len as isize, + (ax.len as isize) - offset[ax.axis.index()], + )), + stride[ax.axis.index()].try_into().unwrap(), ) }); let u_seeds = seeds.slice_each_axis(|ax| { - Slice::from( - std::cmp::max(0, -offset[ax.axis.index()]) - ..std::cmp::min( - ax.len as isize, - (ax.len as isize) - offset[ax.axis.index()], - ), + Slice::new( + std::cmp::max(0, -offset[ax.axis.index()]), + Some(std::cmp::min( + ax.len as isize, + (ax.len as isize) - offset[ax.axis.index()], + )), + stride[ax.axis.index()].try_into().unwrap(), ) }); let v_seeds = seeds.slice_each_axis(|ax| { - Slice::from( - std::cmp::max(0, offset[ax.axis.index()]) - ..std::cmp::min( - ax.len as isize, - (ax.len as isize) + offset[ax.axis.index()], - ), + Slice::new( + std::cmp::max(0, offset[ax.axis.index()]), + Some(std::cmp::min( + ax.len as isize, + (ax.len as isize) + offset[ax.axis.index()], + )), + stride[ax.axis.index()].try_into().unwrap(), ) }); offset_affs.indexed_iter().for_each(|(index, aff)| { @@ -104,6 +115,7 @@ pub fn agglomerate( offsets: Vec>, mut edges: Vec, mut seeds: Array, + strides: Option>>, ) -> Array { // relabel to consecutive ids let mut lookup = HashMap::new(); @@ -130,7 +142,8 @@ pub fn agglomerate( }); // main algorithm - let (sorted_edges, mut filtered_background) = get_edges::(affinities, offsets, &seeds); + let (sorted_edges, mut filtered_background) = + get_edges::(affinities, offsets, &seeds, strides); edges.extend(sorted_edges); lookup.values().for_each(|node_id| { filtered_background.remove(node_id); @@ -211,6 +224,7 @@ fn agglom_rs<'py>( offsets: Vec>, seeds: Option<&PyArrayDyn>, edges: Option>, + strides: Option>>, ) -> PyResult<&'py PyArrayDyn> { let affinities = unsafe { affinities.as_array() }.to_owned(); let seeds = match seeds { @@ -224,12 +238,12 @@ fn agglom_rs<'py>( .map(|(pos, u, v)| AgglomEdge(pos, u, v)) .collect(); let result = match dim { - 1 => agglomerate::<1>(&affinities, offsets, edges, seeds), - 2 => agglomerate::<2>(&affinities, offsets, edges, seeds), - 3 => agglomerate::<3>(&affinities, offsets, edges, seeds), - 4 => agglomerate::<4>(&affinities, offsets, edges, seeds), - 5 => agglomerate::<5>(&affinities, offsets, edges, seeds), - 6 => agglomerate::<6>(&affinities, offsets, edges, seeds), + 1 => agglomerate::<1>(&affinities, offsets, edges, seeds, strides), + 2 => agglomerate::<2>(&affinities, offsets, edges, seeds, strides), + 3 => agglomerate::<3>(&affinities, offsets, edges, seeds, strides), + 4 => agglomerate::<4>(&affinities, offsets, edges, seeds, strides), + 5 => agglomerate::<5>(&affinities, offsets, edges, seeds, strides), + 6 => agglomerate::<6>(&affinities, offsets, edges, seeds, strides), _ => panic!["Only 1-6 dimensional arrays supported"], }; Ok(result.into_pyarray(_py)) @@ -256,6 +270,8 @@ fn mwatershed(_py: Python, m: &PyModule) -> PyResult<()> { #[cfg(test)] mod tests { + use std::vec; + use super::*; use itertools::Itertools; use ndarray::array; @@ -292,7 +308,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, 1], vec![1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); let ids = components .clone() .into_iter() @@ -304,6 +320,50 @@ mod tests { assert!(!ids.contains(&0), "{:?}", components); assert!(ids.len() == 4, "{:?}", components); } + /// Seeds + /// 1 2 0 + /// 4 0 0 + /// 0 0 0 + /// + /// Affs + /// offset [0, 1] + /// 0 1 0 + /// 0 1 0 + /// 0 1 0 + /// + /// offset [1, 0] + /// 0 0 0 + /// 1 1 1 + /// 0 0 0 + /// + /// Expected Components + /// 1 2 2 + /// 4 0 x + /// 4 x x + /// + #[test] + fn test_agglom_with_strides() { + let affinities = array![ + [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]] + ] + .into_dyn() + - 0.5; + let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); + let offsets = vec![vec![0, 1], vec![1, 0]]; + let strides = vec![vec![2, 1], vec![1, 2]]; + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides)); + let ids = components + .clone() + .into_iter() + .unique() + .collect::>(); + for id in [1, 2, 4].iter() { + assert!(ids.contains(id), "{:?}", components); + } + assert!(ids.contains(&0), "{:?}", components); + assert!(ids.len() == 5, "{:?}", components); + } /// Seeds /// 1 2 0 @@ -336,7 +396,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, -1], vec![-1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); let ids = components .clone() .into_iter() @@ -380,7 +440,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, -1], vec![-1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); let ids = components .clone() .into_iter() diff --git a/tests/test_agglom.py b/tests/test_agglom.py new file mode 100644 index 0000000..73af9f5 --- /dev/null +++ b/tests/test_agglom.py @@ -0,0 +1,29 @@ +import numpy as np + +import mwatershed + + +def test_agglom_2d_with_strides(): + offsets = [(0, 1), (1, 0)] + strides = [(1, 2), (2, 1)] + affinities = ( + np.array( + [[[0, 1, 0], [0, 1, 0], [0, 1, 0]], [[0, 0, 0], [1, 1, 1], [0, 0, 0]]], + dtype=float, + ) + - 0.5 + ) + # 9 nodes. connecting edges: + # 2-3, 5-6, 8-9, 4-7, 5-8, 6-9 + # components: [(1,),(2,3),(4,7),(5,6,8,9)] + + components = mwatershed.agglom(affinities, offsets, strides=strides) + + _, counts = np.unique(components, return_counts=True) + counts = sorted(counts) + assert len(counts) == 4 + + assert counts[0] == 1 + assert counts[1] == 2 + assert counts[2] == 2 + assert counts[3] == 4