Skip to content

Commit

Permalink
Merge pull request #50 from Janelia-Trackathon-2023/rmv-tracking-data
Browse files Browse the repository at this point in the history
Remove `TrackingData`
  • Loading branch information
msschwartz21 authored Sep 8, 2023
2 parents a60a976 + f2c814e commit d2064c4
Show file tree
Hide file tree
Showing 21 changed files with 119 additions and 170 deletions.
3 changes: 1 addition & 2 deletions src/traccuracy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
__version__ = "uninstalled"

from ._run_metrics import run_metrics
from ._tracking_data import TrackingData
from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph

__all__ = ["TrackingData", "TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"]
__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"]
6 changes: 3 additions & 3 deletions src/traccuracy/_run_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
if TYPE_CHECKING:
from typing import Dict, List, Optional, Type

from traccuracy import TrackingData
from traccuracy import TrackingGraph
from traccuracy.matchers._matched import Matched
from traccuracy.metrics._base import Metric


def run_metrics(
gt_data: "TrackingData",
pred_data: "TrackingData",
gt_data: "TrackingGraph",
pred_data: "TrackingGraph",
matcher: "Type[Matched]",
metrics: "List[Type[Metric]]",
matcher_kwargs: "Optional[Dict]" = None,
Expand Down
24 changes: 0 additions & 24 deletions src/traccuracy/_tracking_data.py

This file was deleted.

6 changes: 5 additions & 1 deletion src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class TrackingGraph:
def __init__(
self,
graph,
segmentation=None,
frame_key="t",
label_key="segmentation_id",
location_keys=("x", "y"),
Expand All @@ -122,6 +123,9 @@ def __init__(
solution where edges go forward in time. If the graph already
has annotations that are strings included in NodeAttrs or
EdgeAttrs, this will likely ruin metrics computation!
segmentation (numpy-like array, optional): A numpy-like array of segmentations.
The location of each node in tracking_graph is assumed to be inside the
area of the corresponding segmentation. Defaults to None.
frame_key (str, optional): The key on each node in graph that
contains the time frameof the node. Every node must have a
value stored at this key. Defaults to 't'.
Expand All @@ -134,7 +138,7 @@ def __init__(
node must have a value stored at each of these keys.
Defaults to ('x', 'y').
"""
self.graph = graph
self.segmentation = segmentation
if NodeAttr.has_value(frame_key):
raise ValueError(
f"Specified frame key {frame_key} is reserved for graph"
Expand Down
4 changes: 2 additions & 2 deletions src/traccuracy/loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
""" Subpackage for loading tracking data into memory
This subpackage contains functions for loading ground
truth or tracking method outputs into memory as TrackingData objects.
Each loading function must return one TrackingData object which has a
truth or tracking method outputs into memory as TrackingGraph objects.
Each loading function must return one TrackingGraph object which has a
track graph and optionally contains a corresponding segmentation.
"""
from ._ctc import load_ctc_data
Expand Down
5 changes: 1 addition & 4 deletions src/traccuracy/loaders/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tifffile import imread
from tqdm import tqdm

from traccuracy._tracking_data import TrackingData
from traccuracy._tracking_graph import TrackingGraph


Expand Down Expand Up @@ -199,6 +198,4 @@ def load_ctc_data(data_dir, track_path=None):

G = ctc_to_graph(tracks, detections)

data = TrackingData(TrackingGraph(G), segmentation=masks)

return data
return TrackingGraph(G, segmentation=masks)
26 changes: 13 additions & 13 deletions src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from tqdm import tqdm

from traccuracy._tracking_data import TrackingData
from traccuracy._tracking_graph import TrackingGraph

from ._compute_overlap import get_labels_with_overlap
from ._matched import Matched
Expand Down Expand Up @@ -30,21 +30,21 @@ def _match_ctc(self):
and pred node
Raises:
ValueError: gt and pred must be a TrackingData object
ValueError: gt and pred must be a TrackingGraph object
ValueError: GT and pred segmentations must be the same shape
"""
if not isinstance(self.gt_data, TrackingData) or not isinstance(
self.pred_data, TrackingData
if not isinstance(self.gt_graph, TrackingGraph) or not isinstance(
self.pred_graph, TrackingGraph
):
raise ValueError(
"Input data must be a TrackingData object with a graph and segmentations"
)
gt = self.gt_data
pred = self.pred_data
gt_label_key = self.gt_data.tracking_graph.label_key
pred_label_key = self.pred_data.tracking_graph.label_key
G_gt, mask_gt = gt.tracking_graph, gt.segmentation
G_pred, mask_pred = pred.tracking_graph, pred.segmentation
gt = self.gt_graph
pred = self.pred_graph
gt_label_key = self.gt_graph.label_key
pred_label_key = self.pred_graph.label_key
G_gt, mask_gt = gt, gt.segmentation
G_pred, mask_pred = pred, pred.segmentation

if mask_gt.shape != mask_pred.shape:
raise ValueError("Segmentation shapes must match between gt and pred")
Expand All @@ -53,14 +53,14 @@ def _match_ctc(self):
# Get overlaps for each frame
for i, t in enumerate(
tqdm(
range(gt.tracking_graph.start_frame, gt.tracking_graph.end_frame),
range(gt.start_frame, gt.end_frame),
desc="Matching frames",
)
):
gt_frame = mask_gt[i]
pred_frame = mask_pred[i]
gt_frame_nodes = gt.tracking_graph.nodes_by_frame[t]
pred_frame_nodes = pred.tracking_graph.nodes_by_frame[t]
gt_frame_nodes = gt.nodes_by_frame[t]
pred_frame_nodes = pred.nodes_by_frame[t]

# get the labels for this frame
gt_labels = dict(
Expand Down
43 changes: 21 additions & 22 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from tqdm import tqdm

from traccuracy._tracking_data import TrackingData
from traccuracy._tracking_graph import TrackingGraph

from ._compute_overlap import get_labels_with_overlap
from ._matched import Matched
Expand Down Expand Up @@ -52,8 +52,8 @@ def match_iou(gt, pred, threshold=0.6):
and that the label is recorded on each node using label_key
Args:
gt (TrackingData): Tracking data object containing graph and segmentations
pred (TrackingData): Tracking data object containing graph and segmentations
gt (TrackingGraph): Tracking data object containing graph and segmentations
pred (TrackingGraph): Tracking data object containing graph and segmentations
threshold (float, optional): Minimum IoU for matching cells. Defaults to 0.6.
Returns:
Expand All @@ -63,51 +63,50 @@ def match_iou(gt, pred, threshold=0.6):
ValueError: gt and pred must be a TrackingData object
ValueError: GT and pred segmentations must be the same shape
"""
if not isinstance(gt, TrackingData) or not isinstance(pred, TrackingData):
if not isinstance(gt, TrackingGraph) or not isinstance(pred, TrackingGraph):
raise ValueError(
"Input data must be a TrackingData object with a graph and segmentations"
)

mapper = []

G_gt, mask_gt = gt.tracking_graph, gt.segmentation
G_pred, mask_pred = pred.tracking_graph, pred.segmentation

if mask_gt.shape != mask_pred.shape:
if gt.segmentation.shape != pred.segmentation.shape:
raise ValueError("Segmentation shapes must match between gt and pred")

# Get overlaps for each frame
frame_range = range(gt.tracking_graph.start_frame, gt.tracking_graph.end_frame)
frame_range = range(gt.start_frame, gt.end_frame)
total = len(list(frame_range))
for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total):
matches = _match_nodes(mask_gt[i], mask_pred[i], threshold=threshold)
matches = _match_nodes(
gt.segmentation[i], pred.segmentation[i], threshold=threshold
)

# Construct node id tuple for each match
for gt_id, pred_id in zip(*matches):
# Find node id based on time and segmentation label
gt_node = G_gt.get_nodes_with_attribute(
G_gt.label_key,
gt_node = gt.get_nodes_with_attribute(
gt.label_key,
criterion=lambda x: x == gt_id, # noqa
limit_to=G_gt.get_nodes_in_frame(t),
limit_to=gt.get_nodes_in_frame(t),
)[0]
pred_node = G_pred.get_nodes_with_attribute(
G_pred.label_key,
pred_node = pred.get_nodes_with_attribute(
pred.label_key,
criterion=lambda x: x == pred_id, # noqa
limit_to=G_pred.get_nodes_in_frame(t),
limit_to=pred.get_nodes_in_frame(t),
)[0]
mapper.append((gt_node, pred_node))
return mapper


class IOUMatched(Matched):
def __init__(self, gt_data, pred_data, iou_threshold=0.6):
def __init__(self, gt_graph, pred_graph, iou_threshold=0.6):
"""Constructs a mapping between gt and pred nodes using the IoU of the segmentations
Lower values for iou_threshold will be more permissive of imperfect matches
Args:
gt_data (TrackingData): TrackingData for the ground truth with segmentations
pred_data (TrackingData): TrackingData for the prediction with segmentations
gt_graph (TrackingGraph): TrackingGraph for the ground truth with segmentations
pred_graph (TrackingGraph): TrackingGraph for the prediction with segmentations
iou_threshold (float, optional): Minimum IoU value to assign a match. Defaults to 0.6.
Raises:
Expand All @@ -116,12 +115,12 @@ def __init__(self, gt_data, pred_data, iou_threshold=0.6):
self.iou_threshold = iou_threshold

# Check that segmentations exist in the data
if gt_data.segmentation is None or pred_data.segmentation is None:
if gt_graph.segmentation is None or pred_graph.segmentation is None:
raise ValueError(
"Segmentation data must be provided for both gt and pred data"
)

super().__init__(gt_data, pred_data)
super().__init__(gt_graph, pred_graph)

def compute_mapping(self):
return match_iou(self.gt_data, self.pred_data, threshold=self.iou_threshold)
return match_iou(self.gt_graph, self.pred_graph, threshold=self.iou_threshold)
16 changes: 8 additions & 8 deletions src/traccuracy/matchers/_matched.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from abc import ABC, abstractmethod

from traccuracy._tracking_data import TrackingData
from traccuracy._tracking_graph import TrackingGraph


class Matched(ABC):
def __init__(self, gt_data: "TrackingData", pred_data: "TrackingData"):
def __init__(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"):
"""Matched class which takes TrackingData objects for gt and pred, and computes matching.
Each current matching method will be a subclass of Matched e.g. CTCMatched or IOUMatched.
The Matched objects will store both gt and pred data, as well as the mapping,
and any additional private attributes that may be needed/used e.g. detection matrices.
Args:
gt_data (TrackingData): Tracking data object for the gt
pred_data (TrackingData): Tracking data object for the pred
gt_graph (TrackingGraph): Tracking graph object for the gt
pred_graph (TrackingGraph): Tracking graph object for the pred
"""
self.gt_data = gt_data
self.pred_data = pred_data
self.gt_graph = gt_graph
self.pred_graph = pred_graph

self.mapping = self.compute_mapping()

# Report matching performance
total_gt = len(self.gt_data.tracking_graph.nodes())
total_gt = len(self.gt_graph.nodes())
matched_gt = len({m[0] for m in self.mapping})
total_pred = len(self.pred_data.tracking_graph.nodes())
total_pred = len(self.pred_graph.nodes())
matched_pred = len({m[1] for m in self.mapping})
print(f"Matched {matched_gt} out of {total_gt} ground truth nodes.")
print(f"Matched {matched_pred} out of {total_pred} predicted nodes.")
Expand Down
14 changes: 7 additions & 7 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,34 @@ def compute(self):

vertex_error_counts = {
"ns": len(
self.data.pred_data.tracking_graph.get_nodes_with_attribute(
self.data.pred_graph.get_nodes_with_attribute(
NodeAttr.NON_SPLIT, lambda x: x
)
),
"fp": len(
self.data.pred_data.tracking_graph.get_nodes_with_attribute(
self.data.pred_graph.get_nodes_with_attribute(
NodeAttr.FALSE_POS, lambda x: x
)
),
"fn": len(
self.data.gt_data.tracking_graph.get_nodes_with_attribute(
self.data.gt_graph.get_nodes_with_attribute(
NodeAttr.FALSE_NEG, lambda x: x
)
),
}
edge_error_counts = {
"ws": len(
self.data.pred_data.tracking_graph.get_edges_with_attribute(
self.data.pred_graph.get_edges_with_attribute(
EdgeAttr.WRONG_SEMANTIC, lambda x: x
)
),
"fp": len(
self.data.pred_data.tracking_graph.get_edges_with_attribute(
self.data.pred_graph.get_edges_with_attribute(
EdgeAttr.FALSE_POS, lambda x: x
)
),
"fn": len(
self.data.gt_data.tracking_graph.get_edges_with_attribute(
self.data.gt_graph.get_edges_with_attribute(
EdgeAttr.FALSE_NEG, lambda x: x
)
),
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, matched_data: "Matched"):

def compute(self):
# AOGM-0 is the cost of creating the gt graph from scratch
gt_graph = self.data.gt_data.tracking_graph.graph
gt_graph = self.data.gt_graph.graph
n_nodes = gt_graph.number_of_nodes()
n_edges = gt_graph.number_of_edges()
aogm_0 = n_nodes * self.v_weights["fn"] + n_edges * self.e_weights["fn"]
Expand Down
4 changes: 2 additions & 2 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def compute(self):

return {
f"Frame Buffer {fb}": self._calculate_metrics(
matched_data.gt_data.tracking_graph,
matched_data.pred_data.tracking_graph,
matched_data.gt_graph,
matched_data.pred_graph,
)
for fb, matched_data in div_annotations.items()
}
Expand Down
8 changes: 4 additions & 4 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_vertex_errors(matched_data: "Matched"):
matched_data: Matched
Matched data object containing gt and pred graphs with their associated mapping
"""
comp_graph = matched_data.pred_data.tracking_graph
gt_graph = matched_data.gt_data.tracking_graph
comp_graph = matched_data.pred_graph
gt_graph = matched_data.gt_graph
mapping = matched_data.mapping

if comp_graph.node_errors and gt_graph.node_errors:
Expand Down Expand Up @@ -72,8 +72,8 @@ def get_vertex_errors(matched_data: "Matched"):


def get_edge_errors(matched_data: "Matched"):
comp_graph = matched_data.pred_data.tracking_graph
gt_graph = matched_data.gt_data.tracking_graph
comp_graph = matched_data.pred_graph
gt_graph = matched_data.gt_graph
node_mapping = matched_data.mapping

if comp_graph.edge_errors and gt_graph.edge_errors:
Expand Down
Loading

0 comments on commit d2064c4

Please sign in to comment.