diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE/bugs.md similarity index 54% rename from .github/ISSUE_TEMPLATE.md rename to .github/ISSUE_TEMPLATE/bugs.md index 3fd3d74e..61b0bc6f 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE/bugs.md @@ -1,15 +1,29 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + * traccuracy version: * Python version: * Operating System: -### Description +# Description Describe what you were trying to get done. Tell us what happened, what went wrong, and what you expected to happen. -### What I Did +# Minimal example to reproduce the bug ``` Paste the command(s) you ran and the output. If there was a crash, please include the traceback here. ``` + +# Severity +- [ ] Unusable +- [ ] Annoying, but still functional +- [ ] Very minor \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/features.md b/.github/ISSUE_TEMPLATE/features.md new file mode 100644 index 00000000..d1c44dd1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/features.md @@ -0,0 +1,36 @@ +--- +name: Feature +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +# Description + +Please describe the feature that you would like to see implemented in `traccuracy`. + +# Topics + +What types of changes are you suggesting? Put an x in the boxes that apply. +- [ ] New feature or enhancement +- [ ] Documentation update +- [ ] Tests and benchmarks +- [ ] Maintenance (e.g. dependencies, CI, releases, etc.) + +Which topics does your change affect? Put an x in the boxes that apply. +- [ ] Loaders +- [ ] Matchers +- [ ] Track Errors +- [ ] Metrics +- [ ] Core functionality (e.g. `TrackingGraph`, `run_metrics`, `cli`, etc.) + +# Priority +- [ ] This is an essential feature +- [ ] Nice to have +- [ ] Future idea + +# Are you interested in contributing? +- [ ] Yes! :tada: +- [ ] No :slightly_frowning_face: \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..77c7b114 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,32 @@ +If you are implementing a new matcher or metric, please append this `&template=new_matcher_metric.md` to your url to load the correct template. + +# Proposed Change +Briefly describe the contribution. If it resolves an issue or feature request, be sure to link to that issue. + +# Types of Changes +What types of changes does your code introduce? Put an x in the boxes that apply. +- [ ] Bugfix (non-breaking change which fixes an issue) +- [ ] New feature or enhancement +- [ ] Documentation update +- [ ] Tests and benchmarks +- [ ] Maintenance (e.g. dependencies, CI, releases, etc.) + +Which topics does your change affect? Put an x in the boxes that apply. +- [ ] Loaders +- [ ] Matchers +- [ ] Track Errors +- [ ] Metrics +- [ ] Core functionality (e.g. `TrackingGraph`, `run_metrics`, `cli`, etc.) + +# Checklist +Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code. + +- [ ] I have read the developer/contributing docs. +- [ ] I have added tests that prove that my feature works in various situations or tests the bugfix (if appropriate). +- [ ] I have checked that I maintained or improved code coverage. +- [ ] I have checked the benchmarking action to verify that my changes did not adversely affect performance. +- [ ] I have written docstrings and checked that they render correctly in the Read The Docs build (created after the PR is opened). +- [ ] I have updated the general documentation including Metric descriptions and example notebooks if necessary. + +# Further Comments +If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc... \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md new file mode 100644 index 00000000..b85ff09a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/new_matcher_metric.md @@ -0,0 +1,18 @@ +# Proposed Matcher or Metric Addition +- [ ] Matcher +- [ ] Metric + +Briefly describe your new Matcher or Metric class, including links to publication or other source code if relevant. A full description should be included in the documentation. If it resolves a feature request, be sure to link to that issue. + +# Checklist +Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code. + +- [ ] I have read the developer/contributing docs. +- [ ] I have added tests that prove that my feature works in various situations. +- [ ] I have checked that I maintained or improved code coverage. +- [ ] I have added benchmarking functions for my change `tests/bench.py`. +- [ ] I have added a page to the documentation with a complete description of my matcher/metric including any references. +- [ ] I have written docstrings and checked that they render correctly in the Read The Docs build (created after the PR is opened). + +# Further Comments +If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc... \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 49d55888..aed3a561 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,12 +5,12 @@ ci: repos: - repo: https://github.com/crate-ci/typos - rev: v1.16.21 + rev: v1.16.23 hooks: - id: typos - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.4 hooks: - id: ruff args: [--fix] diff --git a/examples/ctc.ipynb b/examples/ctc.ipynb index f2758a65..330e3c68 100644 --- a/examples/ctc.ipynb +++ b/examples/ctc.ipynb @@ -24,7 +24,7 @@ "\n", "from traccuracy import run_metrics\n", "from traccuracy.loaders import load_ctc_data\n", - "from traccuracy.matchers import CTCMatched, IOUMatched\n", + "from traccuracy.matchers import CTCMatcher, IOUMatcher\n", "from traccuracy.metrics import CTCMetrics, DivisionMetrics\n", "\n", "pp = pprint.PrettyPrinter(indent=4)" @@ -63,7 +63,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "Fluo-N2DL-HeLa.zip: 191MB [00:18, 10.2MB/s] \n" + "Fluo-N2DL-HeLa.zip: 0.00B [00:00, ?B/s]" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Fluo-N2DL-HeLa.zip: 191MB [00:15, 12.1MB/s] \n" ] } ], @@ -96,8 +103,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 374.71it/s]\n", - "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 824.06it/s]\n" + "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 388.26it/s]\n", + "Loading TIFFs: 100%|██████████| 92/92 [00:00<00:00, 640.22it/s]\n" ] } ], @@ -130,48 +137,50 @@ "name": "stderr", "output_type": "stream", "text": [ - "Matching frames: 100%|██████████| 92/92 [00:13<00:00, 6.65it/s]\n", - "Evaluating nodes: 100%|██████████| 92/92 [00:00<00:00, 10573.68it/s]\n", - "Evaluating edges: 100%|██████████| 8535/8535 [00:06<00:00, 1359.15it/s]\n" + "Matching frames: 100%|██████████| 92/92 [00:00<00:00, 93.42it/s] \n", + "Evaluating nodes: 100%|██████████| 8600/8600 [00:00<00:00, 721911.19it/s]\n", + "Evaluating FP edges: 100%|██████████| 8535/8535 [00:00<00:00, 968440.00it/s]\n", + "Evaluating FN edges: 100%|██████████| 8562/8562 [00:00<00:00, 1054425.71it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "{ 'CTCMetrics': { 'AOGM': 631.5,\n", - " 'DET': 0.9954855886097927,\n", - " 'TRA': 0.9936361895740329,\n", - " 'fn_edges': 87,\n", - " 'fn_nodes': 39,\n", - " 'fp_edges': 60,\n", - " 'fp_nodes': 0,\n", - " 'ns_nodes': 0,\n", - " 'ws_edges': 51},\n", - " 'DivisionMetrics': { 'Frame Buffer 0': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 1': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 2': { 'Division F1': 0.76,\n", - " 'Division Precision': 0.7169811320754716,\n", - " 'Division Recall': 0.8085106382978723,\n", - " 'False Negative Divisions': 18,\n", - " 'False Positive Divisions': 30,\n", - " 'Mitotic Branching Correctness': 0.6129032258064516,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76}}}\n" + "[ { 'metric': { 'e_weights': {'fn': 1.5, 'fp': 1, 'ws': 1},\n", + " 'name': 'CTCMetrics',\n", + " 'v_weights': {'fn': 10, 'fp': 1, 'ns': 5}},\n", + " 'results': { 'AOGM': 627.5,\n", + " 'DET': 0.9954855886097927,\n", + " 'TRA': 0.993676498745377,\n", + " 'fn_edges': 87,\n", + " 'fn_nodes': 39,\n", + " 'fp_edges': 60,\n", + " 'fp_nodes': 0,\n", + " 'ns_nodes': 0,\n", + " 'ws_edges': 47}},\n", + " { 'metric': {'frame_buffer': (0, 1, 2), 'name': 'DivisionMetrics'},\n", + " 'results': { 'Frame Buffer 0': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76},\n", + " 'Frame Buffer 1': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76},\n", + " 'Frame Buffer 2': { 'Division F1': 0.76,\n", + " 'Division Precision': 0.7169811320754716,\n", + " 'Division Recall': 0.8085106382978723,\n", + " 'False Negative Divisions': 18,\n", + " 'False Positive Divisions': 30,\n", + " 'Mitotic Branching Correctness': 0.6129032258064516,\n", + " 'True Positive Divisions': 76}}}]\n" ] } ], @@ -179,11 +188,8 @@ "ctc_results = run_metrics(\n", " gt_data=gt_data, \n", " pred_data=pred_data, \n", - " matcher=CTCMatched, \n", - " metrics=[CTCMetrics, DivisionMetrics],\n", - " metrics_kwargs=dict(\n", - " frame_buffer=(0,1,2)\n", - " )\n", + " matcher=CTCMatcher(), \n", + " metrics=[CTCMetrics(), DivisionMetrics(frame_buffer=(0,1,2))],\n", ")\n", "pp.pprint(ctc_results)" ] @@ -198,37 +204,42 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Matching frames: 100%|██████████| 92/92 [00:15<00:00, 6.03it/s]\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "{ 'DivisionMetrics': { 'Frame Buffer 0': { 'Division F1': 0,\n", - " 'Division Precision': 0.0,\n", - " 'Division Recall': 0.0,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.0,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 0},\n", - " 'Frame Buffer 1': { 'Division F1': 0.44837758112094395,\n", - " 'Division Precision': 0.44970414201183434,\n", - " 'Division Recall': 0.4470588235294118,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.2889733840304182,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76},\n", - " 'Frame Buffer 2': { 'Division F1': 0.44837758112094395,\n", - " 'Division Precision': 0.44970414201183434,\n", - " 'Division Recall': 0.4470588235294118,\n", - " 'False Negative Divisions': 94,\n", - " 'False Positive Divisions': 93,\n", - " 'Mitotic Branching Correctness': 0.2889733840304182,\n", - " 'Total GT Divisions': 94,\n", - " 'True Positive Divisions': 76}}}\n" + "[ { 'metric': {'frame_buffer': (0, 1, 2), 'name': 'DivisionMetrics'},\n", + " 'results': { 'Frame Buffer 0': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69},\n", + " 'Frame Buffer 1': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69},\n", + " 'Frame Buffer 2': { 'Division F1': 0.711340206185567,\n", + " 'Division Precision': 0.69,\n", + " 'Division Recall': 0.7340425531914894,\n", + " 'False Negative Divisions': 25,\n", + " 'False Positive Divisions': 31,\n", + " 'Mitotic Branching Correctness': 0.552,\n", + " 'True Positive Divisions': 69}}}]\n" ] } ], @@ -236,17 +247,18 @@ "iou_results = run_metrics(\n", " gt_data=gt_data, \n", " pred_data=pred_data, \n", - " matcher=IOUMatched, \n", - " matcher_kwargs=dict(\n", - " iou_threshold=0.5\n", - " ),\n", - " metrics=[DivisionMetrics], \n", - " metrics_kwargs=dict(\n", - " frame_buffer=(0,1,2)\n", - " )\n", + " matcher=IOUMatcher(iou_threshold=0.1), \n", + " metrics=[DivisionMetrics(frame_buffer=(0,1,2))], \n", ")\n", "pp.pprint(iou_results)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -265,7 +277,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.12" }, "orig_nbformat": 4 }, diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 91c2fe42..7d39c2e5 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,55 +1,47 @@ from __future__ import annotations -from typing import TYPE_CHECKING - -from traccuracy._utils import get_relevant_kwargs, validate_matched_data - -if TYPE_CHECKING: - from traccuracy import TrackingGraph - from traccuracy.matchers._matched import Matched - from traccuracy.metrics._base import Metric +from traccuracy._tracking_graph import TrackingGraph +from traccuracy.matchers._base import Matcher +from traccuracy.metrics._base import Metric def run_metrics( gt_data: TrackingGraph, pred_data: TrackingGraph, - matcher: type[Matched], - metrics: list[type[Metric]], - matcher_kwargs: dict | None = None, - metrics_kwargs: dict | None = None, # weights -) -> dict: + matcher: Matcher, + metrics: list[Metric], +) -> list[dict]: """Compute given metrics on data using the given matcher. - An error will be thrown if the given matcher is not compatible with - all metrics in the given list. The returned result dictionary will - contain all metrics computed by the given Metric classes, as well as - general summary numbers e.g. false positive/false negative detection - and edge counts. + The returned result dictionary will contain all metrics computed by + the given Metric classes, as well as general summary numbers + e.g. false positive/false negative detection and edge counts. Args: - gt_data (TrackingData): ground truth graph and optionally segmentation - pred_data (TrackingData): predicted graph and optionally segmentation - matcher (Matched): matching class to use to create correspondence - metrics (List[Metric]): list of metrics to compute as class names - matcher_kwargs (optional, dict): Dictionary of keyword argument for the - matcher class - metric_kwargs (optional, dict): Dictionary of any keyword args for the - Metric classes + gt_data (traccuracy.TrackingGraph): ground truth graph and optionally segmentation + pred_data (traccuracy.TrackingGraph): predicted graph and optionally segmentation + matcher (Matcher): instantiated matcher object + metrics (List[Metric]): list of instantiated metrics objects to compute Returns: - Dict: dictionary of metrics indexed by metric name. Dictionary will be - nested for metrics that return multiple values. + List[Dict]: List of dictionaries with one dictionary per Metric object """ - if matcher_kwargs is None: - matcher_kwargs = {} - matched = matcher(gt_data, pred_data, **matcher_kwargs) - validate_matched_data(matched, metrics) - metric_kwarg_dict = { - m_class: get_relevant_kwargs(m_class, metrics_kwargs) for m_class in metrics - } - results = {} + if not isinstance(gt_data, TrackingGraph) or not isinstance( + pred_data, TrackingGraph + ): + raise TypeError("gt_data and pred_data must be TrackingGraph objects") + + if not isinstance(matcher, Matcher): + raise TypeError("matcher must be an instantiated Matcher object") + + if not all(isinstance(m, Metric) for m in metrics): + raise TypeError("metrics must be a list of instantiated Metric objects") + + matched = matcher.compute_mapping(gt_data, pred_data) + results = [] for _metric in metrics: - relevant_kwargs = metric_kwarg_dict[_metric] - result = _metric(matched, **relevant_kwargs) - results[_metric.__name__] = result.results + result = _metric.compute(matched) + metric_dict = _metric.__dict__ + metric_dict["name"] = _metric.__class__.__name__ + results.append({"results": result, "metric": metric_dict}) return results diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 1153ea6e..082e0ac3 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -616,21 +616,35 @@ def get_edge_attribute(self, _id, attr): return False return self.graph.edges[_id][attr] - def get_tracklets(self): + def get_tracklets(self, include_division_edges: bool = False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next parent). Tracklets can also start or end with a non-dividing cell. + + Args: + include_division_edges (bool, optional): If True, include edges at division. + """ graph_copy = self.graph.copy() # Remove all intertrack edges from a copy of the original graph + removed_edges = [] for parent in self.get_divisions(): for daughter in self.get_succs(parent): graph_copy.remove_edge(parent, daughter) + removed_edges.append((parent, daughter)) # Extract subgraphs (aka tracklets) and return as new track graphs - return [ - self.get_subgraph(g) for g in nx.weakly_connected_components(graph_copy) - ] + tracklets = nx.weakly_connected_components(graph_copy) + + if include_division_edges: + tracklets = list(tracklets) + # Add back intertrack edges + for tracklet in tracklets: + for parent, daughter in removed_edges: + if daughter in tracklet: + tracklet.add(parent) + + return [self.get_subgraph(g) for g in tracklets] diff --git a/src/traccuracy/_utils.py b/src/traccuracy/_utils.py index 7fce666e..599a43b6 100644 --- a/src/traccuracy/_utils.py +++ b/src/traccuracy/_utils.py @@ -1,6 +1,3 @@ -import inspect - - def find_gt_node_matches(matches, gt_node): """For a given gt node, finds all pred nodes that are matches @@ -19,41 +16,3 @@ def find_pred_node_matches(matches, pred_node): pred_node (hashable): pred node ID """ return [pair[0] for pair in matches if pair[1] == pred_node] - - -def validate_matched_data(matched_data, metrics): - """Validate that given matcher supports requirements of each metric. - - Args: - matched_data (Matched): matching class with mapping between gt and pred - metrics (List[Metric]): list of metrics to compute as class names - """ - ... - - -def get_relevant_kwargs(metric_class, kwargs): - """Get all params in kwargs that are valid for given metric class. - - If required parameters are not satisfied, an error is raised. - - Args: - metric_class (Metric): class name of metric to check - kwargs (dict): dictionary of keyword arguments to validate - """ - sig = inspect.signature(metric_class) - relevant_kwargs = {} - missing_args = [] - for param in sig.parameters.values(): - name = param.name - is_required = (param.default is param.empty) and name != "matched_data" - if kwargs and name in kwargs: - relevant_kwargs[name] = kwargs[name] - elif is_required: - missing_args.append(name) - if missing_args: - raise ValueError( - f"Metric class {metric_class.__name__} is missing required" - + f" arguments: {missing_args}. Add arguments to" - + " `run_metrics` or consider skipping this metric." - ) - return relevant_kwargs diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index a991216c..a8b3249d 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -49,7 +49,7 @@ def run_ctc( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import CTCMetrics if loader != "ctc": @@ -57,11 +57,11 @@ def run_ctc( f"Only cell tracking challenge (ctc) loader is available, but {loader} was passed." ) gt_data, pred_data = load_all_ctc(gt_dir, pred_dir, gt_track_path, pred_track_path) - result = run_metrics(gt_data, pred_data, CTCMatched, [CTCMetrics]) + result = run_metrics(gt_data, pred_data, CTCMatcher(), [CTCMetrics()]) with open(out_path, "w") as fp: json.dump(result, fp) - logger.info(f'TRA: {result["CTCMetrics"]["TRA"]}') - logger.info(f'DET: {result["CTCMetrics"]["DET"]}') + logger.info(f'TRA: {result[0]["results"]["TRA"]}') + logger.info(f'DET: {result[0]["results"]["DET"]}') @app.command() @@ -109,7 +109,7 @@ def run_aogm( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import AOGMMetrics if loader != "ctc": @@ -120,20 +120,21 @@ def run_aogm( result = run_metrics( gt_data, pred_data, - CTCMatched, - [AOGMMetrics], - metrics_kwargs={ - "vertex_ns_weight": vertex_ns_weight, - "vertex_fp_weight": vertex_fp_weight, - "vertex_fn_weight": vertex_fn_weight, - "edge_fp_weight": edge_fp_weight, - "edge_fn_weight": edge_fn_weight, - "edge_ws_weight": edge_ws_weight, - }, + CTCMatcher(), + [ + AOGMMetrics( + vertex_ns_weight=vertex_ns_weight, + vertex_fp_weight=vertex_fp_weight, + vertex_fn_weight=vertex_fn_weight, + edge_fp_weight=edge_fp_weight, + edge_fn_weight=edge_fn_weight, + edge_ws_weight=edge_ws_weight, + ) + ], ) with open(out_path, "w") as fp: json.dump(result, fp) - logger.info(f'AOGM: {result["AOGMMetrics"]["AOGM"]}') + logger.info(f'AOGM: {result[0]["results"]["AOGM"]}') @app.command() @@ -173,7 +174,7 @@ def run_divisions_on_iou( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import IOUMatched + from traccuracy.matchers import IOUMatcher from traccuracy.metrics import DivisionMetrics if loader != "ctc": @@ -185,17 +186,13 @@ def run_divisions_on_iou( result = run_metrics( gt_data, pred_data, - IOUMatched, - [DivisionMetrics], - matcher_kwargs={"iou_threshold": match_threshold}, - metrics_kwargs={ - "frame_buffer": frame_buffer_tuple, - }, + IOUMatcher(iou_threshold=match_threshold), + [DivisionMetrics(frame_buffer=frame_buffer_tuple)], ) with open(out_path, "w") as fp: json.dump(result, fp) res_str = "" - for frame_buffer, res_dict in result["DivisionMetrics"].items(): + for frame_buffer, res_dict in result[0]["results"].items(): res_str += f'{frame_buffer} F1: {res_dict["Division F1"]}\n' logger.info(res_str) @@ -232,7 +229,7 @@ def run_divisions_on_ctc( Raises ValueError: if any --loader besides ctc is passed. """ - from traccuracy.matchers import CTCMatched + from traccuracy.matchers import CTCMatcher from traccuracy.metrics import DivisionMetrics if loader != "ctc": @@ -244,16 +241,13 @@ def run_divisions_on_ctc( result = run_metrics( gt_data, pred_data, - CTCMatched, - [DivisionMetrics], - metrics_kwargs={ - "frame_buffer": frame_buffer_tuple, - }, + CTCMatcher(), + [DivisionMetrics(frame_buffer=frame_buffer_tuple)], ) with open(out_path, "w") as fp: json.dump(result, fp) res_str = "" - for frame_buffer, res_dict in result["DivisionMetrics"].items(): + for frame_buffer, res_dict in result[0]["results"].items(): res_str += f'{frame_buffer} F1: {res_dict["Division F1"]}\n' logger.info(res_str) diff --git a/src/traccuracy/matchers/__init__.py b/src/traccuracy/matchers/__init__.py index 91e46d58..e7a641b1 100644 --- a/src/traccuracy/matchers/__init__.py +++ b/src/traccuracy/matchers/__init__.py @@ -25,8 +25,9 @@ While we specify ground truth and prediction, it is possible to write a matching function that matches two arbitrary tracking solutions. """ +from ._base import Matched from ._compute_overlap import get_labels_with_overlap -from ._ctc import CTCMatched -from ._iou import IOUMatched +from ._ctc import CTCMatcher +from ._iou import IOUMatcher -__all__ = ["CTCMatched", "IOUMatched", "get_labels_with_overlap"] +__all__ = ["CTCMatcher", "IOUMatcher", "get_labels_with_overlap", "Matched"] diff --git a/src/traccuracy/matchers/_base.py b/src/traccuracy/matchers/_base.py new file mode 100644 index 00000000..bbfe07d0 --- /dev/null +++ b/src/traccuracy/matchers/_base.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import copy +import logging +from abc import ABC, abstractmethod +from typing import Any + +from traccuracy._tracking_graph import TrackingGraph + +logger = logging.getLogger(__name__) + + +class Matcher(ABC): + """The Matcher base class provides a wrapper around the compute_mapping method + + Each Matcher subclass will implement its own kwargs as needed. + In use, the Matcher object will be initialized with kwargs prior to running compute_mapping + on a particular dataset + """ + + def compute_mapping( + self, gt_graph: TrackingGraph, pred_graph: TrackingGraph + ) -> Matched: + """Run the matching on a given set of gt and pred TrackingGraph and returns a Matched object + with a new copy of each TrackingGraph + + Args: + gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt + pred_graph (traccuracy.TrackingGraph): Tracking graph object for the pred + + Returns: + matched (Matched): Matched data object + + Raises: + ValueError: gt and pred must be a TrackingGraph object + """ + if not isinstance(gt_graph, TrackingGraph) or not isinstance( + pred_graph, TrackingGraph + ): + raise ValueError( + "Input data must be a TrackingData object with a graph and segmentations" + ) + + # Copy graphs to avoid possible changes to graphs while computing mapping + matched = self._compute_mapping( + copy.deepcopy(gt_graph), copy.deepcopy(pred_graph) + ) + + # Report matching performance + total_gt = len(matched.gt_graph.nodes()) + matched_gt = len({m[0] for m in matched.mapping}) + total_pred = len(matched.pred_graph.nodes()) + matched_pred = len({m[1] for m in matched.mapping}) + logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.") + logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.") + + return matched + + @abstractmethod + def _compute_mapping( + self, gt_graph: TrackingGraph, pred_graph: TrackingGraph + ) -> Matched: + """Computes a mapping of nodes in gt to nodes in pred and returns a Matched object + + Raises: + NotImplementedError + """ + raise NotImplementedError + + +class Matched: + """Matched data class which stores TrackingGraph objects for gt and pred + and the computed mapping + + Each TrackingGraph will be a new copy of the original object + + Args: + gt_graph (traccuracy.TrackingGraph): Tracking graph object for the gt + pred_graph (traccuracy.TrackingGraph): Tracking graph object for the pred + mapping (list[tuple[Any, Any]]): List of tuples where each tuple maps + a gt node to a pred node + """ + + def __init__( + self, + gt_graph: TrackingGraph, + pred_graph: TrackingGraph, + mapping: list[tuple[Any, Any]], + ): + self.gt_graph = gt_graph + self.pred_graph = pred_graph + self.mapping = mapping diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 745bbc0a..25827ff2 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -1,48 +1,46 @@ +from typing import TYPE_CHECKING + import networkx as nx import numpy as np from tqdm import tqdm -from traccuracy._tracking_graph import TrackingGraph +if TYPE_CHECKING: + from traccuracy._tracking_graph import TrackingGraph +from ._base import Matched, Matcher from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched -class CTCMatched(Matched): - def compute_mapping(self): - mapping = self._match_ctc() - return mapping +class CTCMatcher(Matcher): + """Match graph nodes based on measure used in cell tracking challenge benchmarking. + + A computed marker (segmentation) is matched to a reference marker if the computed + marker covers a majority of the reference marker. - def _match_ctc(self): - """Match graph nodes based on measure used in cell tracking challenge benchmarking. + Each reference marker can therefore only be matched to one computed marker, but + multiple reference markers can be assigned to a single computed marker. - A computed marker (segmentation) is matched to a reference marker if the computed - marker covers a majority of the reference marker. + See https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0144959 + for complete details. + """ - Each reference marker can therefore only be matched to one computed marker, but - multiple reference markers can be assigned to a single computed marker. + def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Run ctc matching - See https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0144959 - for complete details. + Args: + gt_graph (TrackingGraph): Tracking graph object for the gt + pred_graph (TrackingGraph): Tracking graph object for the pred Returns: - list[(gt_node, pred_node)]: list of tuples where each tuple contains a gt node - and pred node + traccuracy.matchers.Matched: Matched data object containing the CTC mapping Raises: - ValueError: gt and pred must be a TrackingGraph object ValueError: GT and pred segmentations must be the same shape """ - if not isinstance(self.gt_graph, TrackingGraph) or not isinstance( - self.pred_graph, TrackingGraph - ): - raise ValueError( - "Input data must be a TrackingData object with a graph and segmentations" - ) - gt = self.gt_graph - pred = self.pred_graph - gt_label_key = self.gt_graph.label_key - pred_label_key = self.pred_graph.label_key + gt = gt_graph + pred = pred_graph + gt_label_key = gt_graph.label_key + pred_label_key = pred_graph.label_key G_gt, mask_gt = gt, gt.segmentation G_pred, mask_pred = pred, pred.segmentation @@ -93,7 +91,8 @@ def _match_ctc(self): mapping.append( (gt_label_to_id[gt_label], pred_label_to_id[pred_label]) ) - return mapping + + return Matched(gt_graph, pred_graph, mapping) def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int: diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 374291aa..189c4e6d 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -3,8 +3,8 @@ from traccuracy._tracking_graph import TrackingGraph +from ._base import Matched, Matcher from ._compute_overlap import get_labels_with_overlap -from ._matched import Matched def _match_nodes(gt, res, threshold=1): @@ -98,29 +98,37 @@ def match_iou(gt, pred, threshold=0.6): return mapper -class IOUMatched(Matched): - def __init__(self, gt_graph, pred_graph, iou_threshold=0.6): - """Constructs a mapping between gt and pred nodes using the IoU of the segmentations +class IOUMatcher(Matcher): + """Constructs a mapping between gt and pred nodes using the IoU of the segmentations - Lower values for iou_threshold will be more permissive of imperfect matches + Lower values for iou_threshold will be more permissive of imperfect matches + + Args: + iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6. + """ + + def __init__(self, iou_threshold=0.6): + self.iou_threshold = iou_threshold + + def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + """Computes IOU mapping for a set of grpahs Args: - gt_graph (TrackingGraph): TrackingGraph for the ground truth with segmentations - pred_graph (TrackingGraph): TrackingGraph for the prediction with segmentations - iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6. + gt_graph (TrackingGraph): Tracking graph object for the gt with segmentation data + pred_graph (TrackingGraph): Tracking graph object for the pred with segmentation data Raises: ValueError: Segmentation data must be provided for both gt and pred data - """ - self.iou_threshold = iou_threshold + Returns: + Matched: Matched data object containing IOU mapping + """ # Check that segmentations exist in the data if gt_graph.segmentation is None or pred_graph.segmentation is None: raise ValueError( "Segmentation data must be provided for both gt and pred data" ) - super().__init__(gt_graph, pred_graph) + mapping = match_iou(gt_graph, pred_graph, threshold=self.iou_threshold) - def compute_mapping(self): - return match_iou(self.gt_graph, self.pred_graph, threshold=self.iou_threshold) + return Matched(gt_graph, pred_graph, mapping) diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py deleted file mode 100644 index d69ae7b3..00000000 --- a/src/traccuracy/matchers/_matched.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -import logging -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from traccuracy._tracking_graph import TrackingGraph - -logger = logging.getLogger(__name__) - - -class Matched(ABC): - def __init__(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): - """Matched class which takes TrackingData objects for gt and pred, and computes matching. - - Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. - The Matched objects will store both gt and pred data, as well as the mapping, - and any additional private attributes that may be needed/used e.g. detection matrices. - - Args: - gt_graph (TrackingGraph): Tracking graph object for the gt - pred_graph (TrackingGraph): Tracking graph object for the pred - """ - self.gt_graph = gt_graph - self.pred_graph = pred_graph - - self.mapping = self.compute_mapping() - - # Report matching performance - total_gt = len(self.gt_graph.nodes()) - matched_gt = len({m[0] for m in self.mapping}) - total_pred = len(self.pred_graph.nodes()) - matched_pred = len({m[1] for m in self.mapping}) - logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.") - logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.") - - @abstractmethod - def compute_mapping(self): - """Computes a mapping of nodes in gt to nodes in pred - - The mapping must be a list of tuples, e.g. [(gt_node, pred_node)] - - Raises: - NotImplementedError - """ - raise NotImplementedError diff --git a/src/traccuracy/metrics/__init__.py b/src/traccuracy/metrics/__init__.py index f5facc93..109130a3 100644 --- a/src/traccuracy/metrics/__init__.py +++ b/src/traccuracy/metrics/__init__.py @@ -1,4 +1,5 @@ from ._ctc import AOGMMetrics, CTCMetrics from ._divisions import DivisionMetrics +from ._track_overlap import TrackOverlapMetrics -__all__ = ["CTCMetrics", "DivisionMetrics", "AOGMMetrics"] +__all__ = ["CTCMetrics", "DivisionMetrics", "AOGMMetrics", "TrackOverlapMetrics"] diff --git a/src/traccuracy/metrics/_base.py b/src/traccuracy/metrics/_base.py index d98264a8..758f1d52 100644 --- a/src/traccuracy/metrics/_base.py +++ b/src/traccuracy/metrics/_base.py @@ -4,36 +4,26 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers import Matched class Metric(ABC): + """The base class for Metrics + + Data should be passed directly into the compute method + Kwargs should be specified in the constructor + """ + # Mapping criteria needs_one_to_one = False supports_one_to_many = False supports_many_to_one = False supports_many_to_many = False - def __init__(self, matched_data: Matched): - """Add Matched class which takes TrackingData objects for gt and pred,and computes matching - - Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. - The Matched objects will store both gt and pred data, as well as the mapping, - and any additional private attributes that may be needed/used e.g. detection matrices. - Metric subclasses will take keyword arguments to set the weights of various error counts. - - Args: - matched_data (Matched): Matched object for set of GT and Pred data - """ - self.data = matched_data - self.results = self.compute() - @abstractmethod - def compute(self) -> dict: + def compute(self, matched: Matched) -> dict: """The compute methods of Metric objects return a dictionary with counts and statistics. - They may make use of TrackingEvents objects but do not have to. - Raises: NotImplementedError diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index 937ad3cf..28b6ad11 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -8,7 +8,7 @@ from ._base import Metric if TYPE_CHECKING: - from ._base import Matched + from traccuracy.matchers import Matched class AOGMMetrics(Metric): @@ -16,7 +16,6 @@ class AOGMMetrics(Metric): def __init__( self, - matched_data: Matched, vertex_ns_weight=1, vertex_fp_weight=1, vertex_fn_weight=1, @@ -34,22 +33,19 @@ def __init__( "fn": edge_fn_weight, "ws": edge_ws_weight, } - super().__init__(matched_data) - def compute(self): - evaluate_ctc_events(self.data) + def compute(self, data: Matched): + evaluate_ctc_events(data) vertex_error_counts = { - "ns": len(self.data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), - "fp": len(self.data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), - "fn": len(self.data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), + "ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), + "fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), + "fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), } edge_error_counts = { - "ws": len( - self.data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC) - ), - "fp": len(self.data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), - "fn": len(self.data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), + "ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)), + "fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), + "fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), } error_sum = get_weighted_error_sum( vertex_error_counts, @@ -73,7 +69,7 @@ def compute(self): class CTCMetrics(AOGMMetrics): - def __init__(self, matched_data: Matched): + def __init__(self): vertex_weight_ns = 5 vertex_weight_fn = 10 vertex_weight_fp = 1 @@ -82,7 +78,6 @@ def __init__(self, matched_data: Matched): edge_weight_fn = 1.5 edge_weight_ws = 1 super().__init__( - matched_data, vertex_ns_weight=vertex_weight_ns, vertex_fp_weight=vertex_weight_fp, vertex_fn_weight=vertex_weight_fn, @@ -91,9 +86,9 @@ def __init__(self, matched_data: Matched): edge_ws_weight=edge_weight_ws, ) - def compute(self): + def compute(self, data: Matched): # AOGM-0 is the cost of creating the gt graph from scratch - gt_graph = self.data.gt_graph.graph + gt_graph = data.gt_graph.graph n_nodes = gt_graph.number_of_nodes() n_edges = gt_graph.number_of_edges() aogm_0 = n_nodes * self.v_weights["fn"] + n_edges * self.e_weights["fn"] @@ -103,7 +98,7 @@ def compute(self): + f" {n_edges} edges with {self.v_weights['fn']} vertex FN weight and" + f" {self.e_weights['fn']} edge FN weight" ) - errors = super().compute() + errors = super().compute(data) aogm = errors["AOGM"] tra = 1 - min(aogm, aogm_0) / aogm_0 errors["TRA"] = tra diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 42492202..4bd494fa 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -32,45 +32,52 @@ of the early division, by advancing along the graph to find nodes in the same frame as the late division daughters. """ +from __future__ import annotations +from typing import TYPE_CHECKING from traccuracy._tracking_graph import NodeAttr from traccuracy.track_errors.divisions import _evaluate_division_events from ._base import Metric +if TYPE_CHECKING: + from traccuracy.matchers import Matched + class DivisionMetrics(Metric): - needs_one_to_one = True + """Classify division events and provide the following summary metrics - def __init__(self, matched_data, frame_buffer=(0,)): - """Classify division events and provide the following summary metrics + - Division Recall + - Division Precision + - Division F1 Score + - Mitotic Branching Correctness: TP / (TP + FP + FN) as defined by Ulicna, K., + Vallardi, G., Charras, G. & Lowe, A. R. Automated deep lineage tree analysis + using a Bayesian single cell tracking approach. Frontiers in Computer Science + 3, 734559 (2021). - - Division Recall - - Division Precision - - Division F1 Score - - Mitotic Branching Correctness: TP / (TP + FP + FN) as defined by Ulicna, K., - Vallardi, G., Charras, G. & Lowe, A. R. Automated deep lineage tree analysis - using a Bayesian single cell tracking approach. Frontiers in Computer Science - 3, 734559 (2021). + Args: + frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames + to tolerate in correct_shifted_divisions. Defaults to (0). + """ - Args: - matched_data (Matched): Matched object for set of GT and Pred data - Must meet the `needs_one_to_one` criteria - frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames - to tolerate in correct_shifted_divisions. Defaults to (0). - """ + needs_one_to_one = True + + def __init__(self, frame_buffer=(0,)): self.frame_buffer = frame_buffer - super().__init__(matched_data) - def compute(self): + def compute(self, data: Matched): """Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer + Args: + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data + Must meet the `needs_one_to_one` criteria + Returns: dict: Returns a nested dictionary with one dictionary per frame buffer value """ div_annotations = _evaluate_division_events( - self.data, + data, frame_buffer=self.frame_buffer, ) diff --git a/src/traccuracy/metrics/_track_overlap.py b/src/traccuracy/metrics/_track_overlap.py new file mode 100644 index 00000000..f4f1f827 --- /dev/null +++ b/src/traccuracy/metrics/_track_overlap.py @@ -0,0 +1,130 @@ +"""This submodule implements routines for Track Purity (TP) and Target Effectiveness (TE) scores. + +Definitions (Bise et al., 2011; Chen, 2021; Fukai et al., 2022): + +- TE for a single ground truth track T^g_j is calculated by finding the predicted track T^p_k + that overlaps with T^g_j in the largest number of the frames and then dividing + the overlap frame counts by the total frame counts for T^g_j. + The TE for the total dataset is calculated as the mean of TEs for all ground truth tracks, + weighted by the length of the tracks. + +- TP is defined analogously, with T^g_j and T^p_j being swapped in the definition. +""" +from __future__ import annotations + +from itertools import groupby, product +from typing import TYPE_CHECKING, Any + +from ._base import Metric + +if TYPE_CHECKING: + from traccuracy._tracking_graph import TrackingGraph + from traccuracy.matchers import Matched + + +def _mapping_to_dict(mapping: list[tuple[Any, Any]]) -> dict[Any, list[Any]]: + """Convert mapping list of tuples to dictionary. + + Args: + mapping (List[Tuple[Any, Any]]): Mapping list of tuples + + Returns: + Dict[Any, List[Any]]: Mapping dictionary + + """ + + def get_from_val(x): + return x[0] + + return { + k: [v[1] for v in vs] + for k, vs in groupby(sorted(mapping, key=get_from_val), key=get_from_val) + } + + +class TrackOverlapMetrics(Metric): + """Calculate metrics for longest track overlaps. + + - Target Effectiveness: fraction of longest overlapping prediction + tracklets on each GT tracklet + - Track Purity : fraction of longest overlapping GT + tracklets on each prediction tracklet + + Args: + matched_data (traccuracy.matchers.Matched): Matched object for set of GT and Pred data + include_division_edges (bool, optional): If True, include edges at division. + + """ + + supports_many_to_one = True + + def __init__(self, include_division_edges: bool = True): + self.include_division_edges = include_division_edges + + def compute(self, matched: Matched) -> dict: + gt_tracklets = matched.gt_graph.get_tracklets( + include_division_edges=self.include_division_edges + ) + pred_tracklets = matched.pred_graph.get_tracklets( + include_division_edges=self.include_division_edges + ) + + gt_pred_mapping = _mapping_to_dict(matched.mapping) + pred_gt_mapping = _mapping_to_dict( + [(pred_node, gt_node) for gt_node, pred_node in matched.mapping] + ) + + # calculate track purity and target effectiveness + track_purity = _calc_overlap_score( + pred_tracklets, gt_tracklets, gt_pred_mapping + ) + target_effectiveness = _calc_overlap_score( + gt_tracklets, pred_tracklets, pred_gt_mapping + ) + return { + "track_purity": track_purity, + "target_effectiveness": target_effectiveness, + } + + +def _calc_overlap_score( + reference_tracklets: list[TrackingGraph], + overlap_tracklets: list[TrackingGraph], + overlap_reference_mapping: dict[Any, list[Any]], +): + """Calculate weighted sum of the length of the longest overlap tracklet + for each reference tracklet. + + Args: + reference_tracklets (List[TrackingGraph]): The reference tracklets + overlap_tracklets (List[TrackingGraph]): The tracklets that overlap + overlap_reference_mapping (Dict[Any, List[Any]]): Mapping as a dict + from the overlap tracklet nodes to the reference tracklet nodes + + """ + correct_count = 0 + total_count = 0 + + # calculate all overlapping edges mapped onto GT ids + overlap_tracklets_edges_mapped = [] + for overlap_tracklet in overlap_tracklets: + edges = [] + for node1, node2 in overlap_tracklet.edges(): + mapped_nodes1 = overlap_reference_mapping.get(node1, []) + mapped_nodes2 = overlap_reference_mapping.get(node2, []) + if mapped_nodes1 and mapped_nodes2: + for n1, n2 in product(mapped_nodes1, mapped_nodes2): + edges.append((n1, n2)) + overlap_tracklets_edges_mapped.append(edges) + + for reference_tracklet in reference_tracklets: + # find the overlap tracklet with the largest overlap + overlaps = [ + len(set(reference_tracklet.edges()) & set(edges)) + for edges in overlap_tracklets_edges_mapped + ] + max_overlap = max(overlaps) + correct_count += max_overlap + total_count += len(reference_tracklet.edges()) + + return correct_count / total_count if total_count > 0 else -1 diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index f21f8b78..bee72a8c 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -6,10 +6,10 @@ from tqdm import tqdm -from traccuracy import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeAttr, NodeAttr if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers import Matched logger = logging.getLogger(__name__) @@ -28,7 +28,7 @@ def get_vertex_errors(matched_data: Matched): Parameters ---------- - matched_data: Matched + matched_data: traccuracy.matchers.Matched Matched data object containing gt and pred graphs with their associated mapping """ comp_graph = matched_data.pred_graph diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index de1defa1..7ab3fc6b 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -10,7 +10,7 @@ from traccuracy._utils import find_gt_node_matches, find_pred_node_matches if TYPE_CHECKING: - from traccuracy.matchers._matched import Matched + from traccuracy.matchers import Matched logger = logging.getLogger(__name__) diff --git a/tests/bench.py b/tests/bench.py index 6068e459..8458d5a2 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -5,7 +5,7 @@ import pytest from traccuracy.loaders import load_ctc_data -from traccuracy.matchers import CTCMatched, IOUMatched +from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -51,12 +51,12 @@ def pred_data(): @pytest.fixture(scope="module") def ctc_matched(gt_data, pred_data): - return CTCMatched(gt_data, pred_data) + return CTCMatcher().compute_mapping(gt_data, pred_data) @pytest.fixture(scope="module") def iou_matched(gt_data, pred_data): - return IOUMatched(gt_data, pred_data, iou_threshold=0.1) + return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data, pred_data) def test_load_gt_data(benchmark): @@ -88,13 +88,13 @@ def test_load_pred_data(benchmark): def test_ctc_matched(benchmark, gt_data, pred_data): - benchmark(CTCMatched, gt_data, pred_data) + benchmark(CTCMatcher().compute_mapping, gt_data, pred_data) @pytest.mark.timeout(300) def test_ctc_metrics(benchmark, ctc_matched): def run_compute(): - return CTCMetrics(copy.deepcopy(ctc_matched)).compute() + return CTCMetrics().compute(copy.deepcopy(ctc_matched)) ctc_results = benchmark.pedantic(run_compute, rounds=1, iterations=1) @@ -108,7 +108,7 @@ def run_compute(): def test_ctc_div_metrics(benchmark, ctc_matched): def run_compute(): - return DivisionMetrics(copy.deepcopy(ctc_matched)).compute() + return DivisionMetrics().compute(copy.deepcopy(ctc_matched)) div_results = benchmark(run_compute) @@ -118,12 +118,12 @@ def run_compute(): def test_iou_matched(benchmark, gt_data, pred_data): - benchmark(IOUMatched, gt_data, pred_data, iou_threshold=0.5) + benchmark(IOUMatcher(iou_threshold=0.1).compute_mapping, gt_data, pred_data) def test_iou_div_metrics(benchmark, iou_matched): def run_compute(): - return DivisionMetrics(copy.deepcopy(iou_matched)).compute() + return DivisionMetrics().compute(copy.deepcopy(iou_matched)) div_results = benchmark(run_compute) diff --git a/tests/matchers/test_ctc.py b/tests/matchers/test_ctc.py index 02040098..e210d79c 100644 --- a/tests/matchers/test_ctc.py +++ b/tests/matchers/test_ctc.py @@ -2,19 +2,17 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._ctc import CTCMatched +from traccuracy.matchers._ctc import CTCMatcher from tests.test_utils import get_annotated_movie def test_match_ctc(): - # Bad input - with pytest.raises(ValueError): - CTCMatched("not tracking data", "not tracking data") + matcher = CTCMatcher() # shapes don't match with pytest.raises(ValueError): - CTCMatched( + matcher.compute_mapping( TrackingGraph(nx.DiGraph(), segmentation=np.zeros((5, 10, 10))), TrackingGraph(nx.DiGraph(), segmentation=np.zeros((5, 10, 5))), ) @@ -37,7 +35,7 @@ def test_match_ctc(): attrs[f"{i}_{t}"] = {"t": t, "y": 0, "x": 0, "segmentation_id": i} nx.set_node_attributes(g, attrs) - matched = CTCMatched( + matched = matcher.compute_mapping( TrackingGraph(g, segmentation=movie), TrackingGraph(g, segmentation=movie), ) diff --git a/tests/matchers/test_iou.py b/tests/matchers/test_iou.py index 2cfa171e..edf942b7 100644 --- a/tests/matchers/test_iou.py +++ b/tests/matchers/test_iou.py @@ -2,7 +2,7 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._iou import IOUMatched, _match_nodes, match_iou +from traccuracy.matchers._iou import IOUMatcher, _match_nodes, match_iou from tests.test_utils import get_annotated_image, get_movie_with_graph @@ -68,8 +68,10 @@ def test__init__(self): track_graph = get_movie_with_graph() data = TrackingGraph(track_graph.graph) + matcher = IOUMatcher() + with pytest.raises(ValueError): - IOUMatched(data, data) + matcher.compute_mapping(data, data) def test_compute_mapping(self): # Test 2d data @@ -79,7 +81,8 @@ def test_compute_mapping(self): ndims=3, n_frames=n_frames, n_labels=n_labels ) - matched = IOUMatched(gt_graph=track_graph, pred_graph=track_graph) + matcher = IOUMatcher() + matched = matcher.compute_mapping(gt_graph=track_graph, pred_graph=track_graph) # Check for correct number of pairs assert len(matched.mapping) == n_frames * n_labels diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index fb5bd601..bd92fc95 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,4 +1,4 @@ -from traccuracy.matchers._ctc import CTCMatched +from traccuracy.matchers._ctc import CTCMatcher from traccuracy.metrics._ctc import CTCMetrics from tests.test_utils import get_movie_with_graph @@ -10,10 +10,10 @@ def test_compute_mapping(): n_labels = 3 track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) - matched = CTCMatched(gt_graph=track_graph, pred_graph=track_graph) - metric = CTCMetrics(matched) - assert metric.results - assert "TRA" in metric.results - assert "DET" in metric.results - assert metric.results["TRA"] == 1 - assert metric.results["DET"] == 1 + matched = CTCMatcher().compute_mapping(gt_graph=track_graph, pred_graph=track_graph) + results = CTCMetrics().compute(matched) + assert results + assert "TRA" in results + assert "DET" in results + assert results["TRA"] == 1 + assert results["DET"] == 1 diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index e2679115..3aa9b4c6 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,30 +1,20 @@ from traccuracy import TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers import Matched from traccuracy.metrics._divisions import DivisionMetrics from tests.test_utils import get_division_graphs -class DummyMatched(Matched): - def __init__(self, gt_data, pred_data, mapper): - self.mapper = mapper - super().__init__(gt_data, pred_data) - - def compute_mapping(self): - return self.mapper - - def test_DivisionMetrics(): g_gt, g_pred, mapper = get_division_graphs() - matched = DummyMatched( + matched = Matched( TrackingGraph(g_gt), TrackingGraph(g_pred), - mapper=mapper, + mapper, ) frame_buffer = (0, 1, 2) - metrics = DivisionMetrics(matched, frame_buffer=frame_buffer) - results = metrics.compute() + results = DivisionMetrics(frame_buffer=frame_buffer).compute(matched) for name, r in results.items(): buffer = int(name[-1:]) diff --git a/tests/metrics/test_track_overlap_metrics.py b/tests/metrics/test_track_overlap_metrics.py new file mode 100644 index 00000000..f70a9304 --- /dev/null +++ b/tests/metrics/test_track_overlap_metrics.py @@ -0,0 +1,189 @@ +from copy import deepcopy + +import networkx as nx +import pytest +from traccuracy import TrackingGraph +from traccuracy.matchers import Matched +from traccuracy.metrics._track_overlap import TrackOverlapMetrics, _mapping_to_dict + + +def add_frame(tree): + attrs = {} + for node in tree.nodes: + attrs[node] = {"t": int(node.split("_")[0]), "x": 0, "y": 0} + nx.set_node_attributes(tree, attrs) + return tree + + +TEST_TREES = [ + { + "name": "simple1", + "gt_edges": [ + # 0 - 0 - 0 - 0 - 0 - 0 + # | + # - 1 - 1 - 1 + # + # 2 - 2 - 2 - 2 + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("3_0", "4_0"), + ("4_0", "5_0"), + ("2_0", "3_1"), + ("3_1", "4_1"), + ("4_1", "5_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ("3_2", "4_2"), + ], + "pred_edges": [ + # 0 - 0 - 0 - 0 0 - 0 + # | + # - 1 + # - 1 - 1 + # | + # 2 - 2 - 2 - + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("4_0", "5_0"), + ("2_0", "3_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ("3_2", "4_1"), + ("4_1", "5_1"), + ("4_1", "5_1"), + ("4_1", "5_1"), + ], + "results_with_division_edges": { + "track_purity": 7 / 9, + "target_effectiveness": 6 / 11, + }, + "results_without_division_edges": { + "track_purity": 5 / 7, + "target_effectiveness": 6 / 9, + }, + }, + { + "name": "overlap", + # 0 - 0 - 0 - 0 + # | + # - 1 - 1 + "gt_edges": [ + ("0_0", "1_0"), + ("1_0", "2_0"), + ("2_0", "3_0"), + ("1_0", "2_1"), + ("2_1", "3_1"), + ], + # 0 - 0 - 0 + # | + # - 1 - 1 + # 2 - 2 - 2 + # (2 and 1 overlap) + "pred_edges": [ + ("0_0", "1_0"), + ("1_0", "2_0"), + ("1_0", "2_1"), + ("2_1", "3_1"), + ("1_2", "2_2"), + ("2_2", "3_2"), + ], + "mapping": [ # GT to pred mapping + ("0_0", "0_0"), + ("1_0", "1_0"), + ("2_0", "2_0"), + ("3_0", "3_0"), + ("2_1", "2_1"), + ("3_1", "3_1"), + ("2_1", "2_2"), + ("3_1", "3_2"), + ], + "results_with_division_edges": { + "track_purity": 5 / 6, + "target_effectiveness": 4 / 5, + }, + "results_without_division_edges": { + "track_purity": 3 / 4, + "target_effectiveness": 2 / 3, + }, + }, +] + +simple2 = deepcopy(TEST_TREES[0]) +simple2["name"] = "simple2" +# 0 - 0 - 0 - 0 0 - 0 +# | +# - 1 +# - 1 - 1 +# | +# 2 - 2 - 2 - +# | +# - 3 - 3 +simple2["pred_edges"].extend( + [ + ("2_2", "3_3"), + ("3_3", "4_3"), + ] +) +simple2["results_with_division_edges"] = { + "track_purity": 7 / 11, + "target_effectiveness": 5 / 11, +} +simple2["results_without_division_edges"] = { + "track_purity": 5 / 7, + "target_effectiveness": 5 / 9, +} +TEST_TREES.append(simple2) +assert TEST_TREES[0] != TEST_TREES[1] + + +@pytest.mark.parametrize("data", TEST_TREES) +@pytest.mark.parametrize("inverse", [False, True]) +def test_track_overlap_metrics(data, inverse) -> None: + g_gt = add_frame(nx.from_edgelist(data["gt_edges"], create_using=nx.DiGraph)) + g_pred = add_frame(nx.from_edgelist(data["pred_edges"], create_using=nx.DiGraph)) + if "mapping" in data: + mapping = data["mapping"] + else: + mapping = [(n, n) for n in g_gt.nodes] + + if inverse: + g_gt, g_pred = g_pred, g_gt + mapping = [(b, a) for a, b in mapping] + + matched = Matched( + TrackingGraph(g_gt), + TrackingGraph(g_pred), + mapping, + ) + + metric = TrackOverlapMetrics() + results = metric.compute(matched) + assert results + + expected = data["results_with_division_edges"] + if inverse: + expected = { + "track_purity": expected["target_effectiveness"], + "target_effectiveness": expected["track_purity"], + } + assert results == expected, f"{data['name']} failed with division edges" + + metric = TrackOverlapMetrics(include_division_edges=False) + results = metric.compute(matched) + assert results + + expected = data["results_without_division_edges"] + if inverse: + expected = { + "track_purity": expected["target_effectiveness"], + "target_effectiveness": expected["track_purity"], + } + assert results == expected, f"{data['name']} failed without division edges" + + +def test_mapping_to_dict(): + mapping = [("1", "2"), ("2", "3"), ("1", "3"), ("2", "3")] + mapping_dict = _mapping_to_dict(mapping) + assert mapping_dict == {"1": ["2", "3"], "2": ["3", "3"]} diff --git a/tests/test_run_metrics.py b/tests/test_run_metrics.py new file mode 100644 index 00000000..5972e9f8 --- /dev/null +++ b/tests/test_run_metrics.py @@ -0,0 +1,70 @@ +import pytest +from traccuracy import run_metrics +from traccuracy.matchers._base import Matched, Matcher +from traccuracy.metrics._base import Metric + +from tests.test_utils import get_movie_with_graph + + +class DummyMetric(Metric): + def compute(self, matched): + return {} + + +class DummyMetricParam(Metric): + def __init__(self, param="value"): + self.param = param + + def compute(self, matched): + return {} + + +class DummyMatcher(Matcher): + def __init__(self, mapping=None): + if mapping: + self.mapping = mapping + else: + self.mapping = [] + + def _compute_mapping(self, gt_graph, pred_graph): + return Matched(gt_graph, pred_graph, self.mapping) + + +def test_run_metrics(): + graph = get_movie_with_graph() + mapping = [(n, n) for n in graph.nodes()] + + # Check matcher input -- not instantiated + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher, [DummyMetric()]) + + # Check matcher input -- wrong type + with pytest.raises(TypeError): + run_metrics(graph, graph, "rando", DummyMetric()) + + # Check metric input -- not instantiated + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher(), [DummyMetric]) + + # Check metric input -- wrong type + with pytest.raises(TypeError): + run_metrics(graph, graph, DummyMatcher(), [DummyMetric(), "rando"]) + + # One metric + results = run_metrics(graph, graph, DummyMatcher(mapping), [DummyMetric()]) + assert isinstance(results, list) + assert len(results) == 1 + assert results[0]["metric"]["name"] == "DummyMetric" + + # Duplicate metric with different params + results = run_metrics( + graph, + graph, + DummyMatcher(mapping), + [DummyMetricParam("param1"), DummyMetricParam("param2")], + ) + assert len(results) == 2 + assert results[0]["metric"]["name"] == "DummyMetricParam" + assert results[0]["metric"].get("param") == "param1" + assert results[1]["metric"]["name"] == "DummyMetricParam" + assert results[1]["metric"].get("param") == "param2" diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index ceba4989..53d4a7f7 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,15 +1,10 @@ import networkx as nx import numpy as np from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors -class DummyMatched(Matched): - def compute_mapping(self): - return [] - - def test_get_vertex_errors(): comp_ids = [3, 7, 10] comp_ids_2 = list(np.asarray(comp_ids) + 1) @@ -39,27 +34,26 @@ def test_get_vertex_errors(): ) G_comp = TrackingGraph(comp_g) - matched_data = DummyMatched(G_gt, G_comp) - matched_data.mapping = mapping + matched_data = Matched(G_gt, G_comp, mapping) get_vertex_errors(matched_data) - assert len(G_comp.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 - assert len(G_comp.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 - assert len(G_comp.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 - assert len(G_gt.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 - assert gt_g.nodes[15][NodeAttr.FALSE_NEG] - assert not gt_g.nodes[17][NodeAttr.FALSE_NEG] + assert matched_data.gt_graph.graph.nodes[15][NodeAttr.FALSE_NEG] + assert not matched_data.gt_graph.graph.nodes[17][NodeAttr.FALSE_NEG] - assert comp_g.nodes[3][NodeAttr.NON_SPLIT] - assert not comp_g.nodes[7][NodeAttr.NON_SPLIT] + assert matched_data.pred_graph.graph.nodes[3][NodeAttr.NON_SPLIT] + assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.NON_SPLIT] - assert comp_g.nodes[7][NodeAttr.TRUE_POS] - assert not comp_g.nodes[3][NodeAttr.TRUE_POS] + assert matched_data.pred_graph.graph.nodes[7][NodeAttr.TRUE_POS] + assert not matched_data.pred_graph.graph.nodes[3][NodeAttr.TRUE_POS] - assert comp_g.nodes[10][NodeAttr.FALSE_POS] - assert not comp_g.nodes[7][NodeAttr.FALSE_POS] + assert matched_data.pred_graph.graph.nodes[10][NodeAttr.FALSE_POS] + assert not matched_data.pred_graph.graph.nodes[7][NodeAttr.FALSE_POS] def test_assign_edge_errors(): @@ -95,13 +89,12 @@ def test_assign_edge_errors(): ) G_gt = TrackingGraph(gt_g) - matched_data = DummyMatched(G_gt, G_comp) - matched_data.mapping = mapping + matched_data = Matched(G_gt, G_comp, mapping) get_edge_errors(matched_data) - assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS] - assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG] + assert matched_data.pred_graph.graph.edges[(7, 8)][EdgeAttr.FALSE_POS] + assert matched_data.gt_graph.graph.edges[(17, 18)][EdgeAttr.FALSE_NEG] def test_assign_edge_errors_semantics(): @@ -136,9 +129,8 @@ def test_assign_edge_errors_semantics(): # Define mapping with all nodes matching except for 2_3 in comp mapping = [(n, n) for n in gt.nodes] - matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp)) - matched_data.mapping = mapping + matched_data = Matched(TrackingGraph(gt), TrackingGraph(comp), mapping) get_edge_errors(matched_data) - assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] + assert matched_data.pred_graph.graph.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 0a10bf66..6538e644 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -2,7 +2,7 @@ import numpy as np import pytest from traccuracy import NodeAttr, TrackingGraph -from traccuracy.matchers._matched import Matched +from traccuracy.matchers import Matched from traccuracy.track_errors.divisions import ( _classify_divisions, _correct_shifted_divisions, @@ -14,11 +14,6 @@ from tests.test_utils import get_division_graphs -class DummyMatched(Matched): - def compute_mapping(self): - return [] - - @pytest.fixture def g(): """ @@ -51,22 +46,19 @@ def g(): def test_classify_divisions_tp(g): # Define mapper assuming all nodes match mapper = [(n, n) for n in g.nodes] - g_gt = TrackingGraph(g.copy()) - g_pred = TrackingGraph(g.copy()) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g.copy()), TrackingGraph(g.copy()), mapper) # Test true positive _classify_divisions(matched_data) - assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 + assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] + assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] # Check division flag - assert g_gt.division_annotations - assert g_pred.division_annotations + assert matched_data.gt_graph.division_annotations + assert matched_data.pred_graph.division_annotations def test_classify_divisions_fp(g): @@ -84,17 +76,14 @@ def test_classify_divisions_fp(g): nx.set_node_attributes(h, {"5_3": {"t": 3, "x": 0, "y": 0}}) mapper = [(n, n) for n in h.nodes] - g_gt = TrackingGraph(g) - g_pred = TrackingGraph(h) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g), TrackingGraph(h), mapper) _classify_divisions(matched_data) - assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert NodeAttr.FP_DIV in g_pred.nodes()["1_2"] - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 + assert NodeAttr.FP_DIV in matched_data.pred_graph.nodes()["1_2"] + assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes()["2_2"] + assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes()["2_2"] def test_classify_divisions_fn(g): @@ -107,16 +96,13 @@ def test_classify_divisions_fn(g): h.remove_nodes_from(["3_3", "4_3"]) mapper = [(n, n) for n in h.nodes] - g_gt = TrackingGraph(g) - g_pred = TrackingGraph(h) - matched_data = DummyMatched(g_gt, g_pred) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g), TrackingGraph(h), mapper) _classify_divisions(matched_data) - assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 - assert NodeAttr.FN_DIV in g_gt.nodes()["2_2"] + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 + assert NodeAttr.FN_DIV in matched_data.gt_graph.nodes()["2_2"] @pytest.fixture @@ -177,8 +163,7 @@ def test_no_change(self): g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 1, no change new_matched = _correct_shifted_divisions(matched_data, n_frames=1) @@ -195,8 +180,7 @@ def test_fn_early(self): g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 3, corrections new_matched = _correct_shifted_divisions(matched_data, n_frames=3) @@ -214,8 +198,7 @@ def test_fp_early(self): g_pred.nodes["1_1"][NodeAttr.FP_DIV] = True g_gt.nodes["1_3"][NodeAttr.FN_DIV] = True - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) # buffer of 3, corrections new_matched = _correct_shifted_divisions(matched_data, n_frames=3) @@ -232,8 +215,7 @@ def test_evaluate_division_events(): g_gt, g_pred, mapper = get_division_graphs() frame_buffer = (0, 1, 2) - matched_data = DummyMatched(TrackingGraph(g_gt), TrackingGraph(g_pred)) - matched_data.mapping = mapper + matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) results = _evaluate_division_events(matched_data, frame_buffer=frame_buffer)