diff --git a/.gitignore b/.gitignore
index 56794f8b..ff95ce2b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+docs/dest
+.DS_Store
+
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
diff --git a/docs/source/index.rst b/docs/source/index.rst
index f37bd5b0..395beb32 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -7,6 +7,7 @@
:caption: Examples:
examples/ctc
+ test_cases/test_description
.. toctree::
:maxdepth: 2
diff --git a/docs/source/test_cases/matched_graph/Empty Ground Truth.svg b/docs/source/test_cases/matched_graph/Empty Ground Truth.svg
new file mode 100644
index 00000000..d8a17ea2
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/Empty Ground Truth.svg
@@ -0,0 +1,993 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/Empty Prediction.svg b/docs/source/test_cases/matched_graph/Empty Prediction.svg
new file mode 100644
index 00000000..f0d89e84
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/Empty Prediction.svg
@@ -0,0 +1,991 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/False Negative Edge.svg b/docs/source/test_cases/matched_graph/False Negative Edge.svg
new file mode 100644
index 00000000..62416d33
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/False Negative Edge.svg
@@ -0,0 +1,1402 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/False Negative Node.svg b/docs/source/test_cases/matched_graph/False Negative Node.svg
new file mode 100644
index 00000000..70e21e0f
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/False Negative Node.svg
@@ -0,0 +1,1599 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/False Positive Edge.svg b/docs/source/test_cases/matched_graph/False Positive Edge.svg
new file mode 100644
index 00000000..12c1a3c3
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/False Positive Edge.svg
@@ -0,0 +1,1402 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/False Positive Node.svg b/docs/source/test_cases/matched_graph/False Positive Node.svg
new file mode 100644
index 00000000..b15162d4
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/False Positive Node.svg
@@ -0,0 +1,1563 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/Good Matching.svg b/docs/source/test_cases/matched_graph/Good Matching.svg
new file mode 100644
index 00000000..3f57c462
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/Good Matching.svg
@@ -0,0 +1,1064 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/One Ground Truth to Two Predictions.svg b/docs/source/test_cases/matched_graph/One Ground Truth to Two Predictions.svg
new file mode 100644
index 00000000..7bd3755e
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/One Ground Truth to Two Predictions.svg
@@ -0,0 +1,1595 @@
+
+
+
diff --git a/docs/source/test_cases/matched_graph/Two Ground Truth to One Prediction.svg b/docs/source/test_cases/matched_graph/Two Ground Truth to One Prediction.svg
new file mode 100644
index 00000000..51caa64c
--- /dev/null
+++ b/docs/source/test_cases/matched_graph/Two Ground Truth to One Prediction.svg
@@ -0,0 +1,1563 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/False Negative.svg b/docs/source/test_cases/segmentation/2d/False Negative.svg
new file mode 100644
index 00000000..7f15f32c
--- /dev/null
+++ b/docs/source/test_cases/segmentation/2d/False Negative.svg
@@ -0,0 +1,1119 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/False Positive.svg b/docs/source/test_cases/segmentation/2d/False Positive.svg
new file mode 100644
index 00000000..b105fc78
--- /dev/null
+++ b/docs/source/test_cases/segmentation/2d/False Positive.svg
@@ -0,0 +1,1072 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/Good Segmentation.png b/docs/source/test_cases/segmentation/2d/Good Segmentation.png
new file mode 100644
index 00000000..bd384dd9
Binary files /dev/null and b/docs/source/test_cases/segmentation/2d/Good Segmentation.png differ
diff --git a/docs/source/test_cases/segmentation/2d/Good Segmentation.svg b/docs/source/test_cases/segmentation/2d/Good Segmentation.svg
new file mode 100644
index 00000000..59fad995
--- /dev/null
+++ b/docs/source/test_cases/segmentation/2d/Good Segmentation.svg
@@ -0,0 +1,1161 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/Oversegmentation.svg b/docs/source/test_cases/segmentation/2d/Oversegmentation.svg
new file mode 100644
index 00000000..518c2b23
--- /dev/null
+++ b/docs/source/test_cases/segmentation/2d/Oversegmentation.svg
@@ -0,0 +1,1174 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/Undersegmentation.svg b/docs/source/test_cases/segmentation/2d/Undersegmentation.svg
new file mode 100644
index 00000000..596d6939
--- /dev/null
+++ b/docs/source/test_cases/segmentation/2d/Undersegmentation.svg
@@ -0,0 +1,1161 @@
+
+
+
diff --git a/docs/source/test_cases/segmentation/2d/good_seg.png b/docs/source/test_cases/segmentation/2d/good_seg.png
new file mode 100644
index 00000000..b68c7513
Binary files /dev/null and b/docs/source/test_cases/segmentation/2d/good_seg.png differ
diff --git a/docs/source/test_cases/test_description.rst b/docs/source/test_cases/test_description.rst
new file mode 100644
index 00000000..db256304
--- /dev/null
+++ b/docs/source/test_cases/test_description.rst
@@ -0,0 +1,33 @@
+Description of Unit Test Canonical Examples
+===========================================
+
+To facilitate testing, we have provided a suite of canonical examples
+that cover the basic, simple scenarios that can occur in segmentation and
+tracking. Here we describe them and show visualizations of each case.
+
+Matchers should test all the segmentation cases. Metrics should test all the
+tracking cases. The examples are generated by functions in the `tests/examples/`
+directory.
+
+Segmentation Canonical Examples
+-------------------------------
+
+.. image:: segmentation/2d/Good\ Segmentation.svg
+.. image:: segmentation/2d/False\ Negative.svg
+.. image:: segmentation/2d/False\ Positive.svg
+.. image:: segmentation/2d/Oversegmentation.svg
+.. image:: segmentation/2d/Undersegmentation.svg
+
+
+Matched Graph Canonical Examples
+--------------------------------
+
+.. image:: matched_graph/Good\ Matching.svg
+.. image:: matched_graph/False\ Negative\ Node.svg
+.. image:: matched_graph/False\ Negative\ Edge.svg
+.. image:: matched_graph/False\ Positive\ Node.svg
+.. image:: matched_graph/False\ Positive\ Edge.svg
+.. image:: matched_graph/Two\ Ground\ Truth\ to\ One\ Prediction.svg
+.. image:: matched_graph/One\ Ground\ Truth\ to\ Two\ Predictions.svg
+.. image:: matched_graph/Empty\ Ground\ Truth.svg
+.. image:: matched_graph/Empty\ Prediction.svg
\ No newline at end of file
diff --git a/docs/write_graph_test_cases.py b/docs/write_graph_test_cases.py
new file mode 100644
index 00000000..0359616a
--- /dev/null
+++ b/docs/write_graph_test_cases.py
@@ -0,0 +1,105 @@
+import sys
+
+from traccuracy._tracking_graph import TrackingGraph
+
+sys.path.append("../tests/examples")
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+from example_matched_graphs import (
+ empty_gt,
+ empty_pred,
+ fn_edge_matched,
+ fn_node_matched,
+ fp_edge_matched,
+ fp_node_matched,
+ good_matched,
+ one_to_two,
+ two_to_one,
+)
+from matplotlib.patches import Patch
+
+
+def get_loc(graph, node):
+ return graph.graph.nodes[node]["t"], graph.graph.nodes[node]["y"]
+
+
+def plot_graph(ax, graph: TrackingGraph, color="black"):
+ if graph.graph.number_of_nodes() == 0:
+ return 0
+ ids = list(graph.graph.nodes)
+ print(ids)
+ x = [graph.graph.nodes[node]["t"] for node in ids]
+ y = [graph.graph.nodes[node]["y"] for node in ids]
+ ax.scatter(x, y, color=color)
+ for _x, _y, _id in zip(x, y, ids):
+ ax.text(_x + 0.05, _y + 0.05, str(_id))
+
+ for u, v in graph.graph.edges():
+ print(u, v)
+ xs = [graph.graph.nodes[u]["t"], graph.graph.nodes[v]["t"]]
+ ys = [graph.graph.nodes[u]["y"], graph.graph.nodes[v]["y"]]
+ ax.plot(xs, ys, color=color)
+
+ return max(y)
+
+
+def plot_matching(ax, matched, color="grey"):
+ for u, v in matched.mapping:
+ xs = [
+ matched.gt_graph.graph.nodes[u]["t"],
+ matched.pred_graph.graph.nodes[v]["t"],
+ ]
+ ys = [
+ matched.gt_graph.graph.nodes[u]["y"],
+ matched.pred_graph.graph.nodes[v]["y"],
+ ]
+ ax.plot(xs, ys, color=color, linestyle="dashed")
+
+
+def save_matched(examples, title):
+ gt_color = "black"
+ pred_color = "blue"
+ mapping_color = "grey"
+ fig, ax = plt.subplots(1, len(examples) + 1, figsize=(3 * len(examples) + 1, 2))
+ for i, matched in enumerate(examples):
+ axis = ax[i]
+ maxY = plot_graph(axis, matched.gt_graph, color=gt_color)
+ maxY = max([maxY, plot_graph(axis, matched.pred_graph, color=pred_color)])
+ plot_matching(axis, matched, color=mapping_color)
+ axis.set_ybound(-0.5, maxY + 0.5)
+ axis.set_xbound(-0.5, 2.5)
+ axis.set_ylabel("Y Value")
+ axis.set_xlabel("Time")
+
+ handles = [
+ Patch(color=gt_color),
+ Patch(color=pred_color),
+ Patch(color=mapping_color),
+ ]
+ labels = ["Ground Truth", "Prediction", "Mapping"]
+ ax[-1].legend(handles=handles, labels=labels, loc="center")
+ ax[-1].set_axis_off()
+ fig.tight_layout()
+ fig.suptitle(title)
+ fig.savefig(outpath)
+
+
+if __name__ == "__main__":
+ graph_examples = {
+ "Empty Ground Truth": [empty_gt()],
+ "Empty Prediction": [empty_pred()],
+ "Good Matching": [good_matched()],
+ "False Negative Node": [fn_node_matched(t) for t in [0, 1, 2]],
+ "False Negative Edge": [fn_edge_matched(t) for t in [0, 1]],
+ "False Positive Node": [fp_node_matched(t) for t in [0, 1, 2]],
+ "False Positive Edge": [fp_edge_matched(t) for t in [0, 1]],
+ "Two Ground Truth to One Prediction": [two_to_one(t) for t in [0, 1, 2]],
+ "One Ground Truth to Two Predictions": [one_to_two(t) for t in [0, 1, 2]],
+ }
+ outdir = Path("source/test_cases/matched_graph/")
+ print(outdir.exists())
+ for name, matched in graph_examples.items():
+ print(name)
+ outpath = outdir / f"{name}.svg"
+ save_matched(matched, name)
diff --git a/docs/write_seg_test_cases.py b/docs/write_seg_test_cases.py
new file mode 100644
index 00000000..15683716
--- /dev/null
+++ b/docs/write_seg_test_cases.py
@@ -0,0 +1,51 @@
+import sys
+
+sys.path.append("../tests/examples")
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+from example_segmentations import (
+ false_negative_segmentation_2d,
+ false_positive_segmentation_2d,
+ good_segmentation_2d,
+ oversegmentation_2d,
+ undersegmentation_2d,
+)
+from matplotlib.colors import ListedColormap
+from matplotlib.patches import Patch
+
+
+def save_pair(gt, pred, outpath, title):
+ max_label = np.max([gt, pred])
+ colors = ["black", "red", "blue", "green"]
+ colormap = ListedColormap(colors)
+ fig, ax = plt.subplots(1, 2, figsize=(6, 4))
+ ax[0].imshow(gt, cmap=colormap, vmax=4)
+ ax[0].set_title("Ground Truth")
+ # ax[0].set_axis_off()
+ ax[1].imshow(pred, cmap=colormap, vmax=4)
+ ax[1].set_title("Predicted")
+
+ handles = [Patch(color=colors[i]) for i in range(1, max_label + 1)]
+ labels = [str(i) for i in range(1, max_label + 1)]
+ ax[1].legend(handles=handles, labels=labels, title="Label IDs", loc="upper right")
+ fig.suptitle(title)
+ fig.tight_layout()
+ fig.savefig(outpath)
+
+
+if __name__ == "__main__":
+ two_d_examples = {
+ "Good Segmentation": good_segmentation_2d(),
+ "False Positive": false_positive_segmentation_2d(),
+ "False Negative": false_negative_segmentation_2d(),
+ "Oversegmentation": oversegmentation_2d(),
+ "Undersegmentation": undersegmentation_2d(),
+ }
+ outdir = Path("source/test_cases/segmentation/2d")
+ print(outdir.exists())
+ for name, arrs in two_d_examples.items():
+ outpath = outdir / f"{name}.svg"
+ gt, pred = arrs
+ save_pair(gt, pred, outpath, name)
diff --git a/pyproject.toml b/pyproject.toml
index 4c19f8b1..314c9867 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -152,6 +152,9 @@ filterwarnings = [
"ignore:Numba not installed, falling back to slower:UserWarning",
]
addopts = ["--benchmark-min-rounds=1"]
+pythonpath = [
+ "tests"
+]
# https://mypy.readthedocs.io/en/stable/config_file.html
[tool.mypy]
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..a3e16e73
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,6 @@
+import pytest # noqa
+
+pytest_plugins = [
+ "examples.example_segmentations",
+ "examples.example_matched_graphs",
+]
diff --git a/tests/examples/example_matched_graphs.py b/tests/examples/example_matched_graphs.py
new file mode 100644
index 00000000..ce7eab59
--- /dev/null
+++ b/tests/examples/example_matched_graphs.py
@@ -0,0 +1,202 @@
+import networkx as nx
+
+from traccuracy._tracking_graph import TrackingGraph
+from traccuracy.matchers._base import Matched
+
+"""A set of fixtures covering basic graph matching cases over 3 time frames
+Covers edge cases, good matchings, fn nodes, fp nodes, two to one matchings in each
+direction (pred -> gt, gt -> pred), and divisions
+The type of mapping e.g. one to one or one to many is annotated in the docstring
+Please note that the names of the fixtures are just meant to be descriptive and may
+or may not match to a particular error type as described by a set of metrics
+
+example of how to use
+@pytest.mark.parametrize("i", [0, 1, 2], ids=["0", "1", "2"])
+def test_fn_node(i):
+ matched = fn_node_matched(i)
+ assert ...
+"""
+
+
+def basic_graph(node_ids=(1, 2, 3), y_offset=0, frame_key="t", location_keys=("y")):
+ nodes = [
+ (
+ node_ids[0],
+ {
+ frame_key: 0,
+ location_keys[0]: 0 + y_offset,
+ },
+ ),
+ (
+ node_ids[1],
+ {
+ frame_key: 1,
+ location_keys[0]: 0 + y_offset,
+ },
+ ),
+ (
+ node_ids[2],
+ {
+ frame_key: 2,
+ location_keys[0]: 0 + y_offset,
+ },
+ ),
+ ]
+
+ edges = [(node_ids[0], node_ids[1]), (node_ids[1], node_ids[2])]
+ graph = nx.DiGraph()
+ graph.add_nodes_from(nodes)
+ graph.add_edges_from(edges)
+
+ return TrackingGraph(graph, frame_key=frame_key, location_keys=location_keys)
+
+
+# edge cases
+def empty_pred():
+ gt = basic_graph()
+ pred = TrackingGraph(nx.DiGraph())
+ mapping = []
+ return Matched(gt, pred, mapping)
+
+
+def empty_gt():
+ pred = basic_graph()
+ gt = TrackingGraph(nx.DiGraph())
+ mapping = []
+ return Matched(gt, pred, mapping)
+
+
+# good
+def good_matched():
+ """one to one"""
+ gt = basic_graph()
+ pred = basic_graph(node_ids=(4, 5, 6), y_offset=1)
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ return Matched(gt, pred, mapping)
+
+
+# fn_node
+def fn_node_matched(time_to_drop): # 0, 1, or 2
+ """one to one"""
+ gt = basic_graph()
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=1)
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ pred.graph.remove_node(pred_node_ids[time_to_drop])
+ del mapping[time_to_drop]
+ return Matched(gt, pred, mapping)
+
+
+# fn_edge
+def fn_edge_matched(edge_to_drop): # 0 or 1
+ """one to one"""
+ gt = basic_graph()
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=1)
+ edge = (pred_node_ids[edge_to_drop], pred_node_ids[edge_to_drop + 1])
+ pred.graph.remove_edge(*edge)
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ return Matched(gt, pred, mapping)
+
+
+# fp_node
+def fp_node_matched(time_to_add): # 0, 1, or 2
+ """one to one"""
+ gt = basic_graph()
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=1)
+ pred.graph.add_node(7, **{"t": time_to_add, "x": time_to_add, "y": 2})
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ return Matched(gt, pred, mapping)
+
+
+# fp_edge
+def fp_edge_matched(edge_to_add): # 0 or 1
+ """one to one"""
+ gt = basic_graph()
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=1)
+ pred.graph.add_node(7, **{"t": edge_to_add, "y": 2})
+ pred.graph.add_node(8, **{"t": edge_to_add + 1, "y": 2})
+ pred.graph.add_edge(7, 8)
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ return Matched(gt, pred, mapping)
+
+
+# two pred to one gt (identity switch)
+def one_to_two(time): # 0, 1, or 2
+ """one to many"""
+ gt_node_ids = (1, 2, 3)
+ gt = basic_graph(node_ids=gt_node_ids, y_offset=1)
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=0)
+ pred.graph.add_node(7, **{"t": time, "y": 2})
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ if time == 1:
+ pred.graph.remove_edge(5, 6)
+ pred.graph.add_edge(7, 6)
+ pred.graph.nodes[6]["y"] = 2
+ mapping.append((gt_node_ids[time], 7))
+ return Matched(gt, pred, mapping)
+
+
+# two gt to one pred (non split vertex)
+def two_to_one(time): # 0, 1, or 2
+ """many to one"""
+ gt_node_ids = (1, 2, 3)
+ gt = basic_graph(node_ids=gt_node_ids, y_offset=0)
+ pred_node_ids = (4, 5, 6)
+ pred = basic_graph(node_ids=pred_node_ids, y_offset=1)
+ gt.graph.add_node(7, **{"t": time, "y": 2})
+ mapping = [(1, 4), (2, 5), (3, 6)]
+ if time == 1:
+ gt.graph.remove_edge(1, 2)
+ gt.graph.add_edge(1, 7)
+ gt.graph.nodes[1]["y"] = 2
+ mapping.append((7, pred_node_ids[time]))
+ return Matched(gt, pred, mapping)
+
+
+def get_division_graphs():
+ """
+ G1
+ 2_4
+ 1_0 -- 1_1 -- 1_2 -- 1_3 -<
+ 3_4
+ G2
+ 2_2 -- 2_3 -- 2_4
+ 1_0 -- 1_1 -<
+ 3_2 -- 3_3 -- 3_4
+ """
+
+ G1 = nx.DiGraph()
+ G1.add_edge("1_0", "1_1")
+ G1.add_edge("1_1", "1_2")
+ G1.add_edge("1_2", "1_3")
+ G1.add_edge("1_3", "2_4")
+ G1.add_edge("1_3", "3_4")
+
+ attrs = {}
+ for node in G1.nodes:
+ attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
+ nx.set_node_attributes(G1, attrs)
+
+ G2 = nx.DiGraph()
+ G2.add_edge("1_0", "1_1")
+ # Divide to generate 2 lineage
+ G2.add_edge("1_1", "2_2")
+ G2.add_edge("2_2", "2_3")
+ G2.add_edge("2_3", "2_4")
+ # Divide to generate 3 lineage
+ G2.add_edge("1_1", "3_2")
+ G2.add_edge("3_2", "3_3")
+ G2.add_edge("3_3", "3_4")
+
+ attrs = {}
+ for node in G2.nodes:
+ attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
+ nx.set_node_attributes(G2, attrs)
+
+ mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")]
+
+ return G1, G2, mapper
diff --git a/tests/examples/example_segmentations.py b/tests/examples/example_segmentations.py
new file mode 100644
index 00000000..4b0e61a7
--- /dev/null
+++ b/tests/examples/example_segmentations.py
@@ -0,0 +1,318 @@
+from typing import Any
+
+import numpy as np
+import pytest
+from skimage.draw import disk
+from skimage.measure import regionprops
+
+
+def make_one_cell_2d(
+ label: int = 1,
+ arr_shape: tuple[int, int] = (32, 32),
+ center: tuple[int, int] = (16, 16),
+ radius: int = 7,
+) -> np.ndarray:
+ """Create a 2D numpy array with a single circular cell.
+
+ Args:
+ label (int, optional): Value of mask in the foreground. Defaults to 1.
+ arr_shape (tuple[int, int], optional): The size of the numpy array to return.
+ Defaults to (32, 32).
+ center (tuple[int, int], optional): The center of the cell, in pixels.
+ Defaults to (16, 16).
+ radius (int, optional): The radius of the cell. Defaults to 7.
+
+ Returns:
+ np.array: A numpy array with a circle at the given center with the
+ given label.
+ """
+ im = np.zeros(arr_shape, dtype="int32")
+ rr, cc = disk(center, radius, shape=arr_shape)
+ im[rr, cc] = label
+ return im
+
+
+def make_split_cell_2d(
+ labels=(1, 2), arr_shape=(32, 32), center=(16, 16), radius=9
+) -> np.ndarray:
+ """Create a 2d numpy array with two cells, each half a circle.
+
+ Args:
+ labels (tuple, optional): _description_. Defaults to (1, 2).
+ arr_shape (tuple, optional): _description_. Defaults to (32, 32).
+ center (tuple, optional): _description_. Defaults to (16, 16).
+ radius (int, optional): _description_. Defaults to 7.
+
+ Returns:
+ np.ndarray : A numpy array with two half circles with the given labels.
+ The pixels with y value greater than center will be the second label
+ color, and those with y value lass than or equal to the center
+ will have the first label.
+ """
+ im = np.zeros(arr_shape, dtype="int32")
+ rr, cc = disk(center, radius, shape=arr_shape)
+ im[rr, cc] = labels[0]
+ # get indices where y value greater than center
+ mask = cc > center[1]
+ im[rr[mask], cc[mask]] = labels[1]
+ return im
+
+
+def sphere(
+ center: tuple[int, int, int], radius: int, shape: tuple[int, int, int]
+) -> np.ndarray:
+ """Get a mask of a sphere of a given radius
+
+ Args:
+ center (tuple[int, int, int]): The coordinate of the center of the sphere.
+ radius (int): The radius of the sphere
+ shape (tuple[int, int, int]): The share of the numpy array mask to return.
+
+ Returns:
+ np.ndarray: A boolean array with 1s inside the sphere and 0s outside.
+ """
+ assert len(center) == len(shape)
+ indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index
+ distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1)
+ mask = distance <= radius
+ return mask
+
+
+def make_one_cell_3d(
+ label=1, arr_shape=(32, 32, 32), center=(16, 16, 16), radius=7
+) -> np.ndarray:
+ """Make a numpy array containing a single (spherical) cell in 3d.
+
+ Args:
+ label (int, optional): _description_. Defaults to 1.
+ arr_shape (tuple, optional): _description_. Defaults to (32, 32, 32).
+ center (tuple, optional): _description_. Defaults to (16, 16, 16).
+ radius (int, optional): _description_. Defaults to 7.
+
+ Returns:
+ np.ndarray: A numpy array of the given shape containing a sphere
+ with the given label, radius, and center.
+
+ """
+ im = np.zeros(arr_shape, dtype="int32")
+ mask = sphere(center, radius, shape=arr_shape)
+ im[mask] = label
+ return im
+
+
+def make_split_cell_3d(
+ labels=(1, 2), arr_shape=(32, 32, 32), center=(16, 16, 16), radius=9
+):
+ """Make a numpy array containing two cells, each half a sphere.
+ The pixels with y value less than or equal to the center y value will have
+ the first label, and those with y value greater than the center will
+ have the second label
+
+ Args:
+ labels (tuple, optional): _description_. Defaults to (1, 2).
+ arr_shape (tuple, optional): _description_. Defaults to (32, 32, 32).
+ center (tuple, optional): _description_. Defaults to (16, 16, 16).
+ radius (int, optional): _description_. Defaults to 9.
+
+ Returns:
+ np.ndarray: A numpy array of the given shape containing a sphere
+ with the given radius, and center. Half the sphere has the first,
+ label, and the other half has the second label.
+ """
+ im = np.zeros(arr_shape, dtype="int32")
+ mask = sphere(center, radius, shape=arr_shape)
+ im[mask] = labels[0]
+ # get indices where y value greater than center
+ mask[:, 0 : center[1]] = 0
+ im[mask] = labels[1]
+ return im
+
+
+### CANONICAL 2D SEGMENTATION EXAMPLES ###
+@pytest.fixture()
+def good_segmentation_2d() -> tuple[np.ndarray, np.ndarray]:
+ """A pretty good (but not perfect) pair of segmentations in 2d.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The segmentations are circles of the same size with
+ a slight offset in x and y.
+ """
+ gt = make_one_cell_2d(label=1, center=(15, 15), radius=9)
+ pred = make_one_cell_2d(label=2, center=(17, 17), radius=9)
+ return gt, pred
+
+
+@pytest.fixture()
+def false_positive_segmentation_2d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt is empty and the prediction has a
+ single cell.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The gt is empty and the prediction has a single cell.
+ """
+ gt = np.zeros((32, 32), dtype="int32")
+ pred = make_one_cell_2d(label=1, center=(17, 17), radius=9)
+ return gt, pred
+
+
+@pytest.fixture()
+def false_negative_segmentation_2d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has a single cell and the
+ prediction is empty.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The pred is empty and the gt has a single cell.
+ """
+ gt = make_one_cell_2d(label=1, center=(15, 15), radius=9)
+ pred = np.zeros((32, 32), dtype="int32")
+ return gt, pred
+
+
+@pytest.fixture()
+def oversegmentation_2d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has a single cell and the prediction
+ splits that into two cells.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations.
+ The gt has a single circle labeled and the pred splits that circle
+ into two labels.
+ """
+ gt = make_one_cell_2d(label=1, center=(16, 16), radius=9)
+ pred = make_split_cell_2d(labels=(2, 3), center=(16, 16), radius=9)
+ return gt, pred
+
+
+@pytest.fixture()
+def undersegmentation_2d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has two cells and the prediction
+ merges them into one circular cell.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations.
+ The pred has a single merged circle labeled and the gt has two labels,
+ each half of the circle.
+ """
+ gt = make_split_cell_2d(labels=(1, 2), center=(16, 16), radius=9)
+ pred = make_one_cell_2d(label=3, center=(16, 16), radius=9)
+ return gt, pred
+
+
+### CANONICAL 3D SEGMENTATION EXAMPLES ###
+@pytest.fixture()
+def good_segmentation_3d() -> tuple[np.ndarray, np.ndarray]:
+ """A pretty good (but not perfect) pair of segmentations in 3d.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The segmentations are circles of the same size with
+ a slight offset in x and y.
+ """
+ gt = make_one_cell_3d(label=1, center=(15, 15, 15), radius=9)
+ pred = make_one_cell_3d(label=2, center=(17, 17, 17), radius=9)
+ return gt, pred
+
+
+@pytest.fixture
+def false_positive_segmentation_3d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt is empty and the prediction has a
+ single cell.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The gt is empty and the prediction has a single cell.
+ """
+ gt = np.zeros((32, 32, 32), dtype="int32")
+ pred = make_one_cell_3d(label=1, center=(17, 17, 17), radius=9)
+ return gt, pred
+
+
+@pytest.fixture()
+def false_negative_segmentation_3d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has a single cell and the
+ prediction is empty.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations of
+ a single cell. The pred is empty and the gt has a single cell.
+ """
+ gt = make_one_cell_3d(label=1, center=(15, 15, 15), radius=9)
+ pred = np.zeros((32, 32), dtype="int32")
+ return gt, pred
+
+
+@pytest.fixture()
+def oversegmentation_3d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has a single cell and the prediction
+ splits that into two cells.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations.
+ The gt has a single circle labeled and the pred splits that circle
+ into two labels.
+ """
+ gt = make_one_cell_3d(label=1, center=(16, 16, 16), radius=9)
+ pred = make_split_cell_3d(labels=(2, 3), center=(16, 16, 16), radius=9)
+ return gt, pred
+
+
+@pytest.fixture()
+def undersegmentation_3d() -> tuple[np.ndarray, np.ndarray]:
+ """A pair of segmentations where the gt has two cells and the prediction
+ merges them into one circular cell.
+
+ Returns:
+ tuple[np.ndarray, np.ndarray]: A pair of (gt, pred) segmentations.
+ The pred has a single merged circle labeled and the gt has two labels,
+ each half of the circle.
+ """
+ gt = make_split_cell_3d(labels=(1, 2), center=(16, 16, 16), radius=9)
+ pred = make_one_cell_3d(label=3, center=(16, 16, 16), radius=9)
+ return gt, pred
+
+
+def nodes_from_segmentation(
+ seg: np.ndarray,
+ frame: int,
+ pos_keys=("y", "x"),
+ frame_key="t",
+ label_key="label_id",
+) -> dict[Any, dict]:
+ """Extract candidate nodes from a segmentation. Also computes specified attributes.
+ Returns a networkx graph with only nodes, and also a dictionary from frames to
+ node_ids for efficient edge adding.
+
+ Args:
+ segmentation (np.ndarray): A numpy array with integer labels, representing one time
+ frame.
+ frame (int): The time frame of this array. Used for making node ids and
+ for populating the attributes dict.
+ pos_keys (tuple[str]): The attribute keys to use to store the positions.
+ frame_key (str, optional): The frame key to use in the attributes dict.
+ Defaults to "t".
+ label_key (str, optional): The label key to use in the attributes dict.
+ Defaults to "label_id"
+
+ Returns:
+ dict[Any, dict]: A dictionary from node_ids to node attributes, which
+ can be used to create a networkx graph using add_nodes_from().
+ Node Ids are currently "_