Skip to content

Commit

Permalink
add test for filtering out background
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Dec 29, 2023
1 parent 28ae2f5 commit f29f4f6
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 8 deletions.
11 changes: 11 additions & 0 deletions src/clustering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,17 @@ impl Clustering {
*x = self.positives.clusters.find(*x);
});
}

/// Map a vector of node ids to their representative node ids, replacing
/// nodes that have been filtered out with 0.
pub fn filter_map(&self, seeds: &mut Array<usize, IxDyn>, filtered_nodes: HashSet<usize>) {
seeds
.iter_mut()
.for_each(|x| match filtered_nodes.contains(x) {
false => *x = self.positives.clusters.find(*x),
true => *x = self.positives.clusters.len(),
});
}
}

#[cfg(test)]
Expand Down
72 changes: 64 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn get_edges<const D: usize>(
affinities: &Array<f64, IxDyn>,
offsets: Vec<Vec<isize>>,
seeds: &Array<usize, IxDyn>,
) -> Vec<AgglomEdge> {
) -> (Vec<AgglomEdge>, HashSet<usize>) {
// let (_, array_shape) = get_dims::<D>(seeds.dim(), 0);
let offsets: Vec<[isize; D]> = offsets
.into_iter()
Expand Down Expand Up @@ -91,10 +91,13 @@ pub fn get_edges<const D: usize>(
to_filter.remove(v);
}
});
agglom_edges
.into_iter()
.filter(|edge| !(to_filter.contains(&edge.1) || to_filter.contains(&edge.2)))
.collect()
(
agglom_edges
.into_iter()
.filter(|edge| !(to_filter.contains(&edge.1) || to_filter.contains(&edge.2)))
.collect(),
to_filter,
)
}

pub fn agglomerate<const D: usize>(
Expand Down Expand Up @@ -122,18 +125,21 @@ pub fn agglomerate<const D: usize>(
});

// main algorithm
let sorted_edges = get_edges::<D>(affinities, offsets, &seeds);
let (sorted_edges, mut filtered_background) = get_edges::<D>(affinities, offsets, &seeds);
edges.extend(sorted_edges);
lookup.values().for_each(|node_id| {
filtered_background.remove(node_id);
});

let mut clustering = Clustering::new(seeds.len());

clustering.process_edges(edges);
clustering.map(&mut seeds);
clustering.filter_map(&mut seeds, filtered_background);

// TODO: Fix seed handling
// now we have to remap seeded entries back onto the original ids
let mut rev_lookup = HashMap::with_capacity(lookup.len());
//rev_lookup.insert(0, seeds.len());
rev_lookup.insert(seeds.len(), 0);
lookup.iter().for_each(|(seed, id)| {
let rep_id = clustering.positives.clusters.find(*id);
if *seed != rep_id {
Expand Down Expand Up @@ -337,6 +343,56 @@ mod tests {
assert!(!ids.contains(&0), "{:?}", components);
assert!(ids.len() == 4, "{:?}", components);
}

/// Seeds
/// 1 2 0
/// 4 0 0
/// 0 0 0
///
/// Affs
/// offset [0, -1]
/// 0 0 0
/// 0 0 0
/// 0 0 0
///
/// offset [-1, 0]
/// 0 0 0
/// 0 0 0
/// 0 0 0
///
/// Expected Components
/// 1 2 0
/// 4 0 0
/// 0 0 0
///
#[test]
fn test_filtered_background() {
let affinities = array![
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]
]
.into_dyn()
- 0.5;
let seeds = array![[1, 2, 0], [4, 0, 0], [0, 0, 0]].into_dyn();
let offsets = vec![vec![0, -1], vec![-1, 0]];
let components = agglomerate::<2>(&affinities, offsets, vec![], seeds);
let ids = components
.clone()
.into_iter()
.unique()
.collect::<Vec<usize>>();
for id in [1, 2, 4].iter() {
assert!(ids.contains(id), "{:?}", components);
}
assert!(ids.contains(&0), "{:?}", components);
assert!(ids.len() == 4, "{:?}", components);
assert!(
components.iter().counts().get(&0).unwrap() == &6,
"{:?}",
components
);
}

#[test]
fn test_cluster() {
let edges = vec![
Expand Down

0 comments on commit f29f4f6

Please sign in to comment.