Skip to content

Commit

Permalink
Merge pull request #74 from Janelia-Trackathon-2023/3d-benchmark
Browse files Browse the repository at this point in the history
Extend benchmarks
  • Loading branch information
bentaculum authored Nov 14, 2024
2 parents ab3cccf + 9aa9a7a commit 3f33797
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 110 deletions.
30 changes: 25 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ jobs:
with:
fetch-depth: 50 # this is to make sure we obtain the target base commit

- name: Retrieve cached data
uses: actions/cache/restore@v4
id: cache_data
with:
path: downloads
key: ${{ hashFiles('scripts/download_test_data.py') }}

- name: Download Samples
if: steps.cache_data.outputs.cache-hit != 'true'
run: |
pip install requests
python scripts/download_test_data.py
- name: Cache sample data
uses: actions/cache/save@v4
if: steps.cache_data.outputs.cache-hit != 'true'
with:
path: downloads
key: ${{ hashFiles('scripts/download_test_data.py') }}

- name: Set up Python
uses: actions/setup-python@v5
with:
Expand All @@ -75,28 +95,28 @@ jobs:
- name: Retrieve cached baseline if available
uses: actions/cache/restore@v4
id: cache
id: cache_baseline
with:
path: baseline.json
key: ${{ github.event.pull_request.base.sha }}

- name: Run baseline benchmark if not in cache
if: steps.cache.outputs.cache-hit != 'true'
if: steps.cache_baseline.outputs.cache-hit != 'true'
run: |
git checkout ${{ github.event.pull_request.base.sha }}
pytest tests/bench.py --benchmark-json baseline.json
pytest tests/bench.py -v --benchmark-json baseline.json
- name: Cache baseline results
uses: actions/cache/save@v4
if: steps.cache.outputs.cache-hit != 'true'
if: steps.cache_baseline.outputs.cache-hit != 'true'
with:
path: baseline.json
key: ${{ github.event.pull_request.base.sha }}

- name: Run benchmark on PR head commit
run: |
git checkout ${{ github.event.pull_request.head.sha }}
pytest tests/bench.py --benchmark-json pr.json
pytest tests/bench.py -v --benchmark-json pr.json
- name: Generate report
run: python .github/workflows/benchmark-pr.py baseline.json pr.json report.md
Expand Down
37 changes: 37 additions & 0 deletions scripts/download_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os
import urllib.request
import zipfile
from pathlib import Path

ROOT_DIR = Path(__file__).resolve().parents[1]
DATASETS = [
"http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip",
"http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip",
"http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip",
]


def download_gt_data(url, root_dir):
data_dir = os.path.join(root_dir, "downloads")

if not os.path.exists(data_dir):
os.mkdir(data_dir)

filename = url.split("/")[-1]
file_path = os.path.join(data_dir, filename)

if not os.path.exists(file_path):
urllib.request.urlretrieve(url, file_path)

# Unzip the data
with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_dir)


def main():
for url in DATASETS:
download_gt_data(url, ROOT_DIR)


if __name__ == "__main__":
main()
18 changes: 11 additions & 7 deletions src/traccuracy/loaders/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def _get_node_attributes(masks):
segmentation_id, x, y, z, t
"""
data_df = pd.concat(
[_detections_from_image(masks, idx) for idx in range(masks.shape[0])]
[
_detections_from_image(masks, idx)
for idx in tqdm(range(masks.shape[0]), desc="Computing node attributes")
],
).reset_index(drop=True)
data_df = data_df.rename(
columns={
Expand Down Expand Up @@ -193,9 +196,9 @@ def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray
parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0]
if parent_end >= row["Start"]:
raise ValueError(
f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} "
f"starts at t={row['Start']}, "
f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}."
"Invalid tracklet connection: Daughter tracklet with ID"
f" {row['Cell_ID']} starts at t={row['Start']}, but parent tracklet"
f" with ID {row['Parent_ID']} only ends at t={parent_end}."
)

for t in range(tracks["Start"].min(), tracks["End"].max()):
Expand Down Expand Up @@ -241,12 +244,13 @@ def load_ctc_data(data_dir, track_path=None, name=None, run_checks=True):
track_paths = list(glob.glob(os.path.join(data_dir, "*_track.txt")))
if not track_paths:
raise ValueError(
f"No track_path passed and a *_track.txt file could not be found in {data_dir}"
"No track_path passed and a *_track.txt file could not be found in"
f" {data_dir}"
)
if len(track_paths) > 1:
raise ValueError(
f"No track_path passed and multiple *_track.txt files found: {track_paths}."
+ " Please pick one and pass it explicitly."
"No track_path passed and multiple *_track.txt files found:"
f" {track_paths}." + " Please pick one and pass it explicitly."
)
track_path = track_paths[0]

Expand Down
6 changes: 1 addition & 5 deletions src/traccuracy/matchers/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import copy
import logging
from abc import ABC, abstractmethod
from typing import Any
Expand Down Expand Up @@ -41,10 +40,7 @@ def compute_mapping(
"Input data must be a TrackingData object with a graph and segmentations"
)

# Copy graphs to avoid possible changes to graphs while computing mapping
matched = self._compute_mapping(
copy.deepcopy(gt_graph), copy.deepcopy(pred_graph)
)
matched = self._compute_mapping(gt_graph, pred_graph)

# Record matcher info on Matched object
matched.matcher_info = self.info
Expand Down
Loading

0 comments on commit 3f33797

Please sign in to comment.