Skip to content

Commit

Permalink
Merge pull request #104 from tlambert03/future
Browse files Browse the repository at this point in the history
style: cleanup string type annotations
  • Loading branch information
cmalinmayor authored Nov 29, 2023
2 parents 2ed94df + af482ab commit ef50a44
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 57 deletions.
15 changes: 6 additions & 9 deletions src/traccuracy/_run_metrics.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from typing import TYPE_CHECKING
from __future__ import annotations

from traccuracy._tracking_graph import TrackingGraph
from traccuracy.matchers._base import Matcher
from traccuracy.metrics._base import Metric

if TYPE_CHECKING:
from typing import Dict, List


def run_metrics(
gt_data: "TrackingGraph",
pred_data: "TrackingGraph",
matcher: "Matcher",
metrics: "List[Metric]",
) -> "List[Dict]":
gt_data: TrackingGraph,
pred_data: TrackingGraph,
matcher: Matcher,
metrics: list[Metric],
) -> list[dict]:
"""Compute given metrics on data using the given matcher.
The returned result dictionary will contain all metrics computed by
Expand Down
74 changes: 37 additions & 37 deletions src/traccuracy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@


def load_all_ctc(
gt_dir: "str",
pred_dir: "str",
gt_track_path: "Optional[str]" = None,
pred_track_path: "Optional[str]" = None,
gt_dir: str,
pred_dir: str,
gt_track_path: Optional[str] = None,
pred_track_path: Optional[str] = None,
):
gt_data = load_ctc_data(gt_dir, gt_track_path)
pred_data = load_ctc_data(pred_dir, pred_track_path)
Expand All @@ -25,18 +25,18 @@ def load_all_ctc(

@app.command()
def run_ctc(
gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: "str" = typer.Argument(
gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: str = typer.Argument(
..., help="Path to prediction/RES tiffs", show_default=False
),
gt_track_path: "Optional[str]" = typer.Option(
gt_track_path: Optional[str] = typer.Option(
None, help="Path to ctc gt track file", show_default=False
),
pred_track_path: "Optional[str]" = typer.Option(
pred_track_path: Optional[str] = typer.Option(
None, help="Path to predicted track file", show_default=False
),
loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: "str" = typer.Option("ctc_log.json", help="Path to save results"),
loader: str = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: str = typer.Option("ctc_log.json", help="Path to save results"),
):
"""
Run TRA and DET metric on gt and pred data using CTC matching.
Expand Down Expand Up @@ -66,34 +66,34 @@ def run_ctc(

@app.command()
def run_aogm(
gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: "str" = typer.Argument(
gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: str = typer.Argument(
..., help="Path to prediction/RES tiffs", show_default=False
),
gt_track_path: "Optional[str]" = typer.Option(
gt_track_path: Optional[str] = typer.Option(
None, help="Path to ctc gt track file", show_default=False
),
pred_track_path: "Optional[str]" = typer.Option(
pred_track_path: Optional[str] = typer.Option(
None, help="Path to predicted track file", show_default=False
),
loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: "str" = typer.Option("aogm_log.json", help="Path to save results"),
vertex_ns_weight: "float" = typer.Option(
loader: str = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: str = typer.Option("aogm_log.json", help="Path to save results"),
vertex_ns_weight: float = typer.Option(
1, help="Weight to assign to nonsplit vertex errors"
),
vertex_fp_weight: "float" = typer.Option(
vertex_fp_weight: float = typer.Option(
1, help="Weight to assign to false positive vertex errors"
),
vertex_fn_weight: "float" = typer.Option(
vertex_fn_weight: float = typer.Option(
1, help="Weight to assign to false negative vertex errors"
),
edge_fp_weight: "float" = typer.Option(
edge_fp_weight: float = typer.Option(
1, help="Weight to assign to false positive edge errors"
),
edge_fn_weight: "float" = typer.Option(
edge_fn_weight: float = typer.Option(
1, help="Weight to assign to false negative edge errors"
),
edge_ws_weight: "float" = typer.Option(
edge_ws_weight: float = typer.Option(
1, help="Weight to assign to edges with incorrect semantics"
),
):
Expand Down Expand Up @@ -139,24 +139,24 @@ def run_aogm(

@app.command()
def run_divisions_on_iou(
gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: "str" = typer.Argument(
gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: str = typer.Argument(
..., help="Path to prediction/RES tiffs", show_default=False
),
gt_track_path: "Optional[str]" = typer.Option(
gt_track_path: Optional[str] = typer.Option(
None, help="Path to ctc gt track file", show_default=False
),
pred_track_path: "Optional[str]" = typer.Option(
pred_track_path: Optional[str] = typer.Option(
None, help="Path to predicted track file", show_default=False
),
loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: "str" = typer.Option("div_log_iou.json", help="Path to save results"),
match_threshold: "float" = typer.Option(
loader: str = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: str = typer.Option("div_log_iou.json", help="Path to save results"),
match_threshold: float = typer.Option(
1,
help="Threshold above which the intersection over union of a gt and predicted"
" detection match. Default of 1 requires exact matching.",
),
frame_buffer: "int" = typer.Option(
frame_buffer: int = typer.Option(
0,
help="Number of frames to use for division tolerance."
" Numbers greater than 0 will produce metrics for 0...n inclusive.",
Expand Down Expand Up @@ -199,19 +199,19 @@ def run_divisions_on_iou(

@app.command()
def run_divisions_on_ctc(
gt_dir: "str" = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: "str" = typer.Argument(
gt_dir: str = typer.Argument(..., help="Path to GT tiffs", show_default=False),
pred_dir: str = typer.Argument(
..., help="Path to prediction/RES tiffs", show_default=False
),
gt_track_path: "Optional[str]" = typer.Option(
gt_track_path: Optional[str] = typer.Option(
None, help="Path to ctc gt track file", show_default=False
),
pred_track_path: "Optional[str]" = typer.Option(
pred_track_path: Optional[str] = typer.Option(
None, help="Path to predicted track file", show_default=False
),
loader: "str" = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: "str" = typer.Option("div_log_ctc.json", help="Path to save results"),
frame_buffer: "int" = typer.Option(
loader: str = typer.Option("ctc", help="Loader to bring data into memory"),
out_path: str = typer.Option("div_log_ctc.json", help="Path to save results"),
frame_buffer: int = typer.Option(
0,
help="Number of frames to use for division tolerance."
" Numbers greater than 0 will produce metrics for 0...n inclusive.",
Expand Down
6 changes: 4 additions & 2 deletions src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import networkx as nx
Expand All @@ -24,7 +26,7 @@ class CTCMatcher(Matcher):
for complete details.
"""

def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"):
def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
"""Run ctc matching
Args:
Expand Down Expand Up @@ -95,7 +97,7 @@ def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph
return Matched(gt_graph, pred_graph, mapping)


def detection_test(gt_blob: "np.ndarray", comp_blob: "np.ndarray") -> int:
def detection_test(gt_blob: np.ndarray, comp_blob: np.ndarray) -> int:
"""Check if computed marker overlaps majority of the reference marker.
Given a reference marker and computer marker in original coordinates,
Expand Down
4 changes: 3 additions & 1 deletion src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -110,7 +112,7 @@ class IOUMatcher(Matcher):
def __init__(self, iou_threshold=0.6):
self.iou_threshold = iou_threshold

def _compute_mapping(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"):
def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
"""Computes IOU mapping for a set of grpahs
Args:
Expand Down
6 changes: 4 additions & 2 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
Expand Down Expand Up @@ -32,7 +34,7 @@ def __init__(
"ws": edge_ws_weight,
}

def compute(self, data: "Matched"):
def compute(self, data: Matched):
evaluate_ctc_events(data)

vertex_error_counts = {
Expand Down Expand Up @@ -84,7 +86,7 @@ def __init__(self):
edge_ws_weight=edge_weight_ws,
)

def compute(self, data: "Matched"):
def compute(self, data: Matched):
# AOGM-0 is the cost of creating the gt graph from scratch
gt_graph = data.gt_graph.graph
n_nodes = gt_graph.number_of_nodes()
Expand Down
8 changes: 5 additions & 3 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import TYPE_CHECKING
Expand All @@ -12,7 +14,7 @@
logger = logging.getLogger(__name__)


def evaluate_ctc_events(matched_data: "Matched"):
def evaluate_ctc_events(matched_data: Matched):
"""Annotates ground truth and predicted graph with node and edge error types
Annotations are made in place
Expand All @@ -21,7 +23,7 @@ def evaluate_ctc_events(matched_data: "Matched"):
get_edge_errors(matched_data)


def get_vertex_errors(matched_data: "Matched"):
def get_vertex_errors(matched_data: Matched):
"""Count vertex errors and assign class to each comp/gt node.
Parameters
Expand Down Expand Up @@ -70,7 +72,7 @@ def get_vertex_errors(matched_data: "Matched"):
gt_graph.node_errors = True


def get_edge_errors(matched_data: "Matched"):
def get_edge_errors(matched_data: Matched):
comp_graph = matched_data.pred_graph
gt_graph = matched_data.gt_graph
node_mapping = matched_data.mapping
Expand Down
8 changes: 5 additions & 3 deletions src/traccuracy/track_errors/divisions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import itertools
import logging
Expand All @@ -13,7 +15,7 @@
logger = logging.getLogger(__name__)


def _classify_divisions(matched_data: "Matched"):
def _classify_divisions(matched_data: Matched):
"""Identify each division as a true positive, false positive or false negative
This function only works on node mappers that are one-to-one
Expand Down Expand Up @@ -139,7 +141,7 @@ def _get_succ_by_t(g, node, delta_frames):
return node


def _correct_shifted_divisions(matched_data: "Matched", n_frames=1):
def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
"""Allows for divisions to occur within a frame buffer and still be correct
This implementation asserts that the parent lineages and daughter lineages must match.
Expand Down Expand Up @@ -236,7 +238,7 @@ def _correct_shifted_divisions(matched_data: "Matched", n_frames=1):
return new_matched


def _evaluate_division_events(matched_data: "Matched", frame_buffer=(0)):
def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)):
"""Classify division errors and correct shifted divisions according to frame_buffer
Note: A copy of matched_data will be created for each frame_buffer other than 0.
Expand Down

0 comments on commit ef50a44

Please sign in to comment.