diff --git a/mwatershed/__init__.py b/mwatershed/__init__.py index 13995df..6200502 100644 --- a/mwatershed/__init__.py +++ b/mwatershed/__init__.py @@ -17,8 +17,9 @@ def agglom( seeds: Optional[np.ndarray] = None, edges: Optional[list[tuple[bool, int, int]]] = None, strides: Optional[list[list[int]]] = None, + randomized_strides: bool = False, ): - return agglom_rs(affinities, offsets, seeds, edges, strides) + return agglom_rs(affinities, offsets, seeds, edges, strides, randomized_strides) __all__ = ["agglom", "cluster"] diff --git a/src/lib.rs b/src/lib.rs index e5c70ec..6a2d3dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,6 +11,8 @@ use pyo3::prelude::*; use pyo3::wrap_pyfunction; use pyo3::PyResult; +use ndarray_rand::rand; + use ordered_float::NotNan; use std::collections::{HashMap, HashSet}; use std::convert::TryInto; @@ -26,6 +28,7 @@ pub fn get_edges( offsets: Vec>, seeds: &Array, strides: Option>>, + randomized_strides: bool, ) -> (Vec, HashSet) { // let (_, array_shape) = get_dims::(seeds.dim(), 0); let offsets: Vec<[isize; D]> = offsets @@ -52,40 +55,59 @@ pub fn get_edges( .for_each(|(offset_index, (offset, stride))| { let all_offset_affs = affinities.index_axis(Axis(0), offset_index); let offset_affs = all_offset_affs.slice_each_axis(|ax| { + let step = if randomized_strides { + 1 + } else { + stride[ax.axis.index()].try_into().unwrap() + }; Slice::new( std::cmp::max(0, -offset[ax.axis.index()]), Some(std::cmp::min( ax.len as isize, (ax.len as isize) - offset[ax.axis.index()], )), - stride[ax.axis.index()].try_into().unwrap(), + step, ) }); let u_seeds = seeds.slice_each_axis(|ax| { + let step = if randomized_strides { + 1 + } else { + stride[ax.axis.index()].try_into().unwrap() + }; Slice::new( std::cmp::max(0, -offset[ax.axis.index()]), Some(std::cmp::min( ax.len as isize, (ax.len as isize) - offset[ax.axis.index()], )), - stride[ax.axis.index()].try_into().unwrap(), + step, ) }); let v_seeds = seeds.slice_each_axis(|ax| { + let step = if randomized_strides { + 1 + } else { + stride[ax.axis.index()].try_into().unwrap() + }; Slice::new( std::cmp::max(0, offset[ax.axis.index()]), Some(std::cmp::min( ax.len as isize, (ax.len as isize) + offset[ax.axis.index()], )), - stride[ax.axis.index()].try_into().unwrap(), + step, ) }); offset_affs.indexed_iter().for_each(|(index, aff)| { - let u = u_seeds[&index]; - let v = v_seeds[&index]; - affs.push(NotNan::new(aff.abs()).expect("Cannot handle `nan` affinities")); - edges.push(AgglomEdge(aff > &0.0, u, v)) + if !randomized_strides + || rand::random::() < 1.0 / stride.iter().product::() as f32 + { + let u = u_seeds[&index]; + let v = v_seeds[&index]; + affs.push(NotNan::new(aff.abs()).expect("Cannot handle `nan` affinities")); + edges.push(AgglomEdge(aff > &0.0, u, v)) + } }); }); let agglom_edges: Vec = affs @@ -116,6 +138,7 @@ pub fn agglomerate( mut edges: Vec, mut seeds: Array, strides: Option>>, + randomized_strides: bool, ) -> Array { // relabel to consecutive ids let mut lookup = HashMap::new(); @@ -143,7 +166,7 @@ pub fn agglomerate( // main algorithm let (sorted_edges, mut filtered_background) = - get_edges::(affinities, offsets, &seeds, strides); + get_edges::(affinities, offsets, &seeds, strides, randomized_strides); edges.extend(sorted_edges); lookup.values().for_each(|node_id| { filtered_background.remove(node_id); @@ -225,6 +248,7 @@ fn agglom_rs<'py>( seeds: Option<&PyArrayDyn>, edges: Option>, strides: Option>>, + randomized_strides: Option, ) -> PyResult<&'py PyArrayDyn> { let affinities = unsafe { affinities.as_array() }.to_owned(); let seeds = match seeds { @@ -238,12 +262,54 @@ fn agglom_rs<'py>( .map(|(pos, u, v)| AgglomEdge(pos, u, v)) .collect(); let result = match dim { - 1 => agglomerate::<1>(&affinities, offsets, edges, seeds, strides), - 2 => agglomerate::<2>(&affinities, offsets, edges, seeds, strides), - 3 => agglomerate::<3>(&affinities, offsets, edges, seeds, strides), - 4 => agglomerate::<4>(&affinities, offsets, edges, seeds, strides), - 5 => agglomerate::<5>(&affinities, offsets, edges, seeds, strides), - 6 => agglomerate::<6>(&affinities, offsets, edges, seeds, strides), + 1 => agglomerate::<1>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), + 2 => agglomerate::<2>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), + 3 => agglomerate::<3>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), + 4 => agglomerate::<4>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), + 5 => agglomerate::<5>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), + 6 => agglomerate::<6>( + &affinities, + offsets, + edges, + seeds, + strides, + randomized_strides.unwrap_or(false), + ), _ => panic!["Only 1-6 dimensional arrays supported"], }; Ok(result.into_pyarray(_py)) @@ -308,7 +374,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, 1], vec![1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false); let ids = components .clone() .into_iter() @@ -352,7 +418,8 @@ mod tests { let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, 1], vec![1, 0]]; let strides = vec![vec![2, 1], vec![1, 2]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides)); + let components = + agglomerate::<2>(&affinities, offsets, vec![], seeds, Some(strides), false); let ids = components .clone() .into_iter() @@ -396,7 +463,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, -1], vec![-1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false); let ids = components .clone() .into_iter() @@ -440,7 +507,7 @@ mod tests { - 0.5; let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn(); let offsets = vec![vec![0, -1], vec![-1, 0]]; - let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None); + let components = agglomerate::<2>(&affinities, offsets, vec![], seeds, None, false); let ids = components .clone() .into_iter()