Skip to content

Commit

Permalink
Merge pull request #136 from Janelia-Trackathon-2023/framebuffer
Browse files Browse the repository at this point in the history
Change `frame_buffer` kwarg to `max_frame_buffer`
  • Loading branch information
msschwartz21 authored Jan 11, 2024
2 parents 3d85480 + e04b03c commit 32de15e
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 20 deletions.
4 changes: 2 additions & 2 deletions src/traccuracy/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def run_divisions_on_iou(
gt_data,
pred_data,
IOUMatcher(iou_threshold=match_threshold),
[DivisionMetrics(frame_buffer=frame_buffer_tuple)],
[DivisionMetrics(max_frame_buffer=frame_buffer_tuple)],
)
with open(out_path, "w") as fp:
json.dump(result, fp)
Expand Down Expand Up @@ -242,7 +242,7 @@ def run_divisions_on_ctc(
gt_data,
pred_data,
CTCMatcher(),
[DivisionMetrics(frame_buffer=frame_buffer_tuple)],
[DivisionMetrics(max_frame_buffer=frame_buffer_tuple)],
)
with open(out_path, "w") as fp:
json.dump(result, fp)
Expand Down
11 changes: 6 additions & 5 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ class DivisionMetrics(Metric):
3, 734559 (2021).
Args:
frame_buffer (tuple(int), optional): Tuple of integers. Value used as n_frames
to tolerate in correct_shifted_divisions. Defaults to (0).
max_frame_buffer (int, optional): Maximum value of frame buffer to use in correcting
shifted divisions. Divisions will be evaluated for all integer values of frame
buffer between 0 and max_frame_buffer
"""

needs_one_to_one = True

def __init__(self, frame_buffer=(0,)):
self.frame_buffer = frame_buffer
def __init__(self, max_frame_buffer=0):
self.frame_buffer = max_frame_buffer

def _compute(self, data: Matched):
"""Runs `_evaluate_division_events` and calculates summary metrics for each frame buffer
Expand All @@ -78,7 +79,7 @@ def _compute(self, data: Matched):
"""
div_annotations = _evaluate_division_events(
data,
frame_buffer=self.frame_buffer,
max_frame_buffer=self.frame_buffer,
)

return {
Expand Down
13 changes: 5 additions & 8 deletions src/traccuracy/track_errors/divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
return new_matched


def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)):
def _evaluate_division_events(matched_data: Matched, max_frame_buffer=0):
"""Classify division errors and correct shifted divisions according to frame_buffer
Note: A copy of matched_data will be created for each frame_buffer other than 0.
Expand All @@ -248,8 +248,9 @@ def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)):
Args:
matched_data (Matched): Matched data object containing gt and pred graphs
with their associated mapping
frame_buffer (tuple, optional): Tuple of integers. Value used as n_frames
to tolerate in correct_shifted_divisions. Defaults to (0).
max_frame_buffer (int, optional): Maximum value of frame buffer to use in correcting
shifted divisions. Divisions will be evaluated for all integer values of frame
buffer between 0 and max_frame_buffer
Returns:
dict {frame_buffer: matched_data}: A dictionary where each key corresponds to a frame
Expand All @@ -263,11 +264,7 @@ def _evaluate_division_events(matched_data: Matched, frame_buffer=(0)):
div_annotations[0] = matched_data

# Correct shifted divisions for each nonzero value in frame_buffer
for delta in frame_buffer:
# Skip 0 because we used that in baseline classification
if delta == 0:
continue

for delta in range(1, max_frame_buffer + 1):
corrected_matched = _correct_shifted_divisions(matched_data, n_frames=delta)
div_annotations[delta] = corrected_matched

Expand Down
6 changes: 3 additions & 3 deletions tests/metrics/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ def test_DivisionMetrics():
TrackingGraph(g_pred),
mapper,
)
frame_buffer = (0, 1, 2)
frame_buffer = 2

results = DivisionMetrics(frame_buffer=frame_buffer)._compute(matched)
results = DivisionMetrics(max_frame_buffer=frame_buffer)._compute(matched)

for name, r in results.items():
buffer = int(name[-1:])
assert buffer in frame_buffer
assert buffer in list(range(frame_buffer + 1))
if buffer in (0, 1):
# No corrections
assert r["True Positive Divisions"] == 0
Expand Down
4 changes: 2 additions & 2 deletions tests/track_errors/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def test_fp_early(self):

def test_evaluate_division_events():
g_gt, g_pred, mapper = get_division_graphs()
frame_buffer = (0, 1, 2)
frame_buffer = 2

matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper)

results = _evaluate_division_events(matched_data, frame_buffer=frame_buffer)
results = _evaluate_division_events(matched_data, max_frame_buffer=frame_buffer)

assert np.all([isinstance(k, int) for k in results.keys()])

0 comments on commit 32de15e

Please sign in to comment.