diff --git a/src/traccuracy/_run_metrics.py b/src/traccuracy/_run_metrics.py index 8e328922..7d39c2e5 100644 --- a/src/traccuracy/_run_metrics.py +++ b/src/traccuracy/_run_metrics.py @@ -1,19 +1,16 @@ -from typing import TYPE_CHECKING +from __future__ import annotations from traccuracy._tracking_graph import TrackingGraph from traccuracy.matchers._base import Matcher from traccuracy.metrics._base import Metric -if TYPE_CHECKING: - from typing import Dict, List - def run_metrics( - gt_data: "TrackingGraph", - pred_data: "TrackingGraph", - matcher: "Matcher", - metrics: "List[Metric]", -) -> "List[Dict]": + gt_data: TrackingGraph, + pred_data: TrackingGraph, + matcher: Matcher, + metrics: list[Metric], +) -> list[dict]: """Compute given metrics on data using the given matcher. The returned result dictionary will contain all metrics computed by diff --git a/src/traccuracy/cli.py b/src/traccuracy/cli.py index ce178b83..a8b3249d 100644 --- a/src/traccuracy/cli.py +++ b/src/traccuracy/cli.py @@ -13,10 +13,10 @@ def load_all_ctc( - gt_dir: "str", - pred_dir: "str", - gt_track_path: "Optional[str]" = None, - pred_track_path: "Optional[str]" = None, + gt_dir: str, + pred_dir: str, + gt_track_path: Optional[str] = None, + pred_track_path: Optional[str] = None, ): gt_data = load_ctc_data(gt_dir, gt_track_path) pred_data = load_ctc_data(pred_dir, pred_track_path) @@ -25,18 +25,18 @@ def load_all_ctc( @app.command() def run_ctc( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("ctc_log.json", help="Path to save results"), + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("ctc_log.json", help="Path to save results"), ): """ Run TRA and DET metric on gt and pred data using CTC matching. @@ -66,34 +66,34 @@ def run_ctc( @app.command() def run_aogm( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("aogm_log.json", help="Path to save results"), - vertex_ns_weight: "float" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("aogm_log.json", help="Path to save results"), + vertex_ns_weight: float = typer.Option( 1, help="Weight to assign to nonsplit vertex errors" ), - vertex_fp_weight: "float" = typer.Option( + vertex_fp_weight: float = typer.Option( 1, help="Weight to assign to false positive vertex errors" ), - vertex_fn_weight: "float" = typer.Option( + vertex_fn_weight: float = typer.Option( 1, help="Weight to assign to false negative vertex errors" ), - edge_fp_weight: "float" = typer.Option( + edge_fp_weight: float = typer.Option( 1, help="Weight to assign to false positive edge errors" ), - edge_fn_weight: "float" = typer.Option( + edge_fn_weight: float = typer.Option( 1, help="Weight to assign to false negative edge errors" ), - edge_ws_weight: "float" = typer.Option( + edge_ws_weight: float = typer.Option( 1, help="Weight to assign to edges with incorrect semantics" ), ): @@ -139,24 +139,24 @@ def run_aogm( @app.command() def run_divisions_on_iou( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("div_log_iou.json", help="Path to save results"), - match_threshold: "float" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("div_log_iou.json", help="Path to save results"), + match_threshold: float = typer.Option( 1, help="Threshold above which the intersection over union of a gt and predicted" " detection match. Default of 1 requires exact matching.", ), - frame_buffer: "int" = typer.Option( + frame_buffer: int = typer.Option( 0, help="Number of frames to use for division tolerance." " Numbers greater than 0 will produce metrics for 0...n inclusive.", @@ -199,19 +199,19 @@ def run_divisions_on_iou( @app.command() def run_divisions_on_ctc( - gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False), - pred_dir: "str" = typer.Argument( + gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False), + pred_dir: str = typer.Argument( ..., help="Path to prediction/RES tiffs", show_default=False ), - gt_track_path: "Optional[str]" = typer.Option( + gt_track_path: Optional[str] = typer.Option( None, help="Path to ctc gt track file", show_default=False ), - pred_track_path: "Optional[str]" = typer.Option( + pred_track_path: Optional[str] = typer.Option( None, help="Path to predicted track file", show_default=False ), - loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"), - out_path: "str" = typer.Option("div_log_ctc.json", help="Path to save results"), - frame_buffer: "int" = typer.Option( + loader: str = typer.Option("ctc", help="Loader to bring data into memory"), + out_path: str = typer.Option("div_log_ctc.json", help="Path to save results"), + frame_buffer: int = typer.Option( 0, help="Number of frames to use for division tolerance." " Numbers greater than 0 will produce metrics for 0...n inclusive.", diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index ca245e36..23f02be8 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING import networkx as nx @@ -24,7 +26,7 @@ class CTCMatcher(Matcher): for complete details. """ - def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Run ctc matching Args: @@ -95,7 +97,7 @@ def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph return Matched(gt_graph, pred_graph, mapping) -def detection_test(gt_blob: "np.ndarray", comp_blob: "np.ndarray") -> int: +def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int: """Check if computed marker overlaps majority of the reference marker. Given a reference marker and computer marker in original coordinates, diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 189c4e6d..7a48fbc4 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from tqdm import tqdm @@ -110,7 +112,7 @@ class IOUMatcher(Matcher): def __init__(self, iou_threshold=0.6): self.iou_threshold = iou_threshold - def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): + def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): """Computes IOU mapping for a set of grpahs Args: diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index 795050dd..28b6ad11 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import TYPE_CHECKING from traccuracy._tracking_graph import EdgeAttr, NodeAttr @@ -32,7 +34,7 @@ def __init__( "ws": edge_ws_weight, } - def compute(self, data: "Matched"): + def compute(self, data: Matched): evaluate_ctc_events(data) vertex_error_counts = { @@ -84,7 +86,7 @@ def __init__(self): edge_ws_weight=edge_weight_ws, ) - def compute(self, data: "Matched"): + def compute(self, data: Matched): # AOGM-0 is the cost of creating the gt graph from scratch gt_graph = data.gt_graph.graph n_nodes = gt_graph.number_of_nodes() diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 27bdc8dc..bee72a8c 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from collections import defaultdict from typing import TYPE_CHECKING @@ -12,7 +14,7 @@ logger = logging.getLogger(__name__) -def evaluate_ctc_events(matched_data: "Matched"): +def evaluate_ctc_events(matched_data: Matched): """Annotates ground truth and predicted graph with node and edge error types Annotations are made in place @@ -21,7 +23,7 @@ def evaluate_ctc_events(matched_data: "Matched"): get_edge_errors(matched_data) -def get_vertex_errors(matched_data: "Matched"): +def get_vertex_errors(matched_data: Matched): """Count vertex errors and assign class to each comp/gt node. Parameters @@ -70,7 +72,7 @@ def get_vertex_errors(matched_data: "Matched"): gt_graph.node_errors = True -def get_edge_errors(matched_data: "Matched"): +def get_edge_errors(matched_data: Matched): comp_graph = matched_data.pred_graph gt_graph = matched_data.gt_graph node_mapping = matched_data.mapping diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index d2b359ec..7ab3fc6b 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import itertools import logging @@ -13,7 +15,7 @@ logger = logging.getLogger(__name__) -def _classify_divisions(matched_data: "Matched"): +def _classify_divisions(matched_data: Matched): """Identify each division as a true positive, false positive or false negative This function only works on node mappers that are one-to-one @@ -139,7 +141,7 @@ def _get_succ_by_t(g, node, delta_frames): return node -def _correct_shifted_divisions(matched_data: "Matched", n_frames=1): +def _correct_shifted_divisions(matched_data: Matched, n_frames=1): """Allows for divisions to occur within a frame buffer and still be correct This implementation asserts that the parent lineages and daughter lineages must match. @@ -236,7 +238,7 @@ def _correct_shifted_divisions(matched_data: "Matched", n_frames=1): return new_matched -def _evaluate_division_events(matched_data: "Matched", frame_buffer=(0)): +def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)): """Classify division errors and correct shifted divisions according to frame_buffer Note: A copy of matched_data will be created for each frame_buffer other than 0.