From 31b10f8ef16a6904dc171ce60ddf20feb24bf7ba Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 3 Nov 2023 17:01:10 -0400 Subject: [PATCH 1/3] style: cleanup annotations --- src/traccuracy/_run_metrics.py | 17 +++--- src/traccuracy/cli.py | 74 ++++++++++++------------ src/traccuracy/matchers/_ctc.py | 2 +- src/traccuracy/matchers/_matched.py | 4 +- src/traccuracy/metrics/_base.py | 4 +- src/traccuracy/metrics/_ctc.py | 6 +- src/traccuracy/track_errors/_ctc.py | 8 ++- src/traccuracy/track_errors/divisions.py | 8 ++- 8 files changed, 67 insertions(+), 56 deletions(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 9c643ca9..6b9c4d5a 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,9 +1,10 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from traccuracy._utils import get_relevant_kwargs, validate_matched_data if TYPE_CHECKING: - from typing import Dict, List, Optional, Type from traccuracy import TrackingGraph from traccuracy.matchers._matched import Matched @@ -11,13 +12,13 @@ def run_metrics( - gt_data: "TrackingGraph", - pred_data: "TrackingGraph", - matcher: "Type[Matched]", - metrics: "List[Type[Metric]]", - matcher_kwargs: "Optional[Dict]" = None, - metrics_kwargs: "Optional[Dict]" = None, # weights -) -> "Dict": + gt_data: TrackingGraph, + pred_data: TrackingGraph, + matcher: type[Matched], + metrics: list[type[Metric]], + matcher_kwargs: dict | None = None, + metrics_kwargs: dict | None = None, # weights +) -> dict: """Compute given metrics on data using the given matcher. An error will be thrown if the given matcher is not compatible with diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index a7ce4c00..a991216c 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -13,10 +13,10 @@ def load_all_ctc( - gt_dir: "str", - pred_dir: "str", - gt_track_path: "Optional[str]" = None, - pred_track_path: "Optional[str]" = None, + gt_dir: str, + pred_dir: str, + gt_track_path: Optional[str] = None, + pred_track_path: Optional[str] = None, ): gt_data = load_ctc_data(gt_dir, gt_track_path) pred_data = load_ctc_data(pred_dir, pred_track_path) @@ -25,18 +25,18 @@ def load_all_ctc( @app.command() def run_ctc( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("ctc_log.json", help="Path to save results"), + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("ctc_log.json", help="Path to save results"), ): """ Run TRA and DET metric on gt and pred data using CTC matching. @@ -66,34 +66,34 @@ def run_ctc( @app.command() def run_aogm( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("aogm_log.json", help="Path to save results"), - vertex_ns_weight: "float" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("aogm_log.json", help="Path to save results"), + vertex_ns_weight: float = typer.Option( 1, help="Weight to assign to nonsplit vertex errors" ), - vertex_fp_weight: "float" = typer.Option( + vertex_fp_weight: float = typer.Option( 1, help="Weight to assign to false positive vertex errors" ), - vertex_fn_weight: "float" = typer.Option( + vertex_fn_weight: float = typer.Option( 1, help="Weight to assign to false negative vertex errors" ), - edge_fp_weight: "float" = typer.Option( + edge_fp_weight: float = typer.Option( 1, help="Weight to assign to false positive edge errors" ), - edge_fn_weight: "float" = typer.Option( + edge_fn_weight: float = typer.Option( 1, help="Weight to assign to false negative edge errors" ), - edge_ws_weight: "float" = typer.Option( + edge_ws_weight: float = typer.Option( 1, help="Weight to assign to edges with incorrect semantics" ), ): @@ -138,24 +138,24 @@ def run_aogm( @app.command() def run_divisions_on_iou( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("div_log_iou.json", help="Path to save results"), - match_threshold: "float" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("div_log_iou.json", help="Path to save results"), + match_threshold: float = typer.Option( 1, help="Threshold above which the intersection over union of a gt and predicted" " detection match. Default of 1 requires exact matching.", ), - frame_buffer: "int" = typer.Option( + frame_buffer: int = typer.Option( 0, help="Number of frames to use for division tolerance." " Numbers greater than 0 will produce metrics for 0...n inclusive.", @@ -202,19 +202,19 @@ def run_divisions_on_iou( @app.command() def run_divisions_on_ctc( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("div_log_ctc.json", help="Path to save results"), - frame_buffer: "int" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("div_log_ctc.json", help="Path to save results"), + frame_buffer: int = typer.Option( 0, help="Number of frames to use for division tolerance." " Numbers greater than 0 will produce metrics for 0...n inclusive.", diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index ed4980bb..745bbc0a 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -96,7 +96,7 @@ def _match_ctc(self): return mapping -def detection_test(gt_blob: "np.ndarray", comp_blob: "np.ndarray") -> int: +def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int: """Check if computed marker overlaps majority of the reference marker. Given a reference marker and computer marker in original coordinates, diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 85399848..d69ae7b3 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -9,7 +11,7 @@ class Matched(ABC): - def __init__(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def __init__(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Matched class which takes TrackingData objects for gt and pred, and computes matching. Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. diff --git a/src/traccuracy/metrics/_base.py b/src/traccuracy/metrics/_base.py index dd763678..d98264a8 100644 --- a/src/traccuracy/metrics/_base.py +++ b/src/traccuracy/metrics/_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import TYPE_CHECKING @@ -12,7 +14,7 @@ class Metric(ABC): supports_many_to_one = False supports_many_to_many = False - def __init__(self, matched_data: "Matched"): + def __init__(self, matched_data: Matched): """Add Matched class which takes TrackingData objects for gt and pred,and computes matching Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched. diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index a5f0a84f..937ad3cf 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from traccuracy._tracking_graph import EdgeAttr, NodeAttr @@ -14,7 +16,7 @@ class AOGMMetrics(Metric): def __init__( self, - matched_data: "Matched", + matched_data: Matched, vertex_ns_weight=1, vertex_fp_weight=1, vertex_fn_weight=1, @@ -71,7 +73,7 @@ def compute(self): class CTCMetrics(AOGMMetrics): - def __init__(self, matched_data: "Matched"): + def __init__(self, matched_data: Matched): vertex_weight_ns = 5 vertex_weight_fn = 10 vertex_weight_fp = 1 diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 642399df..f21f8b78 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from typing import TYPE_CHECKING @@ -12,7 +14,7 @@ logger = logging.getLogger(__name__) -def evaluate_ctc_events(matched_data: "Matched"): +def evaluate_ctc_events(matched_data: Matched): """Annotates ground truth and predicted graph with node and edge error types Annotations are made in place @@ -21,7 +23,7 @@ def evaluate_ctc_events(matched_data: "Matched"): get_edge_errors(matched_data) -def get_vertex_errors(matched_data: "Matched"): +def get_vertex_errors(matched_data: Matched): """Count vertex errors and assign class to each comp/gt node. Parameters @@ -70,7 +72,7 @@ def get_vertex_errors(matched_data: "Matched"): gt_graph.node_errors = True -def get_edge_errors(matched_data: "Matched"): +def get_edge_errors(matched_data: Matched): comp_graph = matched_data.pred_graph gt_graph = matched_data.gt_graph node_mapping = matched_data.mapping diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 9a432fe8..de1defa1 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import itertools import logging @@ -13,7 +15,7 @@ logger = logging.getLogger(__name__) -def _classify_divisions(matched_data: "Matched"): +def _classify_divisions(matched_data: Matched): """Identify each division as a true positive, false positive or false negative This function only works on node mappers that are one-to-one @@ -139,7 +141,7 @@ def _get_succ_by_t(g, node, delta_frames): return node -def _correct_shifted_divisions(matched_data: "Matched", n_frames=1): +def _correct_shifted_divisions(matched_data: Matched, n_frames=1): """Allows for divisions to occur within a frame buffer and still be correct This implementation asserts that the parent lineages and daughter lineages must match. @@ -236,7 +238,7 @@ def _correct_shifted_divisions(matched_data: "Matched", n_frames=1): return new_matched -def _evaluate_division_events(matched_data: "Matched", frame_buffer=(0)): +def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)): """Classify division errors and correct shifted divisions according to frame_buffer Note: A copy of matched_data will be created for each frame_buffer other than 0. From cef06a38b9b87a2f6595060e7fb7e2e5fc2da74b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 3 Nov 2023 21:08:08 +0000 Subject: [PATCH 2/3] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_run_metrics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 6b9c4d5a..91c2fe42 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -5,7 +5,6 @@ from traccuracy._utils import get_relevant_kwargs, validate_matched_data if TYPE_CHECKING: - from traccuracy import TrackingGraph from traccuracy.matchers._matched import Matched from traccuracy.metrics._base import Metric From af482abeacc6b5b7962a9ba692c8600a1bfc9bf7 Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Tue, 28 Nov 2023 10:30:41 -0500 Subject: [PATCH 3/3] remove more quotes --- src/traccuracy/matchers/_ctc.py | 4 +++- src/traccuracy/matchers/_iou.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 25827ff2..23f02be8 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING import networkx as nx @@ -24,7 +26,7 @@ class CTCMatcher(Matcher): for complete details. """ - def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Run ctc matching Args: diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 189c4e6d..7a48fbc4 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from tqdm import tqdm @@ -110,7 +112,7 @@ class IOUMatcher(Matcher): def __init__(self, iou_threshold=0.6): self.iou_threshold = iou_threshold - def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Computes IOU mapping for a set of grpahs Args: