diff --git a/src/lib.rs b/src/lib.rs index 98e55b5..421a57b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,11 +23,11 @@ pub struct AgglomEdge(bool, usize, usize); pub fn get_edges( affinities: &Array, - offsets: Vec>, + offsets: Vec>, seeds: &Array, ) -> Vec { // let (_, array_shape) = get_dims::(seeds.dim(), 0); - let offsets: Vec<[usize; D]> = offsets + let offsets: Vec<[isize; D]> = offsets .into_iter() .map(|offset| offset.try_into().unwrap()) .collect(); @@ -41,11 +41,33 @@ pub fn get_edges( .enumerate() .for_each(|(offset_index, offset)| { let all_offset_affs = affinities.index_axis(Axis(0), offset_index); - let offset_affs = all_offset_affs - .slice_each_axis(|ax| Slice::from(0..(ax.len - offset[ax.axis.index()]))); - let u_seeds = - seeds.slice_each_axis(|ax| Slice::from(0..(ax.len - offset[ax.axis.index()]))); - let v_seeds = seeds.slice_each_axis(|ax| Slice::from(offset[ax.axis.index()]..)); + let offset_affs = all_offset_affs.slice_each_axis(|ax| { + Slice::from( + std::cmp::max(0, -offset[ax.axis.index()]) + ..std::cmp::min( + ax.len as isize, + (ax.len as isize) - offset[ax.axis.index()], + ), + ) + }); + let u_seeds = seeds.slice_each_axis(|ax| { + Slice::from( + std::cmp::max(0, -offset[ax.axis.index()]) + ..std::cmp::min( + ax.len as isize, + (ax.len as isize) - offset[ax.axis.index()], + ), + ) + }); + let v_seeds = seeds.slice_each_axis(|ax| { + Slice::from( + std::cmp::max(0, offset[ax.axis.index()]) + ..std::cmp::min( + ax.len as isize, + (ax.len as isize) + offset[ax.axis.index()], + ), + ) + }); offset_affs.indexed_iter().for_each(|(index, aff)| { let u = u_seeds[&index]; let v = v_seeds[&index]; @@ -161,7 +183,7 @@ impl Clustering { pub fn agglomerate( affinities: &Array, - offsets: Vec>, + offsets: Vec>, mut edges: Vec, mut seeds: Array, ) -> Array { @@ -259,7 +281,7 @@ pub fn cluster_edges(mut sorted_edges: Vec) -> Vec<(usize, usize)> { fn agglom<'py>( _py: Python<'py>, affinities: &PyArrayDyn, - offsets: Vec>, + offsets: Vec>, seeds: Option<&PyArrayDyn>, edges: Option>, ) -> PyResult<&'py PyArrayDyn> { @@ -341,6 +363,27 @@ mod tests { assert!(ids.len() == 4); } #[test] + fn test_agglom_negative_offsets() { + let affinities = array![ + [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]] + ] + .into_dyn() + - 0.5; + let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); + let offsets = vec![vec![0, -1], vec![-1, 0]]; + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds); + let ids = components + .clone() + .into_iter() + .unique() + .collect::>(); + for id in [1, 2, 4].iter() { + assert!(ids.contains(id)); + } + assert!(ids.len() == 4); + } + #[test] fn test_cluster() { let edges = vec![ AgglomEdge(false, 1, 2), diff --git a/tests/test_mwatershed.py b/tests/test_mwatershed.py index 2d0327d..9457d8c 100644 --- a/tests/test_mwatershed.py +++ b/tests/test_mwatershed.py @@ -52,6 +52,29 @@ def test_agglom_2d(): assert (components == 5).sum() == 4 +def test_agglom_2d_negative_offsets(): + offsets = [(0, -1), (-1, 0)] + affinities = ( + np.array( + [[[0, 1, 0], [0, 1, 0], [0, 1, 0]], [[0, 0, 0], [1, 1, 1], [0, 0, 0]]], + dtype=float, + ) + - 0.5 + ) + # 9 nodes. connecting edges: + # 2-3, 5-6, 8-9, 4-7, 5-8, 6-9 + # components: [(1,),(2,3),(4,7),(5,6,8,9)] + + components = mwatershed.agglom(affinities, offsets) + + # assert False, components + + # assert (components == 1).sum() == 1 + # assert (components == 2).sum() == 2 + # assert (components == 4).sum() == 2 + # assert (components == 5).sum() == 4 + + def test_agglom_2d_with_extra_edges(): nodes = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 9]], dtype=np.uint64) offsets = [(0, 1), (1, 0)]