From 6d93f961c06df555f713fc6330a5c4636fbebcd7 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 3 Nov 2023 13:34:31 -0700 Subject: [PATCH 01/56] Refactor matching into separate Matcher and Matched classes --- src/traccuracy/_run_metrics.py | 6 +-- src/traccuracy/cli.py | 16 +++--- src/traccuracy/matchers/__init__.py | 6 +-- src/traccuracy/matchers/_ctc.py | 55 ++++++++++--------- src/traccuracy/matchers/_iou.py | 27 ++++++---- src/traccuracy/matchers/_matched.py | 77 ++++++++++++++++++++------- tests/bench.py | 10 ++-- tests/matchers/test_ctc.py | 10 ++-- tests/matchers/test_iou.py | 9 ++-- tests/metrics/test_ctc_metrics.py | 4 +- tests/metrics/test_divisions.py | 13 +---- tests/track_errors/test_ctc_errors.py | 44 +++++++-------- tests/track_errors/test_divisions.py | 58 +++++++------------- 13 files changed, 172 insertions(+), 163 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 9c643ca9..2c120c52 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -6,14 +6,14 @@ from typing import Dict, List, Optional, Type from traccuracy import TrackingGraph - from traccuracy.matchers._matched import Matched + from traccuracy.matchers._matched import Matcher from traccuracy.metrics._base import Metric def run_metrics( gt_data: "TrackingGraph", pred_data: "TrackingGraph", - matcher: "Type[Matched]", + matcher: "Type[Matcher]", metrics: "List[Type[Metric]]", matcher_kwargs: "Optional[Dict]" = None, metrics_kwargs: "Optional[Dict]" = None, # weights @@ -42,7 +42,7 @@ def run_metrics( """ if matcher_kwargs is None: matcher_kwargs = {} - matched = matcher(gt_data, pred_data, **matcher_kwargs) + matched = matcher(**matcher_kwargs).compute_mapping(gt_data, pred_data) validate_matched_data(matched, metrics) metric_kwarg_dict = { m_class: get_relevant_kwargs(m_class, metrics_kwargs) for m_class in metrics diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index a7ce4c00..4a5d9ff8 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -49,7 +49,7 @@ def run_ctc( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import CTCMetrics if loader != "ctc": @@ -57,7 +57,7 @@ def run_ctc( f"Only cell tracking challenge (ctc) loader is available, but {loader} was passed." ) gt_data, pred_data = load_all_ctc(gt_dir, pred_dir, gt_track_path, pred_track_path) - result = run_metrics(gt_data, pred_data, CTCMatched, [CTCMetrics]) + result = run_metrics(gt_data, pred_data, CTCMatcher, [CTCMetrics]) with open(out_path, "w") as fp: json.dump(result, fp) logger.info(f'TRA: {result["CTCMetrics"]["TRA"]}') @@ -109,7 +109,7 @@ def run_aogm( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import AOGMMetrics if loader != "ctc": @@ -120,7 +120,7 @@ def run_aogm( result = run_metrics( gt_data, pred_data, - CTCMatched, + CTCMatcher, [AOGMMetrics], metrics_kwargs={ "vertex_ns_weight": vertex_ns_weight, @@ -173,7 +173,7 @@ def run_divisions_on_iou( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import IOUMatched + from traccuracy.matchers import IOUMatcher from traccuracy.metrics import DivisionMetrics if loader != "ctc": @@ -185,7 +185,7 @@ def run_divisions_on_iou( result = run_metrics( gt_data, pred_data, - IOUMatched, + IOUMatcher, [DivisionMetrics], matcher_kwargs={"iou_threshold": match_threshold}, metrics_kwargs={ @@ -232,7 +232,7 @@ def run_divisions_on_ctc( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import DivisionMetrics if loader != "ctc": @@ -244,7 +244,7 @@ def run_divisions_on_ctc( result = run_metrics( gt_data, pred_data, - CTCMatched, + CTCMatcher, [DivisionMetrics], metrics_kwargs={ "frame_buffer": frame_buffer_tuple, diff --git a/src/traccuracy/matchers/__init__.py b/src/traccuracy/matchers/__init__.py index 91e46d58..e99c7ce4 100644 --- a/src/traccuracy/matchers/__init__.py +++ b/src/traccuracy/matchers/__init__.py @@ -26,7 +26,7 @@ write a matching function that matches two arbitrary tracking solutions. """ from ._compute_overlap import get_labels_with_overlap -from ._ctc import CTCMatched -from ._iou import IOUMatched +from ._ctc import CTCMatcher +from ._iou import IOUMatcher -__all__ = ["CTCMatched", "IOUMatched", "get_labels_with_overlap"] +__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap"] diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index ed4980bb..3920af60 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -1,48 +1,46 @@ +from typing import TYPE_CHECKING + import networkx as nx import numpy as np from tqdm import tqdm -from traccuracy._tracking_graph import TrackingGraph +if TYPE_CHECKING: + from traccuracy._tracking_graph import TrackingGraph from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched +from ._matched import Matched, Matcher + +class CTCMatcher(Matcher): + """Match graph nodes based on measure used in cell tracking challenge benchmarking. -class CTCMatched(Matched): - def compute_mapping(self): - mapping = self._match_ctc() - return mapping + A computed marker (segmentation) is matched to a reference marker if the computed + marker covers a majority of the reference marker. - def _match_ctc(self): - """Match graph nodes based on measure used in cell tracking challenge benchmarking. + Each reference marker can therefore only be matched to one computed marker, but + multiple reference markers can be assigned to a single computed marker. - A computed marker (segmentation) is matched to a reference marker if the computed - marker covers a majority of the reference marker. + See https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0144959 + for complete details. + """ - Each reference marker can therefore only be matched to one computed marker, but - multiple reference markers can be assigned to a single computed marker. + def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Run ctc matching - See https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0144959 - for complete details. + Args: + gt_graph (TrackingGraph): Tracking graph object for the gt + pred_graph (TrackingGraph): Tracking graph object for the pred Returns: - list[(gt_node, pred_node)]: list of tuples where each tuple contains a gt node - and pred node + Matched: Matched data object containing the CTC mapping Raises: - ValueError: gt and pred must be a TrackingGraph object ValueError: GT and pred segmentations must be the same shape """ - if not isinstance(self.gt_graph, TrackingGraph) or not isinstance( - self.pred_graph, TrackingGraph - ): - raise ValueError( - "Input data must be a TrackingData object with a graph and segmentations" - ) - gt = self.gt_graph - pred = self.pred_graph - gt_label_key = self.gt_graph.label_key - pred_label_key = self.pred_graph.label_key + gt = gt_graph + pred = pred_graph + gt_label_key = gt_graph.label_key + pred_label_key = pred_graph.label_key G_gt, mask_gt = gt, gt.segmentation G_pred, mask_pred = pred, pred.segmentation @@ -93,7 +91,8 @@ def _match_ctc(self): mapping.append( (gt_label_to_id[gt_label], pred_label_to_id[pred_label]) ) - return mapping + + return Matched(gt_graph, pred_graph, mapping) def detection_test(gt_blob: "np.ndarray", comp_blob: "np.ndarray") -> int: diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 374291aa..5667a586 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -4,7 +4,7 @@ from traccuracy._tracking_graph import TrackingGraph from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched +from ._matched import Matched, Matcher def _match_nodes(gt, res, threshold=1): @@ -98,29 +98,36 @@ def match_iou(gt, pred, threshold=0.6): return mapper -class IOUMatched(Matched): - def __init__(self, gt_graph, pred_graph, iou_threshold=0.6): +class IOUMatcher(Matcher): + def __init__(self, iou_threshold=0.6): """Constructs a mapping between gt and pred nodes using the IoU of the segmentations Lower values for iou_threshold will be more permissive of imperfect matches Args: - gt_graph (TrackingGraph): TrackingGraph for the ground truth with segmentations - pred_graph (TrackingGraph): TrackingGraph for the prediction with segmentations iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6. + """ + self.iou_threshold = iou_threshold + + def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Computes IOU mapping for a set of grpahs + + Args: + gt_graph (TrackingGraph): Tracking graph object for the gt with segmentation data + pred_graph (TrackingGraph): Tracking graph object for the pred with segmentation data Raises: ValueError: Segmentation data must be provided for both gt and pred data - """ - self.iou_threshold = iou_threshold + Returns: + Matched: Matched data object containing IOU mapping + """ # Check that segmentations exist in the data if gt_graph.segmentation is None or pred_graph.segmentation is None: raise ValueError( "Segmentation data must be provided for both gt and pred data" ) - super().__init__(gt_graph, pred_graph) + mapping = match_iou(gt_graph, pred_graph, threshold=self.iou_threshold) - def compute_mapping(self): - return match_iou(self.gt_graph, self.pred_graph, threshold=self.iou_threshold) + return Matched(gt_graph, pred_graph, mapping) diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 85399848..9cee8d02 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -1,45 +1,82 @@ +import copy import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import Any -if TYPE_CHECKING: - from traccuracy._tracking_graph import TrackingGraph +from traccuracy._tracking_graph import TrackingGraph logger = logging.getLogger(__name__) -class Matched(ABC): - def __init__(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): - """Matched class which takes TrackingData objects for gt and pred, and computes matching. +class Matcher(ABC): + """The Matcher base class provides a wrapper around the compute_mapping method - Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. - The Matched objects will store both gt and pred data, as well as the mapping, - and any additional private attributes that may be needed/used e.g. detection matrices. + Each Matcher subclass will implement its own kwargs as needed. + In use, the Matcher object will be initialized with kwargs prior to running compute_mapping + on a particular dataset + """ + + def compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Run the matching on a given set of gt and pred TrackingGraph and returns a Matched object + with a new copy of each TrackingGraph Args: gt_graph (TrackingGraph): Tracking graph object for the gt pred_graph (TrackingGraph): Tracking graph object for the pred + + Returns: + matched (Matched): Matched data object + + Raises: + ValueError: gt and pred must be a TrackingGraph object """ - self.gt_graph = gt_graph - self.pred_graph = pred_graph + if not isinstance(gt_graph, TrackingGraph) or not isinstance( + pred_graph, TrackingGraph + ): + raise ValueError( + "Input data must be a TrackingData object with a graph and segmentations" + ) - self.mapping = self.compute_mapping() + matched = self._compute_mapping(gt_graph, pred_graph) # Report matching performance - total_gt = len(self.gt_graph.nodes()) - matched_gt = len({m[0] for m in self.mapping}) - total_pred = len(self.pred_graph.nodes()) - matched_pred = len({m[1] for m in self.mapping}) + total_gt = len(matched.gt_graph.nodes()) + matched_gt = len({m[0] for m in matched.mapping}) + total_pred = len(matched.pred_graph.nodes()) + matched_pred = len({m[1] for m in matched.mapping}) logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.") logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.") - @abstractmethod - def compute_mapping(self): - """Computes a mapping of nodes in gt to nodes in pred + return matched - The mapping must be a list of tuples, e.g. [(gt_node, pred_node)] + @abstractmethod + def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Computes a mapping of nodes in gt to nodes in pred and returns a Matched object Raises: NotImplementedError """ raise NotImplementedError + + +class Matched: + """Matched data class which stores TrackingGraph objects for gt and pred + and the computed mapping + + Each TrackingGraph will be a new copy on the original object + + Args: + gt_graph (TrackingGraph): Tracking graph object for the gt + pred_graph (TrackingGraph): Tracking graph object for the pred + + """ + + def __init__( + self, + gt_graph: "TrackingGraph", + pred_graph: "TrackingGraph", + mapping: list[tuple[Any, Any]], + ): + self.gt_graph = copy.deepcopy(gt_graph) + self.pred_graph = copy.deepcopy(pred_graph) + self.mapping = mapping diff --git a/tests/bench.py b/tests/bench.py index 6068e459..34e8a1be 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -5,7 +5,7 @@ import pytest from traccuracy.loaders import load_ctc_data -from traccuracy.matchers import CTCMatched, IOUMatched +from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -51,12 +51,12 @@ def pred_data(): @pytest.fixture(scope="module") def ctc_matched(gt_data, pred_data): - return CTCMatched(gt_data, pred_data) + return CTCMatcher().compute_mapping(gt_data, pred_data) @pytest.fixture(scope="module") def iou_matched(gt_data, pred_data): - return IOUMatched(gt_data, pred_data, iou_threshold=0.1) + return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data, pred_data) def test_load_gt_data(benchmark): @@ -88,7 +88,7 @@ def test_load_pred_data(benchmark): def test_ctc_matched(benchmark, gt_data, pred_data): - benchmark(CTCMatched, gt_data, pred_data) + benchmark(CTCMatcher().compute_mapping, gt_data, pred_data) @pytest.mark.timeout(300) @@ -118,7 +118,7 @@ def run_compute(): def test_iou_matched(benchmark, gt_data, pred_data): - benchmark(IOUMatched, gt_data, pred_data, iou_threshold=0.5) + benchmark(IOUMatcher(iou_threshold=0.1).compute_mapping, gt_data, pred_data) def test_iou_div_metrics(benchmark, iou_matched): diff --git a/tests/matchers/test_ctc.py b/tests/matchers/test_ctc.py index 02040098..e210d79c 100644 --- a/tests/matchers/test_ctc.py +++ b/tests/matchers/test_ctc.py @@ -2,19 +2,17 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._ctc import CTCMatched +from traccuracy.matchers._ctc import CTCMatcher from tests.test_utils import get_annotated_movie def test_match_ctc(): - # Bad input - with pytest.raises(ValueError): - CTCMatched("not tracking data", "not tracking data") + matcher = CTCMatcher() # shapes don't match with pytest.raises(ValueError): - CTCMatched( + matcher.compute_mapping( TrackingGraph(nx.DiGraph(), segmentation=np.zeros((5, 10, 10))), TrackingGraph(nx.DiGraph(), segmentation=np.zeros((5, 10, 5))), ) @@ -37,7 +35,7 @@ def test_match_ctc(): attrs[f"{i}_{t}"] = {"t": t, "y": 0, "x": 0, "segmentation_id": i} nx.set_node_attributes(g, attrs) - matched = CTCMatched( + matched = matcher.compute_mapping( TrackingGraph(g, segmentation=movie), TrackingGraph(g, segmentation=movie), ) diff --git a/tests/matchers/test_iou.py b/tests/matchers/test_iou.py index 2cfa171e..edf942b7 100644 --- a/tests/matchers/test_iou.py +++ b/tests/matchers/test_iou.py @@ -2,7 +2,7 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._iou import IOUMatched, _match_nodes, match_iou +from traccuracy.matchers._iou import IOUMatcher, _match_nodes, match_iou from tests.test_utils import get_annotated_image, get_movie_with_graph @@ -68,8 +68,10 @@ def test__init__(self): track_graph = get_movie_with_graph() data = TrackingGraph(track_graph.graph) + matcher = IOUMatcher() + with pytest.raises(ValueError): - IOUMatched(data, data) + matcher.compute_mapping(data, data) def test_compute_mapping(self): # Test 2d data @@ -79,7 +81,8 @@ def test_compute_mapping(self): ndims=3, n_frames=n_frames, n_labels=n_labels ) - matched = IOUMatched(gt_graph=track_graph, pred_graph=track_graph) + matcher = IOUMatcher() + matched = matcher.compute_mapping(gt_graph=track_graph, pred_graph=track_graph) # Check for correct number of pairs assert len(matched.mapping) == n_frames * n_labels diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index fb5bd601..7ca4942c 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,4 +1,4 @@ -from traccuracy.matchers._ctc import CTCMatched +from traccuracy.matchers._ctc import CTCMatcher from traccuracy.metrics._ctc import CTCMetrics from tests.test_utils import get_movie_with_graph @@ -10,7 +10,7 @@ def test_compute_mapping(): n_labels = 3 track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) - matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph) + matched = CTCMatcher().compute_mapping(gt_graph=track_graph, pred_graph=track_graph) metric = CTCMetrics(matched) assert metric.results assert "TRA" in metric.results diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index e2679115..d46537b4 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -5,21 +5,12 @@ from tests.test_utils import get_division_graphs -class DummyMatched(Matched): - def __init__(self, gt_data, pred_data, mapper): - self.mapper = mapper - super().__init__(gt_data, pred_data) - - def compute_mapping(self): - return self.mapper - - def test_DivisionMetrics(): g_gt, g_pred, mapper = get_division_graphs() - matched = DummyMatched( + matched = Matched( TrackingGraph(g_gt), TrackingGraph(g_pred), - mapper=mapper, + mapper, ) frame_buffer = (0, 1, 2) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index ceba4989..98be167c 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -5,11 +5,6 @@ from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors -class DummyMatched(Matched): - def compute_mapping(self): - return [] - - def test_get_vertex_errors(): comp_ids = [3, 7, 10] comp_ids_2 = list(np.asarray(comp_ids) + 1) @@ -39,27 +34,26 @@ def test_get_vertex_errors(): ) G_comp = TrackingGraph(comp_g) - matched_data = DummyMatched(G_gt, G_comp) - matched_data.mapping = mapping + matched_data = Matched(G_gt, G_comp, mapping) get_vertex_errors(matched_data) - assert len(G_comp.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 - assert len(G_comp.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 - assert len(G_comp.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 - assert len(G_gt.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 - assert gt_g.nodes[15][NodeAttr.FALSE_NEG] - assert not gt_g.nodes[17][NodeAttr.FALSE_NEG] + assert matched_data.gt_graph.graph.nodes[15][NodeAttr.FALSE_NEG] + assert not matched_data.gt_graph.graph.nodes[17][NodeAttr.FALSE_NEG] - assert comp_g.nodes[3][NodeAttr.NON_SPLIT] - assert not comp_g.nodes[7][NodeAttr.NON_SPLIT] + assert matched_data.pred_graph.graph.nodes[3][NodeAttr.NON_SPLIT] + assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.NON_SPLIT] - assert comp_g.nodes[7][NodeAttr.TRUE_POS] - assert not comp_g.nodes[3][NodeAttr.TRUE_POS] + assert matched_data.pred_graph.graph.nodes[7][NodeAttr.TRUE_POS] + assert not matched_data.pred_graph.graph.nodes[3][NodeAttr.TRUE_POS] - assert comp_g.nodes[10][NodeAttr.FALSE_POS] - assert not comp_g.nodes[7][NodeAttr.FALSE_POS] + assert matched_data.pred_graph.graph.nodes[10][NodeAttr.FALSE_POS] + assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.FALSE_POS] def test_assign_edge_errors(): @@ -95,13 +89,12 @@ def test_assign_edge_errors(): ) G_gt = TrackingGraph(gt_g) - matched_data = DummyMatched(G_gt, G_comp) - matched_data.mapping = mapping + matched_data = Matched(G_gt, G_comp, mapping) get_edge_errors(matched_data) - assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS] - assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG] + assert matched_data.pred_graph.graph.edges[(7, 8)][EdgeAttr.FALSE_POS] + assert matched_data.gt_graph.graph.edges[(17, 18)][EdgeAttr.FALSE_NEG] def test_assign_edge_errors_semantics(): @@ -136,9 +129,8 @@ def test_assign_edge_errors_semantics(): # Define mapping with all nodes matching except for 2_3 in comp mapping = [(n, n) for n in gt.nodes] - matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp)) - matched_data.mapping = mapping + matched_data = Matched(TrackingGraph(gt), TrackingGraph(comp), mapping) get_edge_errors(matched_data) - assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] + assert matched_data.pred_graph.graph.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 0a10bf66..bbb937fc 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -14,11 +14,6 @@ from tests.test_utils import get_division_graphs -class DummyMatched(Matched): - def compute_mapping(self): - return [] - - @pytest.fixture def g(): """ @@ -51,22 +46,19 @@ def g(): def test_classify_divisions_tp(g): # Define mapper assuming all nodes match mapper = [(n, n) for n in g.nodes] - g_gt = TrackingGraph(g.copy()) - g_pred = TrackingGraph(g.copy()) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g.copy()), TrackingGraph(g.copy()), mapper) # Test true positive _classify_divisions(matched_data) - assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 + assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] + assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] # Check division flag - assert g_gt.division_annotations - assert g_pred.division_annotations + assert matched_data.gt_graph.division_annotations + assert matched_data.pred_graph.division_annotations def test_classify_divisions_fp(g): @@ -84,17 +76,14 @@ def test_classify_divisions_fp(g): nx.set_node_attributes(h, {"5_3": {"t": 3, "x": 0, "y": 0}}) mapper = [(n, n) for n in h.nodes] - g_gt = TrackingGraph(g) - g_pred = TrackingGraph(h) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g), TrackingGraph(h), mapper) _classify_divisions(matched_data) - assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert NodeAttr.FP_DIV in g_pred.nodes()["1_2"] - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 + assert NodeAttr.FP_DIV in matched_data.pred_graph.nodes()["1_2"] + assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] + assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] def test_classify_divisions_fn(g): @@ -107,16 +96,13 @@ def test_classify_divisions_fn(g): h.remove_nodes_from(["3_3", "4_3"]) mapper = [(n, n) for n in h.nodes] - g_gt = TrackingGraph(g) - g_pred = TrackingGraph(h) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g), TrackingGraph(h), mapper) _classify_divisions(matched_data) - assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 - assert NodeAttr.FN_DIV in g_gt.nodes()["2_2"] + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 + assert NodeAttr.FN_DIV in matched_data.gt_graph.nodes()["2_2"] @pytest.fixture @@ -177,8 +163,7 @@ def test_no_change(self): g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 1, no change new_matched = _correct_shifted_divisions(matched_data, n_frames=1) @@ -195,8 +180,7 @@ def test_fn_early(self): g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 3, corrections new_matched = _correct_shifted_divisions(matched_data, n_frames=3) @@ -214,8 +198,7 @@ def test_fp_early(self): g_pred.nodes["1_1"][NodeAttr.FP_DIV] = True g_gt.nodes["1_3"][NodeAttr.FN_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 3, corrections new_matched = _correct_shifted_divisions(matched_data, n_frames=3) @@ -232,8 +215,7 @@ def test_evaluate_division_events(): g_gt, g_pred, mapper = get_division_graphs() frame_buffer = (0, 1, 2) - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) results = _evaluate_division_events(matched_data, frame_buffer=frame_buffer) From 8cb0794e19595ec4ac21458d7a1c37903e697734 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 3 Nov 2023 14:43:20 -0700 Subject: [PATCH 02/56] Add from future import annotations for matched --- src/traccuracy/matchers/_matched.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 9cee8d02..04041fb2 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import logging from abc import ABC, abstractmethod @@ -16,7 +18,7 @@ class Matcher(ABC): on a particular dataset """ - def compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Run the matching on a given set of gt and pred TrackingGraph and returns a Matched object with a new copy of each TrackingGraph @@ -50,7 +52,7 @@ def compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph" return matched @abstractmethod - def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Computes a mapping of nodes in gt to nodes in pred and returns a Matched object Raises: @@ -73,8 +75,8 @@ class Matched: def __init__( self, - gt_graph: "TrackingGraph", - pred_graph: "TrackingGraph", + gt_graph: TrackingGraph, + pred_graph: TrackingGraph, mapping: list[tuple[Any, Any]], ): self.gt_graph = copy.deepcopy(gt_graph) From 7b4156280613216c70c7729026c5617912d609e3 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 3 Nov 2023 14:49:29 -0700 Subject: [PATCH 03/56] Fix docs build by specifying traccuracy.TrackingGraph in docstring --- src/traccuracy/matchers/_matched.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 04041fb2..3682908d 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -23,8 +23,8 @@ def compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): with a new copy of each TrackingGraph Args: - gt_graph (TrackingGraph): Tracking graph object for the gt - pred_graph (TrackingGraph): Tracking graph object for the pred + gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt + pred_graph (traccuracy.TrackingGraph): Tracking graph object for the pred Returns: matched (Matched): Matched data object @@ -68,8 +68,8 @@ class Matched: Each TrackingGraph will be a new copy on the original object Args: - gt_graph (TrackingGraph): Tracking graph object for the gt - pred_graph (TrackingGraph): Tracking graph object for the pred + gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt + pred_graph (traccuracy.TrackingGraph): Tracking graph object for the pred """ From 847858407bc60fe4865895b5a28db703585797a8 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 3 Nov 2023 16:55:45 -0700 Subject: [PATCH 04/56] Add return typing to Matcher methods --- src/traccuracy/matchers/_matched.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 3682908d..b7d33ac4 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -18,7 +18,9 @@ class Matcher(ABC): on a particular dataset """ - def compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): + def compute_mapping( + self, gt_graph: TrackingGraph, pred_graph: TrackingGraph + ) -> Matched: """Run the matching on a given set of gt and pred TrackingGraph and returns a Matched object with a new copy of each TrackingGraph @@ -52,7 +54,9 @@ def compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): return matched @abstractmethod - def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): + def _compute_mapping( + self, gt_graph: TrackingGraph, pred_graph: TrackingGraph + ) -> Matched: """Computes a mapping of nodes in gt to nodes in pred and returns a Matched object Raises: From 98b26fc42ca22f8598e47746d91129e9b6f619b4 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 3 Nov 2023 17:15:21 -0700 Subject: [PATCH 05/56] Refactor Metrics class to pass data into compute method not constructor --- src/traccuracy/_run_metrics.py | 4 +-- src/traccuracy/metrics/_base.py | 26 ++++++---------- src/traccuracy/metrics/_ctc.py | 29 ++++++++---------- src/traccuracy/metrics/_divisions.py | 45 +++++++++++++++------------- tests/metrics/test_ctc_metrics.py | 12 ++++---- tests/metrics/test_divisions.py | 3 +- 6 files changed, 55 insertions(+), 64 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 9c643ca9..09a96f7f 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -50,6 +50,6 @@ def run_metrics( results = {} for _metric in metrics: relevant_kwargs = metric_kwarg_dict[_metric] - result = _metric(matched, **relevant_kwargs) - results[_metric.__name__] = result.results + result = _metric(**relevant_kwargs).compute(matched) + results[_metric.__name__] = result return results diff --git a/src/traccuracy/metrics/_base.py b/src/traccuracy/metrics/_base.py index dd763678..c5d1acac 100644 --- a/src/traccuracy/metrics/_base.py +++ b/src/traccuracy/metrics/_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -6,32 +8,22 @@ class Metric(ABC): + """The base class for Metrics + + Data should be passed directly into the compute method + Kwargs should be specified in the constructor + """ + # Mapping criteria needs_one_to_one = False supports_one_to_many = False supports_many_to_one = False supports_many_to_many = False - def __init__(self, matched_data: "Matched"): - """Add Matched class which takes TrackingData objects for gt and pred,and computes matching - - Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. - The Matched objects will store both gt and pred data, as well as the mapping, - and any additional private attributes that may be needed/used e.g. detection matrices. - Metric subclasses will take keyword arguments to set the weights of various error counts. - - Args: - matched_data (Matched): Matched object for set of GT and Pred data - """ - self.data = matched_data - self.results = self.compute() - @abstractmethod - def compute(self) -> dict: + def compute(self, matched: Matched) -> dict: """The compute methods of Metric objects return a dictionary with counts and statistics. - They may make use of TrackingEvents objects but do not have to. - Raises: NotImplementedError diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index a5f0a84f..877a4f7a 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -14,7 +14,6 @@ class AOGMMetrics(Metric): def __init__( self, - matched_data: "Matched", vertex_ns_weight=1, vertex_fp_weight=1, vertex_fn_weight=1, @@ -32,22 +31,19 @@ def __init__( "fn": edge_fn_weight, "ws": edge_ws_weight, } - super().__init__(matched_data) - def compute(self): - evaluate_ctc_events(self.data) + def compute(self, data: "Matched"): + evaluate_ctc_events(data) vertex_error_counts = { - "ns": len(self.data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), - "fp": len(self.data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), - "fn": len(self.data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), + "ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), + "fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), + "fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), } edge_error_counts = { - "ws": len( - self.data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC) - ), - "fp": len(self.data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), - "fn": len(self.data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), + "ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)), + "fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), + "fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), } error_sum = get_weighted_error_sum( vertex_error_counts, @@ -71,7 +67,7 @@ def compute(self): class CTCMetrics(AOGMMetrics): - def __init__(self, matched_data: "Matched"): + def __init__(self): vertex_weight_ns = 5 vertex_weight_fn = 10 vertex_weight_fp = 1 @@ -80,7 +76,6 @@ def __init__(self, matched_data: "Matched"): edge_weight_fn = 1.5 edge_weight_ws = 1 super().__init__( - matched_data, vertex_ns_weight=vertex_weight_ns, vertex_fp_weight=vertex_weight_fp, vertex_fn_weight=vertex_weight_fn, @@ -89,9 +84,9 @@ def __init__(self, matched_data: "Matched"): edge_ws_weight=edge_weight_ws, ) - def compute(self): + def compute(self, data: "Matched"): # AOGM-0 is the cost of creating the gt graph from scratch - gt_graph = self.data.gt_graph.graph + gt_graph = data.gt_graph.graph n_nodes = gt_graph.number_of_nodes() n_edges = gt_graph.number_of_edges() aogm_0 = n_nodes * self.v_weights["fn"] + n_edges * self.e_weights["fn"] @@ -101,7 +96,7 @@ def compute(self): + f" {n_edges} edges with {self.v_weights['fn']} vertex FN weight and" + f" {self.e_weights['fn']} edge FN weight" ) - errors = super().compute() + errors = super().compute(data) aogm = errors["AOGM"] tra = 1 - min(aogm, aogm_0) / aogm_0 errors["TRA"] = tra diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 42492202..63aeec95 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -32,45 +32,50 @@ of the early division, by advancing along the graph to find nodes in the same frame as the late division daughters. """ +from __future__ import annotations +from typing import TYPE_CHECKING from traccuracy._tracking_graph import NodeAttr from traccuracy.track_errors.divisions import _evaluate_division_events from ._base import Metric +if TYPE_CHECKING: + from ._base import Matched + class DivisionMetrics(Metric): + """Classify division events and provide the following summary metrics + + - Division Recall + - Division Precision + - Division F1 Score + - Mitotic Branching Correctness: TP / (TP + FP + FN) as defined by Ulicna, K., + Vallardi, G., Charras, G. & Lowe, A. R. Automated deep lineage tree analysis + using a Bayesian single cell tracking approach. Frontiers in Computer Science + 3, 734559 (2021). + + Args: + matched_data (Matched): Matched object for set of GT and Pred data + Must meet the `needs_one_to_one` criteria + frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames + to tolerate in correct_shifted_divisions. Defaults to (0). + """ + needs_one_to_one = True - def __init__(self, matched_data, frame_buffer=(0,)): - """Classify division events and provide the following summary metrics - - - Division Recall - - Division Precision - - Division F1 Score - - Mitotic Branching Correctness: TP / (TP + FP + FN) as defined by Ulicna, K., - Vallardi, G., Charras, G. & Lowe, A. R. Automated deep lineage tree analysis - using a Bayesian single cell tracking approach. Frontiers in Computer Science - 3, 734559 (2021). - - Args: - matched_data (Matched): Matched object for set of GT and Pred data - Must meet the `needs_one_to_one` criteria - frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames - to tolerate in correct_shifted_divisions. Defaults to (0). - """ + def __init__(self, frame_buffer=(0,)): self.frame_buffer = frame_buffer - super().__init__(matched_data) - def compute(self): + def compute(self, data: Matched): """Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer Returns: dict: Returns a nested dictionary with one dictionary per frame buffer value """ div_annotations = _evaluate_division_events( - self.data, + data, frame_buffer=self.frame_buffer, ) diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index fb5bd601..849f2656 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -11,9 +11,9 @@ def test_compute_mapping(): track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph) - metric = CTCMetrics(matched) - assert metric.results - assert "TRA" in metric.results - assert "DET" in metric.results - assert metric.results["TRA"] == 1 - assert metric.results["DET"] == 1 + results = CTCMetrics().compute(matched) + assert results + assert "TRA" in results + assert "DET" in results + assert results["TRA"] == 1 + assert results["DET"] == 1 diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index e2679115..accc1439 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -23,8 +23,7 @@ def test_DivisionMetrics(): ) frame_buffer = (0, 1, 2) - metrics = DivisionMetrics(matched, frame_buffer=frame_buffer) - results = metrics.compute() + results = DivisionMetrics(frame_buffer=frame_buffer).compute(matched) for name, r in results.items(): buffer = int(name[-1:]) From 5ed1f400c83c00935b3fb2613adb6ef319285350 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sat, 4 Nov 2023 22:26:13 +0900 Subject: [PATCH 06/56] added definitions and just copies from laptrack code --- src/traccuracy/metrics/_track_matching.py | 185 ++++++++++++++++++++++ 1 file changed, 185 insertions(+) create mode 100644 src/traccuracy/metrics/_track_matching.py diff --git a/src/traccuracy/metrics/_track_matching.py b/src/traccuracy/metrics/_track_matching.py new file mode 100644 index 00000000..f3c18d87 --- /dev/null +++ b/src/traccuracy/metrics/_track_matching.py @@ -0,0 +1,185 @@ +"""This submodule implements routines for Track Purity (TP) and Target Effectiveness (TE) scores. + +Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): + +- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. + +- TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. +""" + +from typing import Dict +from typing import Optional +from typing import Sequence + +import networkx as nx +import numpy as np +import pandas as pd + + +def _add_split_edges(track_df, split_df): + track_df2 = track_df.copy() + for _, row in split_df.iterrows(): + p = ( + track_df[track_df["track_id"] == row["parent_track_id"]] + .sort_values("frame", ascending=True) + .iloc[-1] + ) + p["track_id"] = row["child_track_id"] + track_df2 = pd.concat([track_df2, pd.DataFrame(p).T]) + track_df2 = track_df2.sort_values(["frame", "index"]).reset_index(drop=True) + return track_df2 + + +def _df_to_edges(track_df): + track_edgess = [] + for _, grp in track_df.groupby("track_id"): + track_edges = [] + nodes = list(grp.sort_values("frame").iterrows()) + for (_, row1), (_, row2) in zip(nodes[:-1], nodes[1:]): + track_edges.append( + (tuple(row1[["frame", "index"]]), tuple(row2[["frame", "index"]])) + ) + track_edgess.append(track_edges) + return track_edgess + + +def _calc_overlap_score(reference_edgess, overlap_edgess): + correct_count = 0 + for reference_edges in reference_edgess: + overlaps = [ + len(set(reference_edges) & set(overlap_edges)) + for overlap_edges in overlap_edgess + ] + max_overlap = max(overlaps) + correct_count += max_overlap + total_count = sum([len(reference_edges) for reference_edges in reference_edgess]) + return correct_count / total_count if total_count > 0 else -1 + + +def calc_scores( + true_edges: EdgeType, + predicted_edges: EdgeType, + exclude_true_edges: EdgeType = [], + include_frames: Optional[Sequence[Int]] = None, + track_scores: bool = True, +) -> Dict[str, float]: + r""" + Calculate track prediction scores. + + Parameters + ---------- + true_edges : EdgeType + The list of true edges. Assumes ((frame1,index1), (frame2,index2)) for each edge. + + predicted_edges : EdgeType + The list of predicted edges. See `true_edges` for the format. + + exclude_true_edges : EdgeType, default [] + The list of true edges to be excluded from "\*_ratio". See `true_edges` for the format. + + include_frames : Optional[List[Int]], default None + The list of frames to include in the score calculation. If None, all the frames are included. + + track_scores : bool, default True + If True, the function calculates track_purity, target_effectiveness and mitotic_branching_correctness. + + Returns + ------- + score_dict : Dict[str,float] + The scores in the dict form. The keys are: + + - "Jaccard_index": (number of TP edges) / (number of TP edges + number of FP edges + number of FN edges) + - "true_positive_rate": (number of TP edges) / (number of TP edges + number of FN edges) + - "precision": (number of TP edges) / (number of TP edges + number of FP edges) + - "track_purity" : the track purity. + - "target_effectiveness" : the target effectiveness. + - "mitotic_branching_correctness" : the number of divisions that were correctly predicted. + """ + # return the count o + + if include_frames is None: + include_frames = list(range(np.max([e[0][0] for e in true_edges]) + 1)) + true_edges_included = [e for e in true_edges if e[0][0] in include_frames] + predicted_edges_included = [e for e in predicted_edges if e[0][0] in include_frames] + + if len(list(predicted_edges)) == 0: + return { + "Jaccard_index": 0, + "true_positive_rate": 0, + "precision": 0, + "track_purity": 0, + "target_effectiveness": 0, + "mitotic_branching_correctness": 0, + } + else: + + if track_scores: + ################ calculate track scores ################# + gt_tree = nx.from_edgelist(order_edges(true_edges), create_using=nx.DiGraph) + pred_tree = nx.from_edgelist( + order_edges(predicted_edges), create_using=nx.DiGraph + ) + gt_track_df, gt_split_df, _gt_merge_df = convert_tree_to_dataframe(gt_tree) + pred_track_df, pred_split_df, _pred_merge_df = convert_tree_to_dataframe( + pred_tree + ) + gt_track_df = gt_track_df.reset_index() + pred_track_df = pred_track_df.reset_index() + + gt_track_df = _add_split_edges(gt_track_df, gt_split_df) + pred_track_df = _add_split_edges(pred_track_df, pred_split_df) + gt_edgess = _df_to_edges(gt_track_df) + pred_edgess = _df_to_edges(pred_track_df) + + filter_edges = ( + lambda e: e[0][0] in include_frames and e not in exclude_true_edges + ) + pred_edgess = [ + [e for e in edges if filter_edges(e)] for edges in pred_edgess + ] + gt_edgess = [[e for e in edges if filter_edges(e)] for edges in gt_edgess] + track_purity = _calc_overlap_score(pred_edgess, gt_edgess) + target_effectiveness = _calc_overlap_score(gt_edgess, pred_edgess) + + ################ calculate division recovery ################# + def get_children(m): + return list(gt_tree.successors(m)) + + dividing_nodes = [m for m in gt_tree.nodes() if len(get_children(m)) > 1] + dividing_nodes = [m for m in dividing_nodes if m[0] in include_frames] + mitotic_branching_correctness_count = 0 + total_count = 0 + for m in dividing_nodes: + children = get_children(m) + + def check_match_children(edges): + return all([(n, m) in edges or (m, n) in edges for n in children]) + + excluded = check_match_children(exclude_true_edges) + if not excluded: + if check_match_children(predicted_edges): + mitotic_branching_correctness_count += 1 + total_count += 1 + + if total_count > 0: + mitotic_branching_correctness = ( + mitotic_branching_correctness_count / total_count + ) + else: + mitotic_branching_correctness = -1 + else: + track_purity = -1 + target_effectiveness = -1 + mitotic_branching_correctness = -1 + + ################ calculate edge overlaps ################# + te = set(true_edges_included) - set(exclude_true_edges) + pe = set(predicted_edges_included) - set(exclude_true_edges) + return { + "Jaccard_index": len(te & pe) / len(te | pe), + "true_positive_rate": len(te & pe) / len(te), + "precision": len(te & pe) / len(pe), + "track_purity": track_purity, + "target_effectiveness": target_effectiveness, + "mitotic_branching_correctness": mitotic_branching_correctness, + } From ed06f20840de5b906983b3152f0e0656389fb10b Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 11:48:41 +0900 Subject: [PATCH 07/56] trying to organize the script in Metric class --- src/traccuracy/metrics/_track_matching.py | 185 ---------------------- src/traccuracy/metrics/_track_overlap.py | 82 ++++++++++ 2 files changed, 82 insertions(+), 185 deletions(-) delete mode 100644 src/traccuracy/metrics/_track_matching.py create mode 100644 src/traccuracy/metrics/_track_overlap.py diff --git a/src/traccuracy/metrics/_track_matching.py b/src/traccuracy/metrics/_track_matching.py deleted file mode 100644 index f3c18d87..00000000 --- a/src/traccuracy/metrics/_track_matching.py +++ /dev/null @@ -1,185 +0,0 @@ -"""This submodule implements routines for Track Purity (TP) and Target Effectiveness (TE) scores. - -Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): - -- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. - -- TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. -""" - -from typing import Dict -from typing import Optional -from typing import Sequence - -import networkx as nx -import numpy as np -import pandas as pd - - -def _add_split_edges(track_df, split_df): - track_df2 = track_df.copy() - for _, row in split_df.iterrows(): - p = ( - track_df[track_df["track_id"] == row["parent_track_id"]] - .sort_values("frame", ascending=True) - .iloc[-1] - ) - p["track_id"] = row["child_track_id"] - track_df2 = pd.concat([track_df2, pd.DataFrame(p).T]) - track_df2 = track_df2.sort_values(["frame", "index"]).reset_index(drop=True) - return track_df2 - - -def _df_to_edges(track_df): - track_edgess = [] - for _, grp in track_df.groupby("track_id"): - track_edges = [] - nodes = list(grp.sort_values("frame").iterrows()) - for (_, row1), (_, row2) in zip(nodes[:-1], nodes[1:]): - track_edges.append( - (tuple(row1[["frame", "index"]]), tuple(row2[["frame", "index"]])) - ) - track_edgess.append(track_edges) - return track_edgess - - -def _calc_overlap_score(reference_edgess, overlap_edgess): - correct_count = 0 - for reference_edges in reference_edgess: - overlaps = [ - len(set(reference_edges) & set(overlap_edges)) - for overlap_edges in overlap_edgess - ] - max_overlap = max(overlaps) - correct_count += max_overlap - total_count = sum([len(reference_edges) for reference_edges in reference_edgess]) - return correct_count / total_count if total_count > 0 else -1 - - -def calc_scores( - true_edges: EdgeType, - predicted_edges: EdgeType, - exclude_true_edges: EdgeType = [], - include_frames: Optional[Sequence[Int]] = None, - track_scores: bool = True, -) -> Dict[str, float]: - r""" - Calculate track prediction scores. - - Parameters - ---------- - true_edges : EdgeType - The list of true edges. Assumes ((frame1,index1), (frame2,index2)) for each edge. - - predicted_edges : EdgeType - The list of predicted edges. See `true_edges` for the format. - - exclude_true_edges : EdgeType, default [] - The list of true edges to be excluded from "\*_ratio". See `true_edges` for the format. - - include_frames : Optional[List[Int]], default None - The list of frames to include in the score calculation. If None, all the frames are included. - - track_scores : bool, default True - If True, the function calculates track_purity, target_effectiveness and mitotic_branching_correctness. - - Returns - ------- - score_dict : Dict[str,float] - The scores in the dict form. The keys are: - - - "Jaccard_index": (number of TP edges) / (number of TP edges + number of FP edges + number of FN edges) - - "true_positive_rate": (number of TP edges) / (number of TP edges + number of FN edges) - - "precision": (number of TP edges) / (number of TP edges + number of FP edges) - - "track_purity" : the track purity. - - "target_effectiveness" : the target effectiveness. - - "mitotic_branching_correctness" : the number of divisions that were correctly predicted. - """ - # return the count o - - if include_frames is None: - include_frames = list(range(np.max([e[0][0] for e in true_edges]) + 1)) - true_edges_included = [e for e in true_edges if e[0][0] in include_frames] - predicted_edges_included = [e for e in predicted_edges if e[0][0] in include_frames] - - if len(list(predicted_edges)) == 0: - return { - "Jaccard_index": 0, - "true_positive_rate": 0, - "precision": 0, - "track_purity": 0, - "target_effectiveness": 0, - "mitotic_branching_correctness": 0, - } - else: - - if track_scores: - ################ calculate track scores ################# - gt_tree = nx.from_edgelist(order_edges(true_edges), create_using=nx.DiGraph) - pred_tree = nx.from_edgelist( - order_edges(predicted_edges), create_using=nx.DiGraph - ) - gt_track_df, gt_split_df, _gt_merge_df = convert_tree_to_dataframe(gt_tree) - pred_track_df, pred_split_df, _pred_merge_df = convert_tree_to_dataframe( - pred_tree - ) - gt_track_df = gt_track_df.reset_index() - pred_track_df = pred_track_df.reset_index() - - gt_track_df = _add_split_edges(gt_track_df, gt_split_df) - pred_track_df = _add_split_edges(pred_track_df, pred_split_df) - gt_edgess = _df_to_edges(gt_track_df) - pred_edgess = _df_to_edges(pred_track_df) - - filter_edges = ( - lambda e: e[0][0] in include_frames and e not in exclude_true_edges - ) - pred_edgess = [ - [e for e in edges if filter_edges(e)] for edges in pred_edgess - ] - gt_edgess = [[e for e in edges if filter_edges(e)] for edges in gt_edgess] - track_purity = _calc_overlap_score(pred_edgess, gt_edgess) - target_effectiveness = _calc_overlap_score(gt_edgess, pred_edgess) - - ################ calculate division recovery ################# - def get_children(m): - return list(gt_tree.successors(m)) - - dividing_nodes = [m for m in gt_tree.nodes() if len(get_children(m)) > 1] - dividing_nodes = [m for m in dividing_nodes if m[0] in include_frames] - mitotic_branching_correctness_count = 0 - total_count = 0 - for m in dividing_nodes: - children = get_children(m) - - def check_match_children(edges): - return all([(n, m) in edges or (m, n) in edges for n in children]) - - excluded = check_match_children(exclude_true_edges) - if not excluded: - if check_match_children(predicted_edges): - mitotic_branching_correctness_count += 1 - total_count += 1 - - if total_count > 0: - mitotic_branching_correctness = ( - mitotic_branching_correctness_count / total_count - ) - else: - mitotic_branching_correctness = -1 - else: - track_purity = -1 - target_effectiveness = -1 - mitotic_branching_correctness = -1 - - ################ calculate edge overlaps ################# - te = set(true_edges_included) - set(exclude_true_edges) - pe = set(predicted_edges_included) - set(exclude_true_edges) - return { - "Jaccard_index": len(te & pe) / len(te | pe), - "true_positive_rate": len(te & pe) / len(te), - "precision": len(te & pe) / len(pe), - "track_purity": track_purity, - "target_effectiveness": target_effectiveness, - "mitotic_branching_correctness": mitotic_branching_correctness, - } diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py new file mode 100644 index 00000000..afc0f608 --- /dev/null +++ b/src/traccuracy/metrics/_track_overlap.py @@ -0,0 +1,82 @@ +"""This submodule implements routines for Track Purity (TP) and Target Effectiveness (TE) scores. + +Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): + +- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. + +- TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. +""" + +from typing import TYPE_CHECKING +from typing import Dict +from typing import Optional +from typing import Sequence + +import networkx as nx +import numpy as np +import pandas as pd + +from ._base import Metric + +if TYPE_CHECKING: + from ._base import Matched + +class TrackOverlapMetrics(Metric): + supports_many_to_one = True + + def __init__(self, matched_data: "Matched"): + super().__init__(matched_data) + + def compute(self): + + self.data.gt_graph + + # requires tracklets that also have the splitting and merging edges + gt_tree = nx.from_edgelist(order_edges(true_edges), create_using=nx.DiGraph) + pred_tree = nx.from_edgelist( + order_edges(predicted_edges), create_using=nx.DiGraph + ) + gt_track_df, gt_split_df, _gt_merge_df = convert_tree_to_dataframe(gt_tree) + pred_track_df, pred_split_df, _pred_merge_df = convert_tree_to_dataframe( + pred_tree + ) + gt_track_df = gt_track_df.reset_index() + pred_track_df = pred_track_df.reset_index() + + gt_track_df = _add_split_edges(gt_track_df, gt_split_df) + pred_track_df = _add_split_edges(pred_track_df, pred_split_df) + + # edgess are a list of edges, grouped by the tracks + gt_edgess = _df_to_edges(gt_track_df) + pred_edgess = _df_to_edges(pred_track_df) + + # filter out edges that are not in the include_frames and exclude_true_edges + filter_edges = ( + lambda e: e[0][0] in include_frames and e not in exclude_true_edges + ) + pred_edgess = [ + [e for e in edges if filter_edges(e)] for edges in pred_edgess + ] + gt_edgess = [[e for e in edges if filter_edges(e)] for edges in gt_edgess] + + # calculate track purity and target effectiveness + track_purity = _calc_overlap_score(pred_edgess, gt_edgess) + target_effectiveness = _calc_overlap_score(gt_edgess, pred_edgess) + return { + "track_purity": track_purity, + "target_effectiveness": target_effectiveness, + } + +def _calc_overlap_score(reference_edgess, overlap_edgess): + correct_count = 0 + for reference_edges in reference_edgess: + overlaps = [ + len(set(reference_edges) & set(overlap_edges)) + for overlap_edges in overlap_edgess + ] + max_overlap = max(overlaps) + correct_count += max_overlap + total_count = sum([len(reference_edges) for reference_edges in reference_edgess]) + return correct_count / total_count if total_count > 0 else -1 + + From 2ddc8fc5727066f9a26a7b3b623c716ffa1e3620 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 12:25:11 +0900 Subject: [PATCH 08/56] organized scripts --- src/traccuracy/_tracking_graph.py | 18 ++++-- src/traccuracy/metrics/_track_overlap.py | 73 +++++++++++------------- 2 files changed, 47 insertions(+), 44 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 1153ea6e..ab502894 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -616,7 +616,7 @@ def get_edge_attribute(self, _id, attr): return False return self.graph.edges[_id][attr] - def get_tracklets(self): + def get_tracklets(self, include_intertrack_edges:bool=False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next @@ -626,11 +626,21 @@ def get_tracklets(self): graph_copy = self.graph.copy() # Remove all intertrack edges from a copy of the original graph + removed_edges = [] for parent in self.get_divisions(): for daughter in self.get_succs(parent): graph_copy.remove_edge(parent, daughter) + removed_edges.append((parent, daughter)) + # Extract subgraphs (aka tracklets) and return as new track graphs - return [ - self.get_subgraph(g) for g in nx.weakly_connected_components(graph_copy) - ] + tracklets = nx.weakly_connected_components(graph_copy) + + if include_intertrack_edges: + # Add back intertrack edges + for tracklet in tracklets: + for parent, daughter in removed_edges: + if daughter in tracklet.nodes: + tracklet.add_edge(parent, daughter) + + return [self.get_subgraph(g) for g in tracklets] \ No newline at end of file diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index afc0f608..d82f5518 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -8,16 +8,13 @@ """ from typing import TYPE_CHECKING -from typing import Dict -from typing import Optional -from typing import Sequence +from typing import List, Tuple, Any -import networkx as nx -import numpy as np -import pandas as pd +from traccuracy._tracking_graph import TrackingGraph from ._base import Metric + if TYPE_CHECKING: from ._base import Matched @@ -28,55 +25,51 @@ def __init__(self, matched_data: "Matched"): super().__init__(matched_data) def compute(self): - - self.data.gt_graph # requires tracklets that also have the splitting and merging edges - gt_tree = nx.from_edgelist(order_edges(true_edges), create_using=nx.DiGraph) - pred_tree = nx.from_edgelist( - order_edges(predicted_edges), create_using=nx.DiGraph - ) - gt_track_df, gt_split_df, _gt_merge_df = convert_tree_to_dataframe(gt_tree) - pred_track_df, pred_split_df, _pred_merge_df = convert_tree_to_dataframe( - pred_tree - ) - gt_track_df = gt_track_df.reset_index() - pred_track_df = pred_track_df.reset_index() - - gt_track_df = _add_split_edges(gt_track_df, gt_split_df) - pred_track_df = _add_split_edges(pred_track_df, pred_split_df) - # edgess are a list of edges, grouped by the tracks - gt_edgess = _df_to_edges(gt_track_df) - pred_edgess = _df_to_edges(pred_track_df) + gt_tracklets = self.data.gt_graph.get_tracklets(include_intertrack_edges=True) + pred_tracklets = self.data.pred_graph.get_tracklets(include_intertrack_edges=True) - # filter out edges that are not in the include_frames and exclude_true_edges - filter_edges = ( - lambda e: e[0][0] in include_frames and e not in exclude_true_edges - ) - pred_edgess = [ - [e for e in edges if filter_edges(e)] for edges in pred_edgess - ] - gt_edgess = [[e for e in edges if filter_edges(e)] for edges in gt_edgess] + gt_pred_mapping = self.data.mapping + pred_gt_mapping = [(pred_node, gt_node) for gt_node, pred_node in gt_pred_mapping] # calculate track purity and target effectiveness - track_purity = _calc_overlap_score(pred_edgess, gt_edgess) - target_effectiveness = _calc_overlap_score(gt_edgess, pred_edgess) + track_purity = _calc_overlap_score(pred_tracklets, gt_tracklets, pred_gt_mapping) + target_effectiveness = _calc_overlap_score(gt_tracklets, pred_tracklets, gt_pred_mapping) return { "track_purity": track_purity, "target_effectiveness": target_effectiveness, } -def _calc_overlap_score(reference_edgess, overlap_edgess): + +def _calc_overlap_score(reference_tracklets: List[TrackingGraph], + overlap_tracklets: List[TrackingGraph], + mapping: List[Tuple[Any, Any]]): + """ Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. + + Args: + reference_tracklets (List[TrackingGraph]): The reference tracklets + overlap_tracklets (List[TrackingGraph]): The tracklets that overlap with the reference tracklets + mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes and the overlap tracklet nodes + + + """ correct_count = 0 - for reference_edges in reference_edgess: + total_count = 0 + # iterate over the reference tracklets + for reference_tracklet in reference_tracklets: + # find the overlap tracklet with the largest overlap + reference_tracklet_nodes_mapped = [ + n_to for (n_from, n_to) in mapping if n_from in reference_tracklet.nodes() + ] overlaps = [ - len(set(reference_edges) & set(overlap_edges)) - for overlap_edges in overlap_edgess + len(set(reference_tracklet_nodes_mapped) & set(overlap_tracklet.nodes())) + for overlap_tracklet in overlap_tracklets ] max_overlap = max(overlaps) correct_count += max_overlap - total_count = sum([len(reference_edges) for reference_edges in reference_edgess]) - return correct_count / total_count if total_count > 0 else -1 + total_count += len(reference_tracklet_nodes_mapped) + return correct_count / total_count if total_count > 0 else -1 From c314331d765efe93facedd383de648f682aabd84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Nov 2023 03:27:00 +0000 Subject: [PATCH 09/56] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_tracking_graph.py | 7 ++-- src/traccuracy/metrics/_track_overlap.py | 53 ++++++++++++++---------- 2 files changed, 33 insertions(+), 27 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index ab502894..cfb9cb48 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -616,7 +616,7 @@ def get_edge_attribute(self, _id, attr): return False return self.graph.edges[_id][attr] - def get_tracklets(self, include_intertrack_edges:bool=False): + def get_tracklets(self, include_intertrack_edges: bool = False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next @@ -631,11 +631,10 @@ def get_tracklets(self, include_intertrack_edges:bool=False): for daughter in self.get_succs(parent): graph_copy.remove_edge(parent, daughter) removed_edges.append((parent, daughter)) - # Extract subgraphs (aka tracklets) and return as new track graphs tracklets = nx.weakly_connected_components(graph_copy) - + if include_intertrack_edges: # Add back intertrack edges for tracklet in tracklets: @@ -643,4 +642,4 @@ def get_tracklets(self, include_intertrack_edges:bool=False): if daughter in tracklet.nodes: tracklet.add_edge(parent, daughter) - return [self.get_subgraph(g) for g in tracklets] \ No newline at end of file + return [self.get_subgraph(g) for g in tracklets] diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index d82f5518..1a6cbce4 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -7,60 +7,68 @@ - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. """ -from typing import TYPE_CHECKING -from typing import List, Tuple, Any +from typing import TYPE_CHECKING, Any, List, Tuple from traccuracy._tracking_graph import TrackingGraph from ._base import Metric - if TYPE_CHECKING: from ._base import Matched + class TrackOverlapMetrics(Metric): supports_many_to_one = True - + def __init__(self, matched_data: "Matched"): super().__init__(matched_data) def compute(self): - # requires tracklets that also have the splitting and merging edges # edgess are a list of edges, grouped by the tracks gt_tracklets = self.data.gt_graph.get_tracklets(include_intertrack_edges=True) - pred_tracklets = self.data.pred_graph.get_tracklets(include_intertrack_edges=True) - + pred_tracklets = self.data.pred_graph.get_tracklets( + include_intertrack_edges=True + ) + gt_pred_mapping = self.data.mapping - pred_gt_mapping = [(pred_node, gt_node) for gt_node, pred_node in gt_pred_mapping] - + pred_gt_mapping = [ + (pred_node, gt_node) for gt_node, pred_node in gt_pred_mapping + ] + # calculate track purity and target effectiveness - track_purity = _calc_overlap_score(pred_tracklets, gt_tracklets, pred_gt_mapping) - target_effectiveness = _calc_overlap_score(gt_tracklets, pred_tracklets, gt_pred_mapping) + track_purity = _calc_overlap_score( + pred_tracklets, gt_tracklets, pred_gt_mapping + ) + target_effectiveness = _calc_overlap_score( + gt_tracklets, pred_tracklets, gt_pred_mapping + ) return { "track_purity": track_purity, "target_effectiveness": target_effectiveness, } -def _calc_overlap_score(reference_tracklets: List[TrackingGraph], - overlap_tracklets: List[TrackingGraph], - mapping: List[Tuple[Any, Any]]): - """ Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. - - Args: - reference_tracklets (List[TrackingGraph]): The reference tracklets - overlap_tracklets (List[TrackingGraph]): The tracklets that overlap with the reference tracklets - mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes and the overlap tracklet nodes +def _calc_overlap_score( + reference_tracklets: List[TrackingGraph], + overlap_tracklets: List[TrackingGraph], + mapping: List[Tuple[Any, Any]], +): + """Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. + + Args: + reference_tracklets (List[TrackingGraph]): The reference tracklets + overlap_tracklets (List[TrackingGraph]): The tracklets that overlap with the reference tracklets + mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes and the overlap tracklet nodes + - """ correct_count = 0 total_count = 0 # iterate over the reference tracklets for reference_tracklet in reference_tracklets: # find the overlap tracklet with the largest overlap - reference_tracklet_nodes_mapped = [ + reference_tracklet_nodes_mapped = [ n_to for (n_from, n_to) in mapping if n_from in reference_tracklet.nodes() ] overlaps = [ @@ -72,4 +80,3 @@ def _calc_overlap_score(reference_tracklets: List[TrackingGraph], total_count += len(reference_tracklet_nodes_mapped) return correct_count / total_count if total_count > 0 else -1 - From 6e46f0db70bf0d58f6aadc391b6a33a9044200eb Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 12:28:37 +0900 Subject: [PATCH 10/56] solving problems --- src/traccuracy/metrics/_track_overlap.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index d82f5518..edd62bda 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -2,7 +2,11 @@ Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): -- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. +- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k + that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts + by the total frame counts for T^g_j. + The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, + weighted by the length of the tracks. - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. """ @@ -51,7 +55,7 @@ def _calc_overlap_score(reference_tracklets: List[TrackingGraph], Args: reference_tracklets (List[TrackingGraph]): The reference tracklets overlap_tracklets (List[TrackingGraph]): The tracklets that overlap with the reference tracklets - mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes and the overlap tracklet nodes + mapping (List[Tuple[Any, Any]]): Mapping between the reference nodes and the overlap nodes """ From a68ce54403a87f63fd4982613638a9628ea9a372 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Nov 2023 03:30:01 +0000 Subject: [PATCH 11/56] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/metrics/_track_overlap.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index ee0ed580..702b19db 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -2,10 +2,10 @@ Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): -- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k - that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts - by the total frame counts for T^g_j. - The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, +- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k + that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts + by the total frame counts for T^g_j. + The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. @@ -58,13 +58,13 @@ def _calc_overlap_score( overlap_tracklets: List[TrackingGraph], mapping: List[Tuple[Any, Any]], ): - """Calculate weighted sum of the length of the longest overlap tracklet + """Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. Args: reference_tracklets (List[TrackingGraph]): The reference tracklets overlap_tracklets (List[TrackingGraph]): The tracklets that overlap - mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes + mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes and the overlap tracklet nodes """ From 8efa60e23e79e1663afad1a0153cc84bb35ab697 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 12:31:28 +0900 Subject: [PATCH 12/56] fixed pre-commit problem --- src/traccuracy/metrics/_track_overlap.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 702b19db..3d1789f1 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -3,8 +3,8 @@ Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): - TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k - that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts - by the total frame counts for T^g_j. + that overlaps with T^g_j in the largest number of the frames and then dividing + the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. From a7ff3a349c3c8831380dcf0a010457bd891f91af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Nov 2023 03:31:39 +0000 Subject: [PATCH 13/56] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/metrics/_track_overlap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 3d1789f1..9d4e79ca 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -3,7 +3,7 @@ Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): - TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k - that overlaps with T^g_j in the largest number of the frames and then dividing + that overlaps with T^g_j in the largest number of the frames and then dividing the overlap frame counts by the total frame counts for T^g_j. The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, weighted by the length of the tracks. From 9d125b86639c29844447e1cd2fa7cbd131619aa3 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 14:12:53 +0900 Subject: [PATCH 14/56] test added but failing --- src/traccuracy/_tracking_graph.py | 5 +- tests/metrics/test_divisions.py | 12 +-- tests/metrics/test_track_overlap_metrics.py | 86 +++++++++++++++++++++ tests/test_utils.py | 10 +++ 4 files changed, 100 insertions(+), 13 deletions(-) create mode 100644 tests/metrics/test_track_overlap_metrics.py diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index cfb9cb48..8ddf52e5 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -636,10 +636,11 @@ def get_tracklets(self, include_intertrack_edges: bool = False): tracklets = nx.weakly_connected_components(graph_copy) if include_intertrack_edges: + tracklets = list(tracklets) # Add back intertrack edges for tracklet in tracklets: for parent, daughter in removed_edges: - if daughter in tracklet.nodes: - tracklet.add_edge(parent, daughter) + if daughter in tracklet: + tracklet.add(parent) return [self.get_subgraph(g) for g in tracklets] diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index e2679115..b037e96c 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,17 +1,7 @@ from traccuracy import TrackingGraph -from traccuracy.matchers._matched import Matched from traccuracy.metrics._divisions import DivisionMetrics -from tests.test_utils import get_division_graphs - - -class DummyMatched(Matched): - def __init__(self, gt_data, pred_data, mapper): - self.mapper = mapper - super().__init__(gt_data, pred_data) - - def compute_mapping(self): - return self.mapper +from tests.test_utils import DummyMatched, get_division_graphs def test_DivisionMetrics(): diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py new file mode 100644 index 00000000..df56e20f --- /dev/null +++ b/tests/metrics/test_track_overlap_metrics.py @@ -0,0 +1,86 @@ +import networkx as nx +import pytest +from traccuracy import TrackingGraph +from traccuracy.metrics._track_overlap import TrackOverlapMetrics + +from tests.test_utils import DummyMatched + + +@pytest.fixture +def test_trees(): + # 0 - 1 - 2 - 3 - 4 - 5 + # | + # - 3 - 4 - 5 + # + # 1 - 2 - 3 - 4 + true_edges = [ + ((0, 0), (1, 0)), + ((1, 0), (2, 0)), + ((2, 0), (3, 0)), + ((3, 0), (4, 0)), + ((4, 0), (5, 0)), + ((2, 0), (3, 1)), + ((3, 1), (4, 1)), + ((4, 1), (5, 1)), + ((1, 2), (2, 2)), + ((2, 2), (3, 2)), + ((3, 2), (4, 2)), + ] + + # 0 - 1 - 2 - 3 - 4 - 5 + # | + # - 3 + # - 4 - 5 + # | + # 1 - 2 - 3 - 4 + pred_edges = [ + ((0, 0), (1, 0)), + ((1, 0), (2, 0)), + ((2, 0), (3, 0)), + ((4, 0), (5, 0)), + ((2, 0), (3, 1)), + ((1, 2), (2, 2)), + ((2, 2), (3, 2)), + ((3, 2), (4, 1)), + ((4, 1), (5, 1)), + ] + + def to_str(x): + return "_".join([str(i) for i in x]) + + def to_tree(x): + return nx.from_edgelist( + [(to_str(n1), to_str(n2)) for n1, n2 in x], create_using=nx.DiGraph + ) + + true_tree = to_tree(true_edges) + pred_tree = to_tree(pred_edges) + + attrs = {} + for node in true_tree.nodes: + attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} + nx.set_node_attributes(true_tree, attrs) + attrs = {} + for node in pred_tree.nodes: + attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} + nx.set_node_attributes(pred_tree, attrs) + + mapping = [(n, n) for n in true_tree.nodes] + return true_tree, pred_tree, mapping + + +def test_track_overlap_metrics(test_trees) -> None: + g_gt, g_pred, mapping = test_trees + matched = DummyMatched( + TrackingGraph(g_gt), + TrackingGraph(g_pred), + mapper=mapping, + ) + + metric = TrackOverlapMetrics(matched) + assert metric.results + + assert metric.results == { + "track_purity": 7 / 9, + "target_effectiveness": 6 / 11, + } diff --git a/tests/test_utils.py b/tests/test_utils.py index 1dbf6d7c..6586886f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import numpy as np import skimage as sk from traccuracy._tracking_graph import TrackingGraph +from traccuracy.matchers._matched import Matched def get_annotated_image(img_size=256, num_labels=3, sequential=True, seed=1): @@ -148,3 +149,12 @@ def get_division_graphs(): mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")] return G1, G2, mapper + + +class DummyMatched(Matched): + def __init__(self, gt_data, pred_data, mapper): + self.mapper = mapper + super().__init__(gt_data, pred_data) + + def compute_mapping(self): + return self.mapper From 66ee1a292f264889686d4d9811df11f277a08c76 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sun, 5 Nov 2023 14:34:00 +0900 Subject: [PATCH 15/56] test running --- src/traccuracy/metrics/_track_overlap.py | 19 ++++++++++++++----- tests/metrics/test_track_overlap_metrics.py | 4 ++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 9d4e79ca..2d2b9e7e 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -11,6 +11,7 @@ - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. """ +from itertools import product from typing import TYPE_CHECKING, Any, List, Tuple from traccuracy._tracking_graph import TrackingGraph @@ -71,17 +72,25 @@ def _calc_overlap_score( correct_count = 0 total_count = 0 # iterate over the reference tracklets + + def map_node(node): + return [n_to for (n_from, n_to) in mapping if n_from == node] + for reference_tracklet in reference_tracklets: # find the overlap tracklet with the largest overlap - reference_tracklet_nodes_mapped = [ - n_to for (n_from, n_to) in mapping if n_from in reference_tracklet.nodes() - ] + reference_tracklet_edges_mapped = [] + for node1, node2 in reference_tracklet.edges(): + mapped_nodes1 = map_node(node1) + mapped_nodes2 = map_node(node2) + if mapped_nodes1 and mapped_nodes2: + for n1, n2 in product(mapped_nodes1, mapped_nodes2): + reference_tracklet_edges_mapped.append((n1, n2)) overlaps = [ - len(set(reference_tracklet_nodes_mapped) & set(overlap_tracklet.nodes())) + len(set(reference_tracklet_edges_mapped) & set(overlap_tracklet.edges())) for overlap_tracklet in overlap_tracklets ] max_overlap = max(overlaps) correct_count += max_overlap - total_count += len(reference_tracklet_nodes_mapped) + total_count += len(reference_tracklet_edges_mapped) return correct_count / total_count if total_count > 0 else -1 diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index df56e20f..643bf134 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -27,12 +27,12 @@ def test_trees(): ((3, 2), (4, 2)), ] - # 0 - 1 - 2 - 3 - 4 - 5 + # 0 - 1 - 2 - 3 4 - 5 # | # - 3 # - 4 - 5 # | - # 1 - 2 - 3 - 4 + # 1 - 2 - 3 - pred_edges = [ ((0, 0), (1, 0)), ((1, 0), (2, 0)), From 7868916dca0cdbf108344d0ea025ccdf40a73d52 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Mon, 6 Nov 2023 09:38:10 +0900 Subject: [PATCH 16/56] made it possible to exclude division edges --- src/traccuracy/_tracking_graph.py | 8 ++++++-- src/traccuracy/metrics/_track_overlap.py | 18 +++++++++++++++--- tests/metrics/test_track_overlap_metrics.py | 8 ++++++++ 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 8ddf52e5..082e0ac3 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -616,11 +616,15 @@ def get_edge_attribute(self, _id, attr): return False return self.graph.edges[_id][attr] - def get_tracklets(self, include_intertrack_edges: bool = False): + def get_tracklets(self, include_division_edges: bool = False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next parent). Tracklets can also start or end with a non-dividing cell. + + Args: + include_division_edges (bool, optional): If True, include edges at division. + """ graph_copy = self.graph.copy() @@ -635,7 +639,7 @@ def get_tracklets(self, include_intertrack_edges: bool = False): # Extract subgraphs (aka tracklets) and return as new track graphs tracklets = nx.weakly_connected_components(graph_copy) - if include_intertrack_edges: + if include_division_edges: tracklets = list(tracklets) # Add back intertrack edges for tracklet in tracklets: diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 2d2b9e7e..cb4be3b8 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -25,15 +25,27 @@ class TrackOverlapMetrics(Metric): supports_many_to_one = True - def __init__(self, matched_data: "Matched"): + def __init__(self, matched_data: "Matched", include_division_edges: bool = True): + """Calculate metrics for longest track overlaps. + + - Target Effectiveness + - Track Purity + + Args: + matched_data (Matched): Matched object for set of GT and Pred data + include_division_edges (bool, optional): If True, include edges at division. + """ + self.include_division_edges = include_division_edges super().__init__(matched_data) def compute(self): # requires tracklets that also have the splitting and merging edges # edgess are a list of edges, grouped by the tracks - gt_tracklets = self.data.gt_graph.get_tracklets(include_intertrack_edges=True) + gt_tracklets = self.data.gt_graph.get_tracklets( + include_division_edges=self.include_division_edges + ) pred_tracklets = self.data.pred_graph.get_tracklets( - include_intertrack_edges=True + include_division_edges=self.include_division_edges ) gt_pred_mapping = self.data.mapping diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index 643bf134..a83ae32d 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -84,3 +84,11 @@ def test_track_overlap_metrics(test_trees) -> None: "track_purity": 7 / 9, "target_effectiveness": 6 / 11, } + + metric = TrackOverlapMetrics(matched, include_division_edges=False) + assert metric.results + + assert metric.results == { + "track_purity": 5 / 7, + "target_effectiveness": 6 / 9, + } From a83b11d8c16f950be167a3c03cc8304fd85adf8d Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 7 Nov 2023 14:06:14 -0800 Subject: [PATCH 17/56] Correct benchmarking calls --- tests/bench.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/bench.py b/tests/bench.py index 6068e459..df541640 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -94,7 +94,7 @@ def test_ctc_matched(benchmark, gt_data, pred_data): @pytest.mark.timeout(300) def test_ctc_metrics(benchmark, ctc_matched): def run_compute(): - return CTCMetrics(copy.deepcopy(ctc_matched)).compute() + return CTCMetrics().compute(copy.deepcopy(ctc_matched)) ctc_results = benchmark.pedantic(run_compute, rounds=1, iterations=1) @@ -108,7 +108,7 @@ def run_compute(): def test_ctc_div_metrics(benchmark, ctc_matched): def run_compute(): - return DivisionMetrics(copy.deepcopy(ctc_matched)).compute() + return DivisionMetrics().compute(copy.deepcopy(ctc_matched)) div_results = benchmark(run_compute) @@ -123,7 +123,7 @@ def test_iou_matched(benchmark, gt_data, pred_data): def test_iou_div_metrics(benchmark, iou_matched): def run_compute(): - return DivisionMetrics(copy.deepcopy(iou_matched)).compute() + return DivisionMetrics().compute(copy.deepcopy(iou_matched)) div_results = benchmark(run_compute) From e1cde0d53f3e547ad21c70dafe361c3c989a557c Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 8 Nov 2023 22:23:52 +0900 Subject: [PATCH 18/56] moved docstring --- src/traccuracy/metrics/_track_overlap.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index cb4be3b8..7fa61096 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -23,24 +23,26 @@ class TrackOverlapMetrics(Metric): - supports_many_to_one = True + """Calculate metrics for longest track overlaps. - def __init__(self, matched_data: "Matched", include_division_edges: bool = True): - """Calculate metrics for longest track overlaps. + - Target Effectiveness: fraction of longest overlapping prediction + tracklets onto each GT tracklet + - Track Purity : fraction of longest overlapping GT + tracklets onto each prediction tracklet + + Args: + matched_data (Matched): Matched object for set of GT and Pred data + include_division_edges (bool, optional): If True, include edges at division. - - Target Effectiveness - - Track Purity + """ + + supports_many_to_one = True - Args: - matched_data (Matched): Matched object for set of GT and Pred data - include_division_edges (bool, optional): If True, include edges at division. - """ + def __init__(self, matched_data: "Matched", include_division_edges: bool = True): self.include_division_edges = include_division_edges super().__init__(matched_data) def compute(self): - # requires tracklets that also have the splitting and merging edges - # edgess are a list of edges, grouped by the tracks gt_tracklets = self.data.gt_graph.get_tracklets( include_division_edges=self.include_division_edges ) From f14b5a186366c6c0708f1cb3f4ec1adcb19a62ae Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Thu, 9 Nov 2023 09:44:09 +0900 Subject: [PATCH 19/56] fixed one-to-multi logic --- src/traccuracy/metrics/_track_overlap.py | 39 +++++++++++++++--------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 7fa61096..0ccac857 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -26,9 +26,9 @@ class TrackOverlapMetrics(Metric): """Calculate metrics for longest track overlaps. - Target Effectiveness: fraction of longest overlapping prediction - tracklets onto each GT tracklet + tracklets on each GT tracklet - Track Purity : fraction of longest overlapping GT - tracklets onto each prediction tracklet + tracklets on each prediction tracklet Args: matched_data (Matched): Matched object for set of GT and Pred data @@ -57,10 +57,10 @@ def compute(self): # calculate track purity and target effectiveness track_purity = _calc_overlap_score( - pred_tracklets, gt_tracklets, pred_gt_mapping + pred_tracklets, gt_tracklets, gt_pred_mapping ) target_effectiveness = _calc_overlap_score( - gt_tracklets, pred_tracklets, gt_pred_mapping + gt_tracklets, pred_tracklets, pred_gt_mapping ) return { "track_purity": track_purity, @@ -71,7 +71,7 @@ def compute(self): def _calc_overlap_score( reference_tracklets: List[TrackingGraph], overlap_tracklets: List[TrackingGraph], - mapping: List[Tuple[Any, Any]], + overlap_reference_mapping: List[Tuple[Any, Any]], ): """Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. @@ -87,24 +87,33 @@ def _calc_overlap_score( total_count = 0 # iterate over the reference tracklets - def map_node(node): - return [n_to for (n_from, n_to) in mapping if n_from == node] + def map_node(overlap_node): + return [ + n_reference + for (n_overlap, n_reference) in overlap_reference_mapping + if n_overlap == overlap_node + ] - for reference_tracklet in reference_tracklets: - # find the overlap tracklet with the largest overlap - reference_tracklet_edges_mapped = [] - for node1, node2 in reference_tracklet.edges(): + # calculate all overlapping edges mapped onto GT ids + overlap_tracklets_edges_mapped = [] + for overlap_tracklet in overlap_tracklets: + edges = [] + for node1, node2 in overlap_tracklet.edges(): mapped_nodes1 = map_node(node1) mapped_nodes2 = map_node(node2) if mapped_nodes1 and mapped_nodes2: for n1, n2 in product(mapped_nodes1, mapped_nodes2): - reference_tracklet_edges_mapped.append((n1, n2)) + edges.append((n1, n2)) + overlap_tracklets_edges_mapped.append(edges) + + for reference_tracklet in reference_tracklets: + # find the overlap tracklet with the largest overlap overlaps = [ - len(set(reference_tracklet_edges_mapped) & set(overlap_tracklet.edges())) - for overlap_tracklet in overlap_tracklets + len(set(reference_tracklet.edges()) & set(edges)) + for edges in overlap_tracklets_edges_mapped ] max_overlap = max(overlaps) correct_count += max_overlap - total_count += len(reference_tracklet_edges_mapped) + total_count += len(reference_tracklet.edges()) return correct_count / total_count if total_count > 0 else -1 From 59a77c4656dec28171e16b9f7a8ba0264df24295 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Nov 2023 11:10:26 +0900 Subject: [PATCH 20/56] test data added but test faling --- tests/metrics/test_track_overlap_metrics.py | 173 ++++++++++++-------- 1 file changed, 106 insertions(+), 67 deletions(-) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index a83ae32d..8e96ef50 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import networkx as nx import pytest from traccuracy import TrackingGraph @@ -6,71 +8,110 @@ from tests.test_utils import DummyMatched -@pytest.fixture -def test_trees(): - # 0 - 1 - 2 - 3 - 4 - 5 - # | - # - 3 - 4 - 5 - # - # 1 - 2 - 3 - 4 - true_edges = [ - ((0, 0), (1, 0)), - ((1, 0), (2, 0)), - ((2, 0), (3, 0)), - ((3, 0), (4, 0)), - ((4, 0), (5, 0)), - ((2, 0), (3, 1)), - ((3, 1), (4, 1)), - ((4, 1), (5, 1)), - ((1, 2), (2, 2)), - ((2, 2), (3, 2)), - ((3, 2), (4, 2)), - ] - - # 0 - 1 - 2 - 3 4 - 5 - # | - # - 3 - # - 4 - 5 - # | - # 1 - 2 - 3 - - pred_edges = [ - ((0, 0), (1, 0)), - ((1, 0), (2, 0)), - ((2, 0), (3, 0)), - ((4, 0), (5, 0)), - ((2, 0), (3, 1)), - ((1, 2), (2, 2)), - ((2, 2), (3, 2)), - ((3, 2), (4, 1)), - ((4, 1), (5, 1)), - ] - - def to_str(x): - return "_".join([str(i) for i in x]) +def add_frame(tree): + attrs = {} + for node in tree.nodes: + attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} + nx.set_node_attributes(tree, attrs) + return tree - def to_tree(x): - return nx.from_edgelist( - [(to_str(n1), to_str(n2)) for n1, n2 in x], create_using=nx.DiGraph - ) - true_tree = to_tree(true_edges) - pred_tree = to_tree(pred_edges) +TEST_TREES = [ + { + "name": "simple1", + "gt_edges": [ + # 0 - 0 - 0 - 0 - 0 - 0 + # | + # - 1 - 1 - 1 + # + # 2 - 2 - 2 - 2 + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("3_0", "4_0"), + ("4_0", "5_0"), + ("2_0", "3_1"), + ("3_1", "4_1"), + ("4_1", "5_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ("3_2", "4_2"), + ], + "pred_edges": [ + # 0 - 0 - 0 - 0 0 - 0 + # | + # - 1 + # - 1 - 1 + # | + # 2 - 2 - 2 - + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("4_0", "5_0"), + ("2_0", "3_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ("3_2", "4_1"), + ("4_1", "5_1"), + ("4_1", "5_1"), + ("4_1", "5_1"), + ], + "results_with_division_edges": { + "track_purity": 7 / 9, + "target_effectiveness": 6 / 11, + }, + "results_without_division_edges": { + "track_purity": 5 / 7, + "target_effectiveness": 6 / 9, + }, + }, + # { + # "name" : "overlap", + # "gt_edges" : [ + # ("0_0", "1_0"), + # ("1_0", "2_0"), + # ("2_0", "3_0"), + # ("1_0", "2_1"), + # ("2_1", "3_1"), + # ], + # + # } +] - attrs = {} - for node in true_tree.nodes: - attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} - nx.set_node_attributes(true_tree, attrs) - attrs = {} - for node in pred_tree.nodes: - attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} - nx.set_node_attributes(pred_tree, attrs) +simple2 = deepcopy(TEST_TREES[0]) +simple2["name"] = "simple2" +# 0 - 0 - 0 - 0 0 - 0 +# | +# - 1 +# - 1 - 1 +# | +# 2 - 2 - 2 - +# | +# - 3 - 3 +simple2["pred_edges"].extend( + [ + ("2_2", "3_3"), + ("3_3", "4_3"), + ] +) +simple2["results_with_division_edges"] = { + "track_purity": 7 / 11, + "target_effectiveness": 6 / 11, +} +simple2["results_without_division_edges"] = { + "track_purity": 5 / 9, + "target_effectiveness": 6 / 9, +} +TEST_TREES.append(simple2) +assert TEST_TREES[0] != TEST_TREES[1] - mapping = [(n, n) for n in true_tree.nodes] - return true_tree, pred_tree, mapping +@pytest.mark.parametrize("data", TEST_TREES) +def test_track_overlap_metrics(data) -> None: + g_gt = add_frame(nx.from_edgelist(data["gt_edges"], create_using=nx.DiGraph)) + g_pred = add_frame(nx.from_edgelist(data["pred_edges"], create_using=nx.DiGraph)) + mapping = [(n, n) for n in g_gt.nodes] -def test_track_overlap_metrics(test_trees) -> None: - g_gt, g_pred, mapping = test_trees matched = DummyMatched( TrackingGraph(g_gt), TrackingGraph(g_pred), @@ -80,15 +121,13 @@ def test_track_overlap_metrics(test_trees) -> None: metric = TrackOverlapMetrics(matched) assert metric.results - assert metric.results == { - "track_purity": 7 / 9, - "target_effectiveness": 6 / 11, - } + assert ( + metric.results == data["results_with_division_edges"] + ), f"{data['name']} failed with division edges" metric = TrackOverlapMetrics(matched, include_division_edges=False) assert metric.results - assert metric.results == { - "track_purity": 5 / 7, - "target_effectiveness": 6 / 9, - } + assert ( + metric.results == data["results_without_division_edges"] + ), f"{data['name']} failed without division edges" From 3d60c234695573f6ee614c78297c926c9e0259f2 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Nov 2023 11:14:48 +0900 Subject: [PATCH 21/56] gt value fixed --- tests/metrics/test_track_overlap_metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index 8e96ef50..9edb9fa1 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -96,11 +96,11 @@ def add_frame(tree): ) simple2["results_with_division_edges"] = { "track_purity": 7 / 11, - "target_effectiveness": 6 / 11, + "target_effectiveness": 5 / 11, } simple2["results_without_division_edges"] = { - "track_purity": 5 / 9, - "target_effectiveness": 6 / 9, + "track_purity": 5 / 7, + "target_effectiveness": 5 / 9, } TEST_TREES.append(simple2) assert TEST_TREES[0] != TEST_TREES[1] From d90654bb318dac54608d57950978bac19ce64e55 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Nov 2023 21:13:05 +0900 Subject: [PATCH 22/56] test running --- tests/metrics/test_track_overlap_metrics.py | 60 ++++++++++++++++----- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index 9edb9fa1..53d48d2d 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -65,17 +65,50 @@ def add_frame(tree): "target_effectiveness": 6 / 9, }, }, - # { - # "name" : "overlap", - # "gt_edges" : [ - # ("0_0", "1_0"), - # ("1_0", "2_0"), - # ("2_0", "3_0"), - # ("1_0", "2_1"), - # ("2_1", "3_1"), - # ], - # - # } + { + "name": "overlap", + # 0 - 0 - 0 - 0 + # | + # - 1 - 1 + "gt_edges": [ + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("1_0", "2_1"), + ("2_1", "3_1"), + ], + # 0 - 0 - 0 + # | + # - 1 - 1 + # 2 - 2 - 2 + # (2 and 1 overlap) + "pred_edges": [ + ("0_0", "1_0"), + ("1_0", "2_0"), + ("1_0", "2_1"), + ("2_1", "3_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ], + "mapping": [ # GT to pred mapping + ("0_0", "0_0"), + ("1_0", "1_0"), + ("2_0", "2_0"), + ("3_0", "3_0"), + ("2_1", "2_1"), + ("3_1", "3_1"), + ("2_1", "2_2"), + ("3_1", "3_2"), + ], + "results_with_division_edges": { + "track_purity": 5 / 6, + "target_effectiveness": 4 / 5, + }, + "results_without_division_edges": { + "track_purity": 3 / 4, + "target_effectiveness": 2 / 3, + }, + }, ] simple2 = deepcopy(TEST_TREES[0]) @@ -110,7 +143,10 @@ def add_frame(tree): def test_track_overlap_metrics(data) -> None: g_gt = add_frame(nx.from_edgelist(data["gt_edges"], create_using=nx.DiGraph)) g_pred = add_frame(nx.from_edgelist(data["pred_edges"], create_using=nx.DiGraph)) - mapping = [(n, n) for n in g_gt.nodes] + if "mapping" in data: + mapping = data["mapping"] + else: + mapping = [(n, n) for n in g_gt.nodes] matched = DummyMatched( TrackingGraph(g_gt), From b2ae4526c5b9861bfa644867ab9511ca4496a398 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Fri, 10 Nov 2023 21:15:35 +0900 Subject: [PATCH 23/56] inverse test running --- tests/metrics/test_track_overlap_metrics.py | 27 +++++++++++++++------ 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index 53d48d2d..72e297fa 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -140,7 +140,8 @@ def add_frame(tree): @pytest.mark.parametrize("data", TEST_TREES) -def test_track_overlap_metrics(data) -> None: +@pytest.mark.parametrize("inverse", [False, True]) +def test_track_overlap_metrics(data, inverse) -> None: g_gt = add_frame(nx.from_edgelist(data["gt_edges"], create_using=nx.DiGraph)) g_pred = add_frame(nx.from_edgelist(data["pred_edges"], create_using=nx.DiGraph)) if "mapping" in data: @@ -148,6 +149,10 @@ def test_track_overlap_metrics(data) -> None: else: mapping = [(n, n) for n in g_gt.nodes] + if inverse: + g_gt, g_pred = g_pred, g_gt + mapping = [(b, a) for a, b in mapping] + matched = DummyMatched( TrackingGraph(g_gt), TrackingGraph(g_pred), @@ -157,13 +162,21 @@ def test_track_overlap_metrics(data) -> None: metric = TrackOverlapMetrics(matched) assert metric.results - assert ( - metric.results == data["results_with_division_edges"] - ), f"{data['name']} failed with division edges" + expected = data["results_with_division_edges"] + if inverse: + expected = { + "track_purity": expected["target_effectiveness"], + "target_effectiveness": expected["track_purity"], + } + assert metric.results == expected, f"{data['name']} failed with division edges" metric = TrackOverlapMetrics(matched, include_division_edges=False) assert metric.results - assert ( - metric.results == data["results_without_division_edges"] - ), f"{data['name']} failed without division edges" + expected = data["results_without_division_edges"] + if inverse: + expected = { + "track_purity": expected["target_effectiveness"], + "target_effectiveness": expected["track_purity"], + } + assert metric.results == expected, f"{data['name']} failed without division edges" From 655f7b1560374c2338432098cd0a0537c98ee00e Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 10 Nov 2023 10:53:28 -0800 Subject: [PATCH 24/56] Rename matchers._matched to matchers._base --- src/traccuracy/_run_metrics.py | 2 +- src/traccuracy/matchers/{_matched.py => _base.py} | 0 src/traccuracy/matchers/_ctc.py | 2 +- src/traccuracy/matchers/_iou.py | 2 +- src/traccuracy/metrics/_base.py | 2 +- src/traccuracy/track_errors/_ctc.py | 2 +- src/traccuracy/track_errors/divisions.py | 2 +- tests/metrics/test_divisions.py | 2 +- tests/track_errors/test_ctc_errors.py | 2 +- tests/track_errors/test_divisions.py | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename src/traccuracy/matchers/{_matched.py => _base.py} (100%) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 2c120c52..19fb02df 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Type from traccuracy import TrackingGraph - from traccuracy.matchers._matched import Matcher + from traccuracy.matchers._base import Matcher from traccuracy.metrics._base import Metric diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_base.py similarity index 100% rename from src/traccuracy/matchers/_matched.py rename to src/traccuracy/matchers/_base.py diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 3920af60..5b1c48e7 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -7,8 +7,8 @@ if TYPE_CHECKING: from traccuracy._tracking_graph import TrackingGraph +from ._base import Matched, Matcher from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched, Matcher class CTCMatcher(Matcher): diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 5667a586..4b2cdec8 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -3,8 +3,8 @@ from traccuracy._tracking_graph import TrackingGraph +from ._base import Matched, Matcher from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched, Matcher def _match_nodes(gt, res, threshold=1): diff --git a/src/traccuracy/metrics/_base.py b/src/traccuracy/metrics/_base.py index dd763678..4d80d651 100644 --- a/src/traccuracy/metrics/_base.py +++ b/src/traccuracy/metrics/_base.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers._base import Matched class Metric(ABC): diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 642399df..5388c9f0 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -7,7 +7,7 @@ from traccuracy import EdgeAttr, NodeAttr if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers._base import Matched logger = logging.getLogger(__name__) diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 9a432fe8..24f64bf0 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -8,7 +8,7 @@ from traccuracy._utils import find_gt_node_matches, find_pred_node_matches if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers._base import Matched logger = logging.getLogger(__name__) diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index d46537b4..935f8dde 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,5 +1,5 @@ from traccuracy import TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers._base import Matched from traccuracy.metrics._divisions import DivisionMetrics from tests.test_utils import get_division_graphs diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 98be167c..e90e05da 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,7 +1,7 @@ import networkx as nx import numpy as np from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers._base import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index bbb937fc..72a6e95a 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -2,7 +2,7 @@ import numpy as np import pytest from traccuracy import NodeAttr, TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers._base import Matched from traccuracy.track_errors.divisions import ( _classify_divisions, _correct_shifted_divisions, From f80d0d213cba4a8412dbfc64e1128fbc184ec23d Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 10 Nov 2023 12:50:11 -0800 Subject: [PATCH 25/56] Move graph copying into Matcher.compute_mapping instead of in Matched --- src/traccuracy/matchers/_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/matchers/_base.py b/src/traccuracy/matchers/_base.py index b7d33ac4..f10c8ee0 100644 --- a/src/traccuracy/matchers/_base.py +++ b/src/traccuracy/matchers/_base.py @@ -41,7 +41,10 @@ def compute_mapping( "Input data must be a TrackingData object with a graph and segmentations" ) - matched = self._compute_mapping(gt_graph, pred_graph) + # Copy graphs to avoid possible changes to graphs while computing mapping + matched = self._compute_mapping( + copy.deepcopy(gt_graph), copy.deepcopy(pred_graph) + ) # Report matching performance total_gt = len(matched.gt_graph.nodes()) @@ -74,7 +77,8 @@ class Matched: Args: gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt pred_graph (traccuracy.TrackingGraph): Tracking graph object for the pred - + mapping (list[tuple[Any, Any]]): List of tuples where each tuple maps + a gt node to a pred node """ def __init__( @@ -83,6 +87,6 @@ def __init__( pred_graph: TrackingGraph, mapping: list[tuple[Any, Any]], ): - self.gt_graph = copy.deepcopy(gt_graph) - self.pred_graph = copy.deepcopy(pred_graph) + self.gt_graph = gt_graph + self.pred_graph = pred_graph self.mapping = mapping From 8c3ca39b35baa2771e4259be2152461b8e266ddb Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Sat, 11 Nov 2023 18:34:06 +0900 Subject: [PATCH 26/56] updated mapping implementation --- src/traccuracy/metrics/_track_overlap.py | 50 +++++++++++++-------- tests/metrics/test_track_overlap_metrics.py | 8 +++- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 0ccac857..60491809 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -11,8 +11,8 @@ - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. """ -from itertools import product -from typing import TYPE_CHECKING, Any, List, Tuple +from itertools import groupby, product +from typing import TYPE_CHECKING, Any, Dict, List, Tuple from traccuracy._tracking_graph import TrackingGraph @@ -22,6 +22,26 @@ from ._base import Matched +def _mapping_to_dict(mapping: List[Tuple[Any, Any]]) -> Dict[Any, List[Any]]: + """Convert mapping list of tuples to dictionary. + + Args: + mapping (List[Tuple[Any, Any]]): Mapping list of tuples + + Returns: + Dict[Any, List[Any]]: Mapping dictionary + + """ + + def get_from_val(x): + return x[0] + + return { + k: [v[1] for v in vs] + for k, vs in groupby(sorted(mapping, key=get_from_val), key=get_from_val) + } + + class TrackOverlapMetrics(Metric): """Calculate metrics for longest track overlaps. @@ -50,10 +70,10 @@ def compute(self): include_division_edges=self.include_division_edges ) - gt_pred_mapping = self.data.mapping - pred_gt_mapping = [ - (pred_node, gt_node) for gt_node, pred_node in gt_pred_mapping - ] + gt_pred_mapping = _mapping_to_dict(self.data.mapping) + pred_gt_mapping = _mapping_to_dict( + [(pred_node, gt_node) for gt_node, pred_node in self.data.mapping] + ) # calculate track purity and target effectiveness track_purity = _calc_overlap_score( @@ -71,7 +91,7 @@ def compute(self): def _calc_overlap_score( reference_tracklets: List[TrackingGraph], overlap_tracklets: List[TrackingGraph], - overlap_reference_mapping: List[Tuple[Any, Any]], + overlap_reference_mapping: Dict[Any, List[Any]], ): """Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. @@ -79,28 +99,20 @@ def _calc_overlap_score( Args: reference_tracklets (List[TrackingGraph]): The reference tracklets overlap_tracklets (List[TrackingGraph]): The tracklets that overlap - mapping (List[Tuple[Any, Any]]): Mapping between the reference tracklet nodes - and the overlap tracklet nodes + overlap_reference_mapping (Dict[Any, List[Any]]): Mapping as a dict + from the overlap tracklet nodes to the reference tracklet nodes """ correct_count = 0 total_count = 0 - # iterate over the reference tracklets - - def map_node(overlap_node): - return [ - n_reference - for (n_overlap, n_reference) in overlap_reference_mapping - if n_overlap == overlap_node - ] # calculate all overlapping edges mapped onto GT ids overlap_tracklets_edges_mapped = [] for overlap_tracklet in overlap_tracklets: edges = [] for node1, node2 in overlap_tracklet.edges(): - mapped_nodes1 = map_node(node1) - mapped_nodes2 = map_node(node2) + mapped_nodes1 = overlap_reference_mapping.get(node1, []) + mapped_nodes2 = overlap_reference_mapping.get(node2, []) if mapped_nodes1 and mapped_nodes2: for n1, n2 in product(mapped_nodes1, mapped_nodes2): edges.append((n1, n2)) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index 72e297fa..d8253db2 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -3,7 +3,7 @@ import networkx as nx import pytest from traccuracy import TrackingGraph -from traccuracy.metrics._track_overlap import TrackOverlapMetrics +from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict from tests.test_utils import DummyMatched @@ -180,3 +180,9 @@ def test_track_overlap_metrics(data, inverse) -> None: "target_effectiveness": expected["track_purity"], } assert metric.results == expected, f"{data['name']} failed without division edges" + + +def test_mapping_to_dict(): + mapping = [("1", "2"), ("2", "3"), ("1", "3"), ("2", "3")] + mapping_dict = _mapping_to_dict(mapping) + assert mapping_dict == {"1": ["2", "3"], "2": ["3", "3"]} From 51eebf4e7156dceddc8e75e6ca1897438ac35ce1 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 13 Nov 2023 12:13:29 -0800 Subject: [PATCH 27/56] Eliminate use of DummyMatched in tests --- tests/metrics/test_track_overlap_metrics.py | 7 +++---- tests/test_utils.py | 10 ---------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index d8253db2..a0cd70c4 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -3,10 +3,9 @@ import networkx as nx import pytest from traccuracy import TrackingGraph +from traccuracy.matchers._base import Matched from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict -from tests.test_utils import DummyMatched - def add_frame(tree): attrs = {} @@ -153,10 +152,10 @@ def test_track_overlap_metrics(data, inverse) -> None: g_gt, g_pred = g_pred, g_gt mapping = [(b, a) for a, b in mapping] - matched = DummyMatched( + matched = Matched( TrackingGraph(g_gt), TrackingGraph(g_pred), - mapper=mapping, + mapping, ) metric = TrackOverlapMetrics(matched) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6586886f..1dbf6d7c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,6 @@ import numpy as np import skimage as sk from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._matched import Matched def get_annotated_image(img_size=256, num_labels=3, sequential=True, seed=1): @@ -149,12 +148,3 @@ def get_division_graphs(): mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")] return G1, G2, mapper - - -class DummyMatched(Matched): - def __init__(self, gt_data, pred_data, mapper): - self.mapper = mapper - super().__init__(gt_data, pred_data) - - def compute_mapping(self): - return self.mapper From 9a94ebcc32ee6c9a159fd69a65022c9e7cb87159 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 13 Nov 2023 12:22:22 -0800 Subject: [PATCH 28/56] Update track overlap with new metrics API --- src/traccuracy/metrics/_track_overlap.py | 37 ++++++++++----------- tests/metrics/test_track_overlap_metrics.py | 14 ++++---- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 60491809..628cfae4 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -10,19 +10,20 @@ - TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. """ +from __future__ import annotations from itertools import groupby, product -from typing import TYPE_CHECKING, Any, Dict, List, Tuple - -from traccuracy._tracking_graph import TrackingGraph +from typing import TYPE_CHECKING, Any from ._base import Metric if TYPE_CHECKING: + from traccuracy._tracking_graph import TrackingGraph + from ._base import Matched -def _mapping_to_dict(mapping: List[Tuple[Any, Any]]) -> Dict[Any, List[Any]]: +def _mapping_to_dict(mapping: list[tuple[Any, Any]]) -> dict[Any, list[Any]]: """Convert mapping list of tuples to dictionary. Args: @@ -51,28 +52,26 @@ class TrackOverlapMetrics(Metric): tracklets on each prediction tracklet Args: - matched_data (Matched): Matched object for set of GT and Pred data include_division_edges (bool, optional): If True, include edges at division. """ supports_many_to_one = True - def __init__(self, matched_data: "Matched", include_division_edges: bool = True): + def __init__(self, include_division_edges: bool = True): self.include_division_edges = include_division_edges - super().__init__(matched_data) - def compute(self): - gt_tracklets = self.data.gt_graph.get_tracklets( + def compute(self, matched: Matched) -> dict: + gt_tracklets = matched.gt_graph.get_tracklets( include_division_edges=self.include_division_edges ) - pred_tracklets = self.data.pred_graph.get_tracklets( + pred_tracklets = matched.pred_graph.get_tracklets( include_division_edges=self.include_division_edges ) - gt_pred_mapping = _mapping_to_dict(self.data.mapping) + gt_pred_mapping = _mapping_to_dict(matched.mapping) pred_gt_mapping = _mapping_to_dict( - [(pred_node, gt_node) for gt_node, pred_node in self.data.mapping] + [(pred_node, gt_node) for gt_node, pred_node in matched.mapping] ) # calculate track purity and target effectiveness @@ -89,18 +88,18 @@ def compute(self): def _calc_overlap_score( - reference_tracklets: List[TrackingGraph], - overlap_tracklets: List[TrackingGraph], - overlap_reference_mapping: Dict[Any, List[Any]], + reference_tracklets: list[TrackingGraph], + overlap_tracklets: list[TrackingGraph], + overlap_reference_mapping: dict[Any, list[Any]], ): """Calculate weighted sum of the length of the longest overlap tracklet for each reference tracklet. Args: - reference_tracklets (List[TrackingGraph]): The reference tracklets - overlap_tracklets (List[TrackingGraph]): The tracklets that overlap - overlap_reference_mapping (Dict[Any, List[Any]]): Mapping as a dict - from the overlap tracklet nodes to the reference tracklet nodes + reference_tracklets (List[TrackingGraph]): The reference tracklets + overlap_tracklets (List[TrackingGraph]): The tracklets that overlap + overlap_reference_mapping (Dict[Any, List[Any]]): Mapping as a dict + from the overlap tracklet nodes to the reference tracklet nodes """ correct_count = 0 diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index d8253db2..c6852b70 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -159,8 +159,9 @@ def test_track_overlap_metrics(data, inverse) -> None: mapper=mapping, ) - metric = TrackOverlapMetrics(matched) - assert metric.results + metric = TrackOverlapMetrics() + results = metric.compute(matched) + assert results expected = data["results_with_division_edges"] if inverse: @@ -168,10 +169,11 @@ def test_track_overlap_metrics(data, inverse) -> None: "track_purity": expected["target_effectiveness"], "target_effectiveness": expected["track_purity"], } - assert metric.results == expected, f"{data['name']} failed with division edges" + assert results == expected, f"{data['name']} failed with division edges" - metric = TrackOverlapMetrics(matched, include_division_edges=False) - assert metric.results + metric = TrackOverlapMetrics(include_division_edges=False) + results = metric.compute(matched) + assert results expected = data["results_without_division_edges"] if inverse: @@ -179,7 +181,7 @@ def test_track_overlap_metrics(data, inverse) -> None: "track_purity": expected["target_effectiveness"], "target_effectiveness": expected["track_purity"], } - assert metric.results == expected, f"{data['name']} failed without division edges" + assert results == expected, f"{data['name']} failed without division edges" def test_mapping_to_dict(): From e16f339db43f9954df140d2260a20791d544635a Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 13 Nov 2023 13:06:06 -0800 Subject: [PATCH 29/56] Fix docstring typo --- src/traccuracy/matchers/_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traccuracy/matchers/_base.py b/src/traccuracy/matchers/_base.py index f10c8ee0..bbfe07d0 100644 --- a/src/traccuracy/matchers/_base.py +++ b/src/traccuracy/matchers/_base.py @@ -72,7 +72,7 @@ class Matched: """Matched data class which stores TrackingGraph objects for gt and pred and the computed mapping - Each TrackingGraph will be a new copy on the original object + Each TrackingGraph will be a new copy of the original object Args: gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt From 4cf32e056b9225245d6c31fb84d15287120512b2 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 13 Nov 2023 13:12:32 -0800 Subject: [PATCH 30/56] Import Matched in matchers init and change imports accordingly --- src/traccuracy/matchers/__init__.py | 3 ++- src/traccuracy/metrics/_base.py | 2 +- src/traccuracy/metrics/_ctc.py | 2 +- src/traccuracy/metrics/_track_overlap.py | 2 +- src/traccuracy/track_errors/_ctc.py | 2 +- src/traccuracy/track_errors/divisions.py | 2 +- tests/metrics/test_divisions.py | 2 +- tests/metrics/test_track_overlap_metrics.py | 2 +- tests/track_errors/test_ctc_errors.py | 2 +- tests/track_errors/test_divisions.py | 2 +- 10 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/traccuracy/matchers/__init__.py b/src/traccuracy/matchers/__init__.py index e99c7ce4..e7a641b1 100644 --- a/src/traccuracy/matchers/__init__.py +++ b/src/traccuracy/matchers/__init__.py @@ -25,8 +25,9 @@ While we specify ground truth and prediction, it is possible to write a matching function that matches two arbitrary tracking solutions. """ +from ._base import Matched from ._compute_overlap import get_labels_with_overlap from ._ctc import CTCMatcher from ._iou import IOUMatcher -__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap"] +__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap", "Matched"] diff --git a/src/traccuracy/metrics/_base.py b/src/traccuracy/metrics/_base.py index 4d80d651..e2636799 100644 --- a/src/traccuracy/metrics/_base.py +++ b/src/traccuracy/metrics/_base.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from traccuracy.matchers._base import Matched + from traccuracy.matchers import Matched class Metric(ABC): diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index a5f0a84f..83d272ac 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -6,7 +6,7 @@ from ._base import Metric if TYPE_CHECKING: - from ._base import Matched + from traccuracy.matchers import Matched class AOGMMetrics(Metric): diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index 60491809..c2cae052 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -19,7 +19,7 @@ from ._base import Metric if TYPE_CHECKING: - from ._base import Matched + from traccuracy.matchers import Matched def _mapping_to_dict(mapping: List[Tuple[Any, Any]]) -> Dict[Any, List[Any]]: diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 5388c9f0..1a109a83 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -7,7 +7,7 @@ from traccuracy import EdgeAttr, NodeAttr if TYPE_CHECKING: - from traccuracy.matchers._base import Matched + from traccuracy.matchers import Matched logger = logging.getLogger(__name__) diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 24f64bf0..d2b359ec 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -8,7 +8,7 @@ from traccuracy._utils import find_gt_node_matches, find_pred_node_matches if TYPE_CHECKING: - from traccuracy.matchers._base import Matched + from traccuracy.matchers import Matched logger = logging.getLogger(__name__) diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index 935f8dde..e07ea1aa 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,5 +1,5 @@ from traccuracy import TrackingGraph -from traccuracy.matchers._base import Matched +from traccuracy.matchers import Matched from traccuracy.metrics._divisions import DivisionMetrics from tests.test_utils import get_division_graphs diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py index a0cd70c4..4a08b37f 100644 --- a/tests/metrics/test_track_overlap_metrics.py +++ b/tests/metrics/test_track_overlap_metrics.py @@ -3,7 +3,7 @@ import networkx as nx import pytest from traccuracy import TrackingGraph -from traccuracy.matchers._base import Matched +from traccuracy.matchers import Matched from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index e90e05da..53d4a7f7 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,7 +1,7 @@ import networkx as nx import numpy as np from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph -from traccuracy.matchers._base import Matched +from traccuracy.matchers import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 72a6e95a..6538e644 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -2,7 +2,7 @@ import numpy as np import pytest from traccuracy import NodeAttr, TrackingGraph -from traccuracy.matchers._base import Matched +from traccuracy.matchers import Matched from traccuracy.track_errors.divisions import ( _classify_divisions, _correct_shifted_divisions, From 44a39523fbdf5bc7dfcd2c0cb863b2d6b6427ad0 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 10:38:34 -0800 Subject: [PATCH 31/56] Fix docs crossreferencing errors for Matched --- src/traccuracy/_run_metrics.py | 2 +- src/traccuracy/_utils.py | 2 +- src/traccuracy/matchers/_ctc.py | 2 +- src/traccuracy/metrics/_track_overlap.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 19fb02df..38f9ebd6 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -29,7 +29,7 @@ def run_metrics( Args: gt_data (TrackingData): ground truth graph and optionally segmentation pred_data (TrackingData): predicted graph and optionally segmentation - matcher (Matched): matching class to use to create correspondence + matcher (traccuracy.matchers.Matched): matching class to use to create correspondence metrics (List[Metric]): list of metrics to compute as class names matcher_kwargs (optional, dict): Dictionary of keyword argument for the matcher class diff --git a/src/traccuracy/_utils.py b/src/traccuracy/_utils.py index 7fce666e..3b6c0413 100644 --- a/src/traccuracy/_utils.py +++ b/src/traccuracy/_utils.py @@ -25,7 +25,7 @@ def validate_matched_data(matched_data, metrics): """Validate that given matcher supports requirements of each metric. Args: - matched_data (Matched): matching class with mapping between gt and pred + matched_data (traccuracy.matcher.Matched): matching class with mapping between gt and pred metrics (List[Metric]): list of metrics to compute as class names """ ... diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 5b1c48e7..ca245e36 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -32,7 +32,7 @@ def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph pred_graph (TrackingGraph): Tracking graph object for the pred Returns: - Matched: Matched data object containing the CTC mapping + traccuracy.matchers.Matched: Matched data object containing the CTC mapping Raises: ValueError: GT and pred segmentations must be the same shape diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py index c2cae052..84567fe6 100644 --- a/src/traccuracy/metrics/_track_overlap.py +++ b/src/traccuracy/metrics/_track_overlap.py @@ -51,7 +51,7 @@ class TrackOverlapMetrics(Metric): tracklets on each prediction tracklet Args: - matched_data (Matched): Matched object for set of GT and Pred data + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data include_division_edges (bool, optional): If True, include edges at division. """ From 16be9ca47727dd17345905a2e24b7b46ea2e9a6b Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 10:54:18 -0800 Subject: [PATCH 32/56] One more docstring reference fix --- src/traccuracy/track_errors/_ctc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 1a109a83..775690ef 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -26,7 +26,7 @@ def get_vertex_errors(matched_data: "Matched"): Parameters ---------- - matched_data: Matched + matched_data: traccuracy.matchers.Matched Matched data object containing gt and pred graphs with their associated mapping """ comp_graph = matched_data.pred_graph From 0c7910c249416a486cf0002566cee5587c3369ef Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 12:59:50 -0800 Subject: [PATCH 33/56] Correct import of Matched --- src/traccuracy/metrics/_divisions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 63aeec95..569da84b 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -42,7 +42,7 @@ from ._base import Metric if TYPE_CHECKING: - from ._base import Matched + from traccuracy.matchers import Matched class DivisionMetrics(Metric): From 2d86b27d31f0c9f4c9d0025db66ee80ada04ca8c Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 16:37:20 -0800 Subject: [PATCH 34/56] Remove validate_matched_data function since it has not been implemented --- src/traccuracy/_run_metrics.py | 3 +-- src/traccuracy/_utils.py | 10 ---------- 2 files changed, 1 insertion(+), 12 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 4b7a96dc..033de6de 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from traccuracy._utils import get_relevant_kwargs, validate_matched_data +from traccuracy._utils import get_relevant_kwargs if TYPE_CHECKING: from typing import Dict, List, Optional, Type @@ -43,7 +43,6 @@ def run_metrics( if matcher_kwargs is None: matcher_kwargs = {} matched = matcher(**matcher_kwargs).compute_mapping(gt_data, pred_data) - validate_matched_data(matched, metrics) metric_kwarg_dict = { m_class: get_relevant_kwargs(m_class, metrics_kwargs) for m_class in metrics } diff --git a/src/traccuracy/_utils.py b/src/traccuracy/_utils.py index 3b6c0413..add9550c 100644 --- a/src/traccuracy/_utils.py +++ b/src/traccuracy/_utils.py @@ -21,16 +21,6 @@ def find_pred_node_matches(matches, pred_node): return [pair[0] for pair in matches if pair[1] == pred_node] -def validate_matched_data(matched_data, metrics): - """Validate that given matcher supports requirements of each metric. - - Args: - matched_data (traccuracy.matcher.Matched): matching class with mapping between gt and pred - metrics (List[Metric]): list of metrics to compute as class names - """ - ... - - def get_relevant_kwargs(metric_class, kwargs): """Get all params in kwargs that are valid for given metric class. From 0c809f0bf768f9e65463ece1222313553e3f039d Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 16:46:50 -0800 Subject: [PATCH 35/56] Refactor run_metrics to use instantiated objects for matchers and metrics' --- src/traccuracy/_run_metrics.py | 44 +++++++++++----------------------- src/traccuracy/_utils.py | 31 ------------------------ 2 files changed, 14 insertions(+), 61 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 033de6de..264c025d 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,9 +1,7 @@ from typing import TYPE_CHECKING -from traccuracy._utils import get_relevant_kwargs - if TYPE_CHECKING: - from typing import Dict, List, Optional, Type + from typing import Dict, List from traccuracy import TrackingGraph from traccuracy.matchers._base import Matcher @@ -13,42 +11,28 @@ def run_metrics( gt_data: "TrackingGraph", pred_data: "TrackingGraph", - matcher: "Type[Matcher]", - metrics: "List[Type[Metric]]", - matcher_kwargs: "Optional[Dict]" = None, - metrics_kwargs: "Optional[Dict]" = None, # weights + matcher: "Matcher", + metrics: "List[Metric]", ) -> "Dict": """Compute given metrics on data using the given matcher. - An error will be thrown if the given matcher is not compatible with - all metrics in the given list. The returned result dictionary will - contain all metrics computed by the given Metric classes, as well as - general summary numbers e.g. false positive/false negative detection - and edge counts. + The returned result dictionary will contain all metrics computed by + the given Metric classes, as well as general summary numbers + e.g. false positive/false negative detection and edge counts. Args: - gt_data (TrackingData): ground truth graph and optionally segmentation - pred_data (TrackingData): predicted graph and optionally segmentation - matcher (traccuracy.matchers.Matched): matching class to use to create correspondence - metrics (List[Metric]): list of metrics to compute as class names - matcher_kwargs (optional, dict): Dictionary of keyword argument for the - matcher class - metric_kwargs (optional, dict): Dictionary of any keyword args for the - Metric classes + gt_data (TrackingGraph): ground truth graph and optionally segmentation + pred_data (TrackingGraph): predicted graph and optionally segmentation + matcher (traccuracy.matchers.Matcher): instantiated matcher object + metrics (List[Metric]): list of instantiated metrics objects to compute Returns: Dict: dictionary of metrics indexed by metric name. Dictionary will be - nested for metrics that return multiple values. + nested for metrics that return multiple values. """ - if matcher_kwargs is None: - matcher_kwargs = {} - matched = matcher(**matcher_kwargs).compute_mapping(gt_data, pred_data) - metric_kwarg_dict = { - m_class: get_relevant_kwargs(m_class, metrics_kwargs) for m_class in metrics - } + matched = matcher.compute_mapping(gt_data, pred_data) results = {} for _metric in metrics: - relevant_kwargs = metric_kwarg_dict[_metric] - result = _metric(**relevant_kwargs).compute(matched) - results[_metric.__name__] = result + result = _metric.compute(matched) + results[_metric.__class__.__name__] = result return results diff --git a/src/traccuracy/_utils.py b/src/traccuracy/_utils.py index add9550c..599a43b6 100644 --- a/src/traccuracy/_utils.py +++ b/src/traccuracy/_utils.py @@ -1,6 +1,3 @@ -import inspect - - def find_gt_node_matches(matches, gt_node): """For a given gt node, finds all pred nodes that are matches @@ -19,31 +16,3 @@ def find_pred_node_matches(matches, pred_node): pred_node (hashable): pred node ID """ return [pair[0] for pair in matches if pair[1] == pred_node] - - -def get_relevant_kwargs(metric_class, kwargs): - """Get all params in kwargs that are valid for given metric class. - - If required parameters are not satisfied, an error is raised. - - Args: - metric_class (Metric): class name of metric to check - kwargs (dict): dictionary of keyword arguments to validate - """ - sig = inspect.signature(metric_class) - relevant_kwargs = {} - missing_args = [] - for param in sig.parameters.values(): - name = param.name - is_required = (param.default is param.empty) and name != "matched_data" - if kwargs and name in kwargs: - relevant_kwargs[name] = kwargs[name] - elif is_required: - missing_args.append(name) - if missing_args: - raise ValueError( - f"Metric class {metric_class.__name__} is missing required" - + f" arguments: {missing_args}. Add arguments to" - + " `run_metrics` or consider skipping this metric." - ) - return relevant_kwargs From 4b802608a54a906811f1383f5216cbee01dee482 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 16:54:20 -0800 Subject: [PATCH 36/56] Fix calls to run_metrics in cli --- src/traccuracy/cli.py | 38 ++++++++++++++++---------------------- 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index 4a5d9ff8..b4164db8 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -57,7 +57,7 @@ def run_ctc( f"Only cell tracking challenge (ctc) loader is available, but {loader} was passed." ) gt_data, pred_data = load_all_ctc(gt_dir, pred_dir, gt_track_path, pred_track_path) - result = run_metrics(gt_data, pred_data, CTCMatcher, [CTCMetrics]) + result = run_metrics(gt_data, pred_data, CTCMatcher(), [CTCMetrics()]) with open(out_path, "w") as fp: json.dump(result, fp) logger.info(f'TRA: {result["CTCMetrics"]["TRA"]}') @@ -120,16 +120,17 @@ def run_aogm( result = run_metrics( gt_data, pred_data, - CTCMatcher, - [AOGMMetrics], - metrics_kwargs={ - "vertex_ns_weight": vertex_ns_weight, - "vertex_fp_weight": vertex_fp_weight, - "vertex_fn_weight": vertex_fn_weight, - "edge_fp_weight": edge_fp_weight, - "edge_fn_weight": edge_fn_weight, - "edge_ws_weight": edge_ws_weight, - }, + CTCMatcher(), + [ + AOGMMetrics( + vertex_ns_weight=vertex_ns_weight, + vertex_fp_weight=vertex_fp_weight, + vertex_fn_weight=vertex_fn_weight, + edge_fp_weight=edge_fp_weight, + edge_fn_weight=edge_fn_weight, + edge_ws_weight=edge_ws_weight, + ) + ], ) with open(out_path, "w") as fp: json.dump(result, fp) @@ -185,12 +186,8 @@ def run_divisions_on_iou( result = run_metrics( gt_data, pred_data, - IOUMatcher, - [DivisionMetrics], - matcher_kwargs={"iou_threshold": match_threshold}, - metrics_kwargs={ - "frame_buffer": frame_buffer_tuple, - }, + IOUMatcher(iou_threshold=match_threshold), + [DivisionMetrics(frame_buffer=frame_buffer_tuple)], ) with open(out_path, "w") as fp: json.dump(result, fp) @@ -244,11 +241,8 @@ def run_divisions_on_ctc( result = run_metrics( gt_data, pred_data, - CTCMatcher, - [DivisionMetrics], - metrics_kwargs={ - "frame_buffer": frame_buffer_tuple, - }, + CTCMatcher(), + [DivisionMetrics(frame_buffer=frame_buffer_tuple)], ) with open(out_path, "w") as fp: json.dump(result, fp) From b364238513bb940cf144bf0b547f0335d4c2b7af Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 16:59:32 -0800 Subject: [PATCH 37/56] Add minimal test case for run_metrics --- tests/test_run_metrics.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 tests/test_run_metrics.py diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py new file mode 100644 index 00000000..7d0db69f --- /dev/null +++ b/tests/test_run_metrics.py @@ -0,0 +1,15 @@ +from test_utils import get_movie_with_graph +from traccuracy import run_metrics +from traccuracy.matchers import CTCMatcher +from traccuracy.metrics import CTCMetrics + + +def test_run_metrics(): + graph = get_movie_with_graph() + + metric = CTCMetrics() + matcher = CTCMatcher() + + results = run_metrics(graph, graph, matcher, [metric]) + assert isinstance(results, dict) + assert "CTCMetrics" in results From f54eaefb1272fb6c58cc9b9752cd4cd920240cc8 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 17:02:19 -0800 Subject: [PATCH 38/56] Fix docs cross referencing --- src/traccuracy/_run_metrics.py | 6 +++--- src/traccuracy/metrics/_divisions.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 264c025d..0c9b39be 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -21,9 +21,9 @@ def run_metrics( e.g. false positive/false negative detection and edge counts. Args: - gt_data (TrackingGraph): ground truth graph and optionally segmentation - pred_data (TrackingGraph): predicted graph and optionally segmentation - matcher (traccuracy.matchers.Matcher): instantiated matcher object + gt_data (traccuracy.TrackingGraph): ground truth graph and optionally segmentation + pred_data (traccuracy.TrackingGraph): predicted graph and optionally segmentation + matcher (Matcher): instantiated matcher object metrics (List[Metric]): list of instantiated metrics objects to compute Returns: diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 569da84b..70aebe48 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -57,7 +57,7 @@ class DivisionMetrics(Metric): 3, 734559 (2021). Args: - matched_data (Matched): Matched object for set of GT and Pred data + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data Must meet the `needs_one_to_one` criteria frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames to tolerate in correct_shifted_divisions. Defaults to (0). From d82f6de6bb896dd50c2c582e18a4bbe906830663 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Tue, 14 Nov 2023 17:04:59 -0800 Subject: [PATCH 39/56] Fix import for test utils --- tests/test_run_metrics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py index 7d0db69f..1c9dfd22 100644 --- a/tests/test_run_metrics.py +++ b/tests/test_run_metrics.py @@ -1,8 +1,9 @@ -from test_utils import get_movie_with_graph from traccuracy import run_metrics from traccuracy.matchers import CTCMatcher from traccuracy.metrics import CTCMetrics +from tests.test_utils import get_movie_with_graph + def test_run_metrics(): graph = get_movie_with_graph() From 0c788594df2a1ae90b789a25389d604806751b70 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 15 Nov 2023 11:51:39 -0800 Subject: [PATCH 40/56] Move docstring under class declaration --- src/traccuracy/matchers/_iou.py | 13 +++++++------ src/traccuracy/metrics/__init__.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 4b2cdec8..189c4e6d 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -99,14 +99,15 @@ def match_iou(gt, pred, threshold=0.6): class IOUMatcher(Matcher): - def __init__(self, iou_threshold=0.6): - """Constructs a mapping between gt and pred nodes using the IoU of the segmentations + """Constructs a mapping between gt and pred nodes using the IoU of the segmentations - Lower values for iou_threshold will be more permissive of imperfect matches + Lower values for iou_threshold will be more permissive of imperfect matches - Args: - iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6. - """ + Args: + iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6. + """ + + def __init__(self, iou_threshold=0.6): self.iou_threshold = iou_threshold def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): diff --git a/src/traccuracy/metrics/__init__.py b/src/traccuracy/metrics/__init__.py index f5facc93..109130a3 100644 --- a/src/traccuracy/metrics/__init__.py +++ b/src/traccuracy/metrics/__init__.py @@ -1,4 +1,5 @@ from ._ctc import AOGMMetrics, CTCMetrics from ._divisions import DivisionMetrics +from ._track_overlap import TrackOverlapMetrics -__all__ = ["CTCMetrics", "DivisionMetrics", "AOGMMetrics"] +__all__ = ["CTCMetrics", "DivisionMetrics", "AOGMMetrics", "TrackOverlapMetrics"] From b42318edd1c16c01d89e57693af510d3a6e5d46e Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Wed, 15 Nov 2023 16:23:06 -0800 Subject: [PATCH 41/56] Support metrics with different parameters --- src/traccuracy/_run_metrics.py | 22 ++++++++---- tests/test_run_metrics.py | 64 ++++++++++++++++++++++++++++++---- 2 files changed, 73 insertions(+), 13 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 0c9b39be..6ef28703 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,12 +1,12 @@ from typing import TYPE_CHECKING +# from traccuracy import TrackingGraph +# from traccuracy.matchers._base import Matcher +# from traccuracy.metrics._base import Metric + if TYPE_CHECKING: from typing import Dict, List - from traccuracy import TrackingGraph - from traccuracy.matchers._base import Matcher - from traccuracy.metrics._base import Metric - def run_metrics( gt_data: "TrackingGraph", @@ -30,9 +30,19 @@ def run_metrics( Dict: dictionary of metrics indexed by metric name. Dictionary will be nested for metrics that return multiple values. """ + # if not isinstance(gt_data, TrackingGraph) or not isinstance(pred_data, TrackingGraph): + # raise TypeError("gt_data and pred_data must be TrackingGraph objects") + + # if not isinstance(matcher, Matcher): + # raise TypeError("matcher must be an instantiated Matcher object") + + # if not all([isinstance(m, Metric) for m in metrics]): + # raise TypeError("metrics must be a list of instantiated Metric objects") + matched = matcher.compute_mapping(gt_data, pred_data) - results = {} + results = [] for _metric in metrics: result = _metric.compute(matched) - results[_metric.__class__.__name__] = result + report = {_metric.__class__.__name__: result, "parameters": _metric.__dict__} + results.append(report) return results diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py index 1c9dfd22..89f43056 100644 --- a/tests/test_run_metrics.py +++ b/tests/test_run_metrics.py @@ -1,16 +1,66 @@ from traccuracy import run_metrics -from traccuracy.matchers import CTCMatcher -from traccuracy.metrics import CTCMetrics +from traccuracy.matchers._base import Matched, Matcher +from traccuracy.metrics._base import Metric from tests.test_utils import get_movie_with_graph +class DummyMetric(Metric): + def compute(self, matched): + return {} + + +class DummyMetricParam(Metric): + def __init__(self, param="value"): + self.param = param + + def compute(self, matched): + return {} + + +class DummyMatcher(Matcher): + def __init__(self, mapping=[]): + self.mapping = mapping + + def _compute_mapping(self, gt_graph, pred_graph): + return Matched(gt_graph, pred_graph, self.mapping) + + def test_run_metrics(): graph = get_movie_with_graph() + mapping = [(n, n) for n in graph.nodes()] + + # # Check matcher input -- not instantiated + # with pytest.raises(TypeError): + # run_metrics(graph, graph, DummyMatcher, [DummyMetric()]) + + # # Check matcher input -- wrong type + # with pytest.raises(TypeError): + # run_metrics(graph, graph, 'rando', DummyMetric()) + + # # Check metric input -- not instantiated + # with pytest.raises(TypeError): + # run_metrics(graph, graph, DummyMatcher(), [DummyMetric]) + + # # Check metric input -- wrong type + # with pytest.raises(TypeError): + # run_metrics(graph, graph, DummyMatcher(), [DummyMetric(), 'rando']) - metric = CTCMetrics() - matcher = CTCMatcher() + # One metric + results = run_metrics(graph, graph, DummyMatcher(mapping), [DummyMetric()]) + assert isinstance(results, list) + assert len(results) == 1 + assert "DummyMetric" in results[0] - results = run_metrics(graph, graph, matcher, [metric]) - assert isinstance(results, dict) - assert "CTCMetrics" in results + # Duplicate metric with different params + results = run_metrics( + graph, + graph, + DummyMatcher(mapping), + [DummyMetricParam("param1"), DummyMetricParam("param2")], + ) + assert len(results) == 2 + assert "DummyMetricParam" in results[0] + assert results[0]["parameters"] == {"param": "param1"} + assert "DummyMetricParam" in results[1] + assert results[1]["parameters"] == {"param": "param2"} From ded3c1c0b42b53912858e4619dd90fa6fa2d0a80 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 11:40:52 -0800 Subject: [PATCH 42/56] Add pull request templates --- .github/PULL_REQUEST_TEMPLATE/general.md | 30 +++++++++++++++++++ .../new_matcher_metric.md | 18 +++++++++++ 2 files changed, 48 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE/general.md create mode 100644 .github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md diff --git a/.github/PULL_REQUEST_TEMPLATE/general.md b/.github/PULL_REQUEST_TEMPLATE/general.md new file mode 100644 index 00000000..70267dd0 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/general.md @@ -0,0 +1,30 @@ +# Proposed Change +Briefly describe the contribution. If it resolves an issue or feature request, be sure to link to that issue. + +# Types of Changes +What types of changes does your code introduce? Put an x in the boxes that apply. +- [ ] Bugfix (non-breaking change which fixes an issue) +- [ ] New feature or enhancement +- [ ] Documentation update +- [ ] Tests and benchmarks +- [ ] Maintenance (e.g. dependencies, CI, releases, etc.) + +Which topics does your change affect? Put an x in the boxes that apply. +- [ ] Loaders +- [ ] Matchers +- [ ] Track Errors +- [ ] Metrics +- [ ] Core functionality (e.g. `TrackingGraph`, `run_metrics`, `cli`, etc.) + +# Checklist +Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code. + +- [ ] I have read the developer/contributing docs. +- [ ] I have added tests that prove that my feature works in various situations or tests the bugfix (if appropriate). +- [ ] I have checked that I maintained or improved code coverage. +- [ ] I have checked the benchmarking action to verify that my changes did not adversely affect performance. +- [ ] I have written docstrings and checked that they render correctly in the Read The Docs build (created after the PR is opened). +- [ ] I have updated the general documentation including Metric descriptions and example notebooks if necessary. + +# Further Comments +If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc... \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md new file mode 100644 index 00000000..b85ff09a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md @@ -0,0 +1,18 @@ +# Proposed Matcher or Metric Addition +- [ ] Matcher +- [ ] Metric + +Briefly describe your new Matcher or Metric class, including links to publication or other source code if relevant. A full description should be included in the documentation. If it resolves a feature request, be sure to link to that issue. + +# Checklist +Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code. + +- [ ] I have read the developer/contributing docs. +- [ ] I have added tests that prove that my feature works in various situations. +- [ ] I have checked that I maintained or improved code coverage. +- [ ] I have added benchmarking functions for my change `tests/bench.py`. +- [ ] I have added a page to the documentation with a complete description of my matcher/metric including any references. +- [ ] I have written docstrings and checked that they render correctly in the Read The Docs build (created after the PR is opened). + +# Further Comments +If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc... \ No newline at end of file From 1e72ee77ed9f34f6cd218628dfd9581d9f320490 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 11:44:34 -0800 Subject: [PATCH 43/56] Add config for pr templates --- .github/PULL_REQUEST_TEMPLATE/general.md | 9 +++++++++ .github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/.github/PULL_REQUEST_TEMPLATE/general.md b/.github/PULL_REQUEST_TEMPLATE/general.md index 70267dd0..b37fe15f 100644 --- a/.github/PULL_REQUEST_TEMPLATE/general.md +++ b/.github/PULL_REQUEST_TEMPLATE/general.md @@ -1,3 +1,12 @@ +--- +name: General Pull Request +about: Bugfixes, enhancements, documentation, etc. +title: '' +labels: '' +assignees: '' + +--- + # Proposed Change Briefly describe the contribution. If it resolves an issue or feature request, be sure to link to that issue. diff --git a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md index b85ff09a..fe3f4188 100644 --- a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md +++ b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md @@ -1,3 +1,12 @@ +--- +name: New Matcher or Metric +about: A new Matcher or Metric class +title: '' +labels: '' +assignees: '' + +--- + # Proposed Matcher or Metric Addition - [ ] Matcher - [ ] Metric From e197fc31349da82eac64c775d3267308c213a56d Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 11:51:28 -0800 Subject: [PATCH 44/56] New issue templates --- .../bugs.md} | 18 ++++++++-- .github/ISSUE_TEMPLATE/features.md | 36 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) rename .github/{ISSUE_TEMPLATE.md => ISSUE_TEMPLATE/bugs.md} (57%) create mode 100644 .github/ISSUE_TEMPLATE/features.md diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/bugs.md similarity index 57% rename from .github/ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE/bugs.md index 3fd3d74e..8d86e4a8 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE/bugs.md @@ -1,15 +1,29 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + * traccuracy version: * Python version: * Operating System: -### Description +# Description Describe what you were trying to get done. Tell us what happened, what went wrong, and what you expected to happen. -### What I Did +# What I Did ``` Paste the command(s) you ran and the output. If there was a crash, please include the traceback here. ``` + +# Severity +- [ ] Unusable +- [ ] Annoying, but still functional +- [ ] Very minor \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/features.md b/.github/ISSUE_TEMPLATE/features.md new file mode 100644 index 00000000..d1c44dd1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/features.md @@ -0,0 +1,36 @@ +--- +name: Feature +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +# Description + +Please describe the feature that you would like to see implemented in `traccuracy`. + +# Topics + +What types of changes are you suggesting? Put an x in the boxes that apply. +- [ ] New feature or enhancement +- [ ] Documentation update +- [ ] Tests and benchmarks +- [ ] Maintenance (e.g. dependencies, CI, releases, etc.) + +Which topics does your change affect? Put an x in the boxes that apply. +- [ ] Loaders +- [ ] Matchers +- [ ] Track Errors +- [ ] Metrics +- [ ] Core functionality (e.g. `TrackingGraph`, `run_metrics`, `cli`, etc.) + +# Priority +- [ ] This is an essential feature +- [ ] Nice to have +- [ ] Future idea + +# Are you interested in contributing? +- [ ] Yes! :tada: +- [ ] No :slightly_frowning_face: \ No newline at end of file From 4c67f66d772d09d1c695772e2396e58155db2683 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 11:52:06 -0800 Subject: [PATCH 45/56] Minimal example --- .github/ISSUE_TEMPLATE/bugs.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ISSUE_TEMPLATE/bugs.md b/.github/ISSUE_TEMPLATE/bugs.md index 8d86e4a8..61b0bc6f 100644 --- a/.github/ISSUE_TEMPLATE/bugs.md +++ b/.github/ISSUE_TEMPLATE/bugs.md @@ -16,7 +16,7 @@ assignees: '' Describe what you were trying to get done. Tell us what happened, what went wrong, and what you expected to happen. -# What I Did +# Minimal example to reproduce the bug ``` Paste the command(s) you ran and the output. From dbb25b3518ec1495962d576a55e96eba4d413db7 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 12:00:53 -0800 Subject: [PATCH 46/56] Create general pr template --- .../general.md => PULL_REQUEST_TEMPLATE.md} | 2 ++ 1 file changed, 2 insertions(+) rename .github/{PULL_REQUEST_TEMPLATE/general.md => PULL_REQUEST_TEMPLATE.md} (92%) diff --git a/.github/PULL_REQUEST_TEMPLATE/general.md b/.github/PULL_REQUEST_TEMPLATE.md similarity index 92% rename from .github/PULL_REQUEST_TEMPLATE/general.md rename to .github/PULL_REQUEST_TEMPLATE.md index b37fe15f..e97d7e26 100644 --- a/.github/PULL_REQUEST_TEMPLATE/general.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -7,6 +7,8 @@ assignees: '' --- +If you are implementing a new matcher or metric, please append this `&template=new_matcher_metric.md` to your url to load the correct template. + # Proposed Change Briefly describe the contribution. If it resolves an issue or feature request, be sure to link to that issue. From b6bef63a89b365392c40eb7c9c2e964834a022fb Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 12:02:23 -0800 Subject: [PATCH 47/56] Remove Pr template headers --- .github/PULL_REQUEST_TEMPLATE.md | 9 --------- .github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md | 9 --------- 2 files changed, 18 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index e97d7e26..77c7b114 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,12 +1,3 @@ ---- -name: General Pull Request -about: Bugfixes, enhancements, documentation, etc. -title: '' -labels: '' -assignees: '' - ---- - If you are implementing a new matcher or metric, please append this `&template=new_matcher_metric.md` to your url to load the correct template. # Proposed Change diff --git a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md index fe3f4188..b85ff09a 100644 --- a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md +++ b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md @@ -1,12 +1,3 @@ ---- -name: New Matcher or Metric -about: A new Matcher or Metric class -title: '' -labels: '' -assignees: '' - ---- - # Proposed Matcher or Metric Addition - [ ] Matcher - [ ] Metric From 541f9c0c80c39964cab6e0f1bae4cd53fbbced20 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 18:40:20 -0800 Subject: [PATCH 48/56] Fix circular imports and reenable input validation --- src/traccuracy/__init__.py | 4 +++- src/traccuracy/_run_metrics.py | 20 +++++++++++--------- tests/test_run_metrics.py | 32 ++++++++++++++++++-------------- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/traccuracy/__init__.py b/src/traccuracy/__init__.py index 6fbbeeb8..bf617211 100644 --- a/src/traccuracy/__init__.py +++ b/src/traccuracy/__init__.py @@ -6,7 +6,9 @@ except PackageNotFoundError: # pragma: no cover __version__ = "uninstalled" -from ._run_metrics import run_metrics from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +# must go after TrackingGraph to avoid circular imports +from ._run_metrics import run_metrics # isort:skip + __all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"] diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 6ef28703..9093e5b2 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING -# from traccuracy import TrackingGraph -# from traccuracy.matchers._base import Matcher -# from traccuracy.metrics._base import Metric +from traccuracy import TrackingGraph +from traccuracy.matchers._base import Matcher +from traccuracy.metrics._base import Metric if TYPE_CHECKING: from typing import Dict, List @@ -30,14 +30,16 @@ def run_metrics( Dict: dictionary of metrics indexed by metric name. Dictionary will be nested for metrics that return multiple values. """ - # if not isinstance(gt_data, TrackingGraph) or not isinstance(pred_data, TrackingGraph): - # raise TypeError("gt_data and pred_data must be TrackingGraph objects") + if not isinstance(gt_data, TrackingGraph) or not isinstance( + pred_data, TrackingGraph + ): + raise TypeError("gt_data and pred_data must be TrackingGraph objects") - # if not isinstance(matcher, Matcher): - # raise TypeError("matcher must be an instantiated Matcher object") + if not isinstance(matcher, Matcher): + raise TypeError("matcher must be an instantiated Matcher object") - # if not all([isinstance(m, Metric) for m in metrics]): - # raise TypeError("metrics must be a list of instantiated Metric objects") + if not all(isinstance(m, Metric) for m in metrics): + raise TypeError("metrics must be a list of instantiated Metric objects") matched = matcher.compute_mapping(gt_data, pred_data) results = [] diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py index 89f43056..4ca6ae67 100644 --- a/tests/test_run_metrics.py +++ b/tests/test_run_metrics.py @@ -1,3 +1,4 @@ +import pytest from traccuracy import run_metrics from traccuracy.matchers._base import Matched, Matcher from traccuracy.metrics._base import Metric @@ -19,8 +20,11 @@ def compute(self, matched): class DummyMatcher(Matcher): - def __init__(self, mapping=[]): - self.mapping = mapping + def __init__(self, mapping=None): + if mapping: + self.mapping = mapping + else: + self.mapping = [] def _compute_mapping(self, gt_graph, pred_graph): return Matched(gt_graph, pred_graph, self.mapping) @@ -30,21 +34,21 @@ def test_run_metrics(): graph = get_movie_with_graph() mapping = [(n, n) for n in graph.nodes()] - # # Check matcher input -- not instantiated - # with pytest.raises(TypeError): - # run_metrics(graph, graph, DummyMatcher, [DummyMetric()]) + # Check matcher input -- not instantiated + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher, [DummyMetric()]) - # # Check matcher input -- wrong type - # with pytest.raises(TypeError): - # run_metrics(graph, graph, 'rando', DummyMetric()) + # Check matcher input -- wrong type + with pytest.raises(TypeError): + run_metrics(graph, graph, "rando", DummyMetric()) - # # Check metric input -- not instantiated - # with pytest.raises(TypeError): - # run_metrics(graph, graph, DummyMatcher(), [DummyMetric]) + # Check metric input -- not instantiated + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher(), [DummyMetric]) - # # Check metric input -- wrong type - # with pytest.raises(TypeError): - # run_metrics(graph, graph, DummyMatcher(), [DummyMetric(), 'rando']) + # Check metric input -- wrong type + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher(), [DummyMetric(), "rando"]) # One metric results = run_metrics(graph, graph, DummyMatcher(mapping), [DummyMetric()]) From e127a544a10c08753079972f4928fb56f3029fb9 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 18:49:14 -0800 Subject: [PATCH 49/56] Improve output structure --- src/traccuracy/_run_metrics.py | 7 ++++--- tests/test_run_metrics.py | 10 +++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 9093e5b2..92f6c852 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -13,7 +13,7 @@ def run_metrics( pred_data: "TrackingGraph", matcher: "Matcher", metrics: "List[Metric]", -) -> "Dict": +) -> "List[Dict]": """Compute given metrics on data using the given matcher. The returned result dictionary will contain all metrics computed by @@ -45,6 +45,7 @@ def run_metrics( results = [] for _metric in metrics: result = _metric.compute(matched) - report = {_metric.__class__.__name__: result, "parameters": _metric.__dict__} - results.append(report) + metric_dict = _metric.__dict__ + metric_dict["name"] = _metric.__class__.__name__ + results.append({"results": result, "metric": metric_dict}) return results diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py index 4ca6ae67..5972e9f8 100644 --- a/tests/test_run_metrics.py +++ b/tests/test_run_metrics.py @@ -54,7 +54,7 @@ def test_run_metrics(): results = run_metrics(graph, graph, DummyMatcher(mapping), [DummyMetric()]) assert isinstance(results, list) assert len(results) == 1 - assert "DummyMetric" in results[0] + assert results[0]["metric"]["name"] == "DummyMetric" # Duplicate metric with different params results = run_metrics( @@ -64,7 +64,7 @@ def test_run_metrics(): [DummyMetricParam("param1"), DummyMetricParam("param2")], ) assert len(results) == 2 - assert "DummyMetricParam" in results[0] - assert results[0]["parameters"] == {"param": "param1"} - assert "DummyMetricParam" in results[1] - assert results[1]["parameters"] == {"param": "param2"} + assert results[0]["metric"]["name"] == "DummyMetricParam" + assert results[0]["metric"].get("param") == "param1" + assert results[1]["metric"]["name"] == "DummyMetricParam" + assert results[1]["metric"].get("param") == "param2" From 1fbd519e1d84fe8ad9ee9f0fef730578b00c0905 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 18:53:56 -0800 Subject: [PATCH 50/56] Correct usage of run_metrics output in cli --- src/traccuracy/cli.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index b4164db8..ce178b83 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -60,8 +60,8 @@ def run_ctc( result = run_metrics(gt_data, pred_data, CTCMatcher(), [CTCMetrics()]) with open(out_path, "w") as fp: json.dump(result, fp) - logger.info(f'TRA: {result["CTCMetrics"]["TRA"]}') - logger.info(f'DET: {result["CTCMetrics"]["DET"]}') + logger.info(f'TRA: {result[0]["results"]["TRA"]}') + logger.info(f'DET: {result[0]["results"]["DET"]}') @app.command() @@ -134,7 +134,7 @@ def run_aogm( ) with open(out_path, "w") as fp: json.dump(result, fp) - logger.info(f'AOGM: {result["AOGMMetrics"]["AOGM"]}') + logger.info(f'AOGM: {result[0]["results"]["AOGM"]}') @app.command() @@ -192,7 +192,7 @@ def run_divisions_on_iou( with open(out_path, "w") as fp: json.dump(result, fp) res_str = "" - for frame_buffer, res_dict in result["DivisionMetrics"].items(): + for frame_buffer, res_dict in result[0]["results"].items(): res_str += f'{frame_buffer} F1: {res_dict["Division F1"]}\n' logger.info(res_str) @@ -247,7 +247,7 @@ def run_divisions_on_ctc( with open(out_path, "w") as fp: json.dump(result, fp) res_str = "" - for frame_buffer, res_dict in result["DivisionMetrics"].items(): + for frame_buffer, res_dict in result[0]["results"].items(): res_str += f'{frame_buffer} F1: {res_dict["Division F1"]}\n' logger.info(res_str) From ba5be035d9a10a4d432513747b81c09130062b36 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 18:58:53 -0800 Subject: [PATCH 51/56] Update example notebook after several api changes --- examples/ctc.ipynb | 170 ++++++++++++++++++++++++--------------------- 1 file changed, 91 insertions(+), 79 deletions(-) diff --git a/examples/ctc.ipynb b/examples/ctc.ipynb index f2758a65..330e3c68 100644 --- a/examples/ctc.ipynb +++ b/examples/ctc.ipynb @@ -24,7 +24,7 @@ "\n", "from traccuracy import run_metrics\n", "from traccuracy.loaders import load_ctc_data\n", - "from traccuracy.matchers import CTCMatched, IOUMatched\n", + "from traccuracy.matchers import CTCMatcher, IOUMatcher\n", "from traccuracy.metrics import CTCMetrics, DivisionMetrics\n", "\n", "pp = pprint.PrettyPrinter(indent=4)" @@ -63,7 +63,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "Fluo-N2DL-HeLa.zip: 191MB [00:18, 10.2MB/s] \n" + "Fluo-N2DL-HeLa.zip: 0.00B [00:00, ?B/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fluo-N2DL-HeLa.zip: 191MB [00:15, 12.1MB/s] \n" ] } ], @@ -96,8 +103,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 374.71it/s]\n", - "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 824.06it/s]\n" + "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 388.26it/s]\n", + "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 640.22it/s]\n" ] } ], @@ -130,48 +137,50 @@ "name": "stderr", "output_type": "stream", "text": [ - "Matching frames: 100%|██████████| 92/92 [00:13<00:00, 6.65it/s]\n", - "Evaluating nodes: 100%|██████████| 92/92 [00:00<00:00, 10573.68it/s]\n", - "Evaluating edges: 100%|██████████| 8535/8535 [00:06<00:00, 1359.15it/s]\n" + "Matching frames: 100%|██████████| 92/92 [00:00<00:00, 93.42it/s] \n", + "Evaluating nodes: 100%|██████████| 8600/8600 [00:00<00:00, 721911.19it/s]\n", + "Evaluating FP edges: 100%|██████████| 8535/8535 [00:00<00:00, 968440.00it/s]\n", + "Evaluating FN edges: 100%|██████████| 8562/8562 [00:00<00:00, 1054425.71it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "{ 'CTCMetrics': { 'AOGM': 631.5,\n", - " 'DET': 0.9954855886097927,\n", - " 'TRA': 0.9936361895740329,\n", - " 'fn_edges': 87,\n", - " 'fn_nodes': 39,\n", - " 'fp_edges': 60,\n", - " 'fp_nodes': 0,\n", - " 'ns_nodes': 0,\n", - " 'ws_edges': 51},\n", - " 'DivisionMetrics': { 'Frame Buffer 0': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 1': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 2': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76}}}\n" + "[ { 'metric': { 'e_weights': {'fn': 1.5, 'fp': 1, 'ws': 1},\n", + " 'name': 'CTCMetrics',\n", + " 'v_weights': {'fn': 10, 'fp': 1, 'ns': 5}},\n", + " 'results': { 'AOGM': 627.5,\n", + " 'DET': 0.9954855886097927,\n", + " 'TRA': 0.993676498745377,\n", + " 'fn_edges': 87,\n", + " 'fn_nodes': 39,\n", + " 'fp_edges': 60,\n", + " 'fp_nodes': 0,\n", + " 'ns_nodes': 0,\n", + " 'ws_edges': 47}},\n", + " { 'metric': {'frame_buffer': (0, 1, 2), 'name': 'DivisionMetrics'},\n", + " 'results': { 'Frame Buffer 0': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76},\n", + " 'Frame Buffer 1': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76},\n", + " 'Frame Buffer 2': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76}}}]\n" ] } ], @@ -179,11 +188,8 @@ "ctc_results = run_metrics(\n", " gt_data=gt_data, \n", " pred_data=pred_data, \n", - " matcher=CTCMatched, \n", - " metrics=[CTCMetrics, DivisionMetrics],\n", - " metrics_kwargs=dict(\n", - " frame_buffer=(0,1,2)\n", - " )\n", + " matcher=CTCMatcher(), \n", + " metrics=[CTCMetrics(), DivisionMetrics(frame_buffer=(0,1,2))],\n", ")\n", "pp.pprint(ctc_results)" ] @@ -198,37 +204,42 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Matching frames: 100%|██████████| 92/92 [00:15<00:00, 6.03it/s]\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "{ 'DivisionMetrics': { 'Frame Buffer 0': { 'Division F1': 0,\n", - " 'Division Precision': 0.0,\n", - " 'Division Recall': 0.0,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.0,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 0},\n", - " 'Frame Buffer 1': { 'Division F1': 0.44837758112094395,\n", - " 'Division Precision': 0.44970414201183434,\n", - " 'Division Recall': 0.4470588235294118,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.2889733840304182,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 2': { 'Division F1': 0.44837758112094395,\n", - " 'Division Precision': 0.44970414201183434,\n", - " 'Division Recall': 0.4470588235294118,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.2889733840304182,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76}}}\n" + "[ { 'metric': {'frame_buffer': (0, 1, 2), 'name': 'DivisionMetrics'},\n", + " 'results': { 'Frame Buffer 0': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69},\n", + " 'Frame Buffer 1': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69},\n", + " 'Frame Buffer 2': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69}}}]\n" ] } ], @@ -236,17 +247,18 @@ "iou_results = run_metrics(\n", " gt_data=gt_data, \n", " pred_data=pred_data, \n", - " matcher=IOUMatched, \n", - " matcher_kwargs=dict(\n", - " iou_threshold=0.5\n", - " ),\n", - " metrics=[DivisionMetrics], \n", - " metrics_kwargs=dict(\n", - " frame_buffer=(0,1,2)\n", - " )\n", + " matcher=IOUMatcher(iou_threshold=0.1), \n", + " metrics=[DivisionMetrics(frame_buffer=(0,1,2))], \n", ")\n", "pp.pprint(iou_results)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -265,7 +277,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.12" }, "orig_nbformat": 4 }, From 749cbe540f091b5b55cd780e754370da87ed2e0f Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Fri, 17 Nov 2023 19:18:35 -0800 Subject: [PATCH 52/56] Correct args in docstring for division metrics --- src/traccuracy/metrics/_divisions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 569da84b..5bac4542 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -57,8 +57,6 @@ class DivisionMetrics(Metric): 3, 734559 (2021). Args: - matched_data (Matched): Matched object for set of GT and Pred data - Must meet the `needs_one_to_one` criteria frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames to tolerate in correct_shifted_divisions. Defaults to (0). """ @@ -71,6 +69,10 @@ def __init__(self, frame_buffer=(0,)): def compute(self, data: Matched): """Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer + Args: + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data + Must meet the `needs_one_to_one` criteria + Returns: dict: Returns a nested dictionary with one dictionary per frame buffer value """ From b1ba7b485ab7bc5277ac63be99a1250490af7ebc Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 20 Nov 2023 10:49:53 -0800 Subject: [PATCH 53/56] Update docstring with new output for run metrics --- src/traccuracy/_run_metrics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 92f6c852..a5fb6b94 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -27,8 +27,7 @@ def run_metrics( metrics (List[Metric]): list of instantiated metrics objects to compute Returns: - Dict: dictionary of metrics indexed by metric name. Dictionary will be - nested for metrics that return multiple values. + List[Dict]: List of dictionaries with one dictionary per Metric object """ if not isinstance(gt_data, TrackingGraph) or not isinstance( pred_data, TrackingGraph From 3e64444241195387f6bd4fd0a27aee6c5aca8584 Mon Sep 17 00:00:00 2001 From: msschwartz21 Date: Mon, 20 Nov 2023 10:52:08 -0800 Subject: [PATCH 54/56] Fix indentation error --- src/traccuracy/metrics/_divisions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 5bac4542..4bd494fa 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -70,8 +70,8 @@ def compute(self, data: Matched): """Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer Args: - matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data - Must meet the `needs_one_to_one` criteria + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data + Must meet the `needs_one_to_one` criteria Returns: dict: Returns a nested dictionary with one dictionary per frame buffer value From 3aea4c149ffe16cb91d81e3983f5624a4a651f15 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Mon, 20 Nov 2023 15:24:13 -0500 Subject: [PATCH 55/56] fix import loop --- src/traccuracy/__init__.py | 4 +--- src/traccuracy/_run_metrics.py | 2 +- src/traccuracy/track_errors/_ctc.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/traccuracy/__init__.py b/src/traccuracy/__init__.py index bf617211..6fbbeeb8 100644 --- a/src/traccuracy/__init__.py +++ b/src/traccuracy/__init__.py @@ -6,9 +6,7 @@ except PackageNotFoundError: # pragma: no cover __version__ = "uninstalled" +from ._run_metrics import run_metrics from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph -# must go after TrackingGraph to avoid circular imports -from ._run_metrics import run_metrics # isort:skip - __all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"] diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index a5fb6b94..8e328922 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from traccuracy import TrackingGraph +from traccuracy._tracking_graph import TrackingGraph from traccuracy.matchers._base import Matcher from traccuracy.metrics._base import Metric diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 775690ef..27bdc8dc 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -4,7 +4,7 @@ from tqdm import tqdm -from traccuracy import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeAttr, NodeAttr if TYPE_CHECKING: from traccuracy.matchers import Matched From 2ed94df23575a8a93cc82e82adb0c2fd9315d691 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 14:00:59 +1100 Subject: [PATCH 56/56] ci(pre-commit.ci): autoupdate (#113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/crate-ci/typos: v1.16.21 → v1.16.23](https://github.com/crate-ci/typos/compare/v1.16.21...v1.16.23) - [github.com/astral-sh/ruff-pre-commit: v0.1.3 → v0.1.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.1.3...v0.1.4) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Draga Doncila Pop <17995243+DragaDoncila@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49d55888..aed3a561 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ ci: repos: - repo: https://github.com/crate-ci/typos - rev: v1.16.21 + rev: v1.16.23 hooks: - id: typos - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.4 hooks: - id: ruff args: [--fix]