diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 509e122f..530d63ce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -60,6 +60,26 @@ jobs: with: fetch-depth: 50 # this is to make sure we obtain the target base commit + - name: Retrieve cached data + uses: actions/cache/restore@v4 + id: cache_data + with: + path: downloads + key: ${{ hashFiles('scripts/download_test_data.py') }} + + - name: Download Samples + if: steps.cache_data.outputs.cache-hit != 'true' + run: | + pip install requests + python scripts/download_test_data.py + + - name: Cache sample data + uses: actions/cache/save@v4 + if: steps.cache_data.outputs.cache-hit != 'true' + with: + path: downloads + key: ${{ hashFiles('scripts/download_test_data.py') }} + - name: Set up Python uses: actions/setup-python@v5 with: @@ -75,20 +95,20 @@ jobs: - name: Retrieve cached baseline if available uses: actions/cache/restore@v4 - id: cache + id: cache_baseline with: path: baseline.json key: ${{ github.event.pull_request.base.sha }} - name: Run baseline benchmark if not in cache - if: steps.cache.outputs.cache-hit != 'true' + if: steps.cache_baseline.outputs.cache-hit != 'true' run: | git checkout ${{ github.event.pull_request.base.sha }} - pytest tests/bench.py --benchmark-json baseline.json + pytest tests/bench.py -v --benchmark-json baseline.json - name: Cache baseline results uses: actions/cache/save@v4 - if: steps.cache.outputs.cache-hit != 'true' + if: steps.cache_baseline.outputs.cache-hit != 'true' with: path: baseline.json key: ${{ github.event.pull_request.base.sha }} @@ -96,7 +116,7 @@ jobs: - name: Run benchmark on PR head commit run: | git checkout ${{ github.event.pull_request.head.sha }} - pytest tests/bench.py --benchmark-json pr.json + pytest tests/bench.py -v --benchmark-json pr.json - name: Generate report run: python .github/workflows/benchmark-pr.py baseline.json pr.json report.md diff --git a/scripts/download_test_data.py b/scripts/download_test_data.py new file mode 100644 index 00000000..e7597a85 --- /dev/null +++ b/scripts/download_test_data.py @@ -0,0 +1,37 @@ +import os +import urllib.request +import zipfile +from pathlib import Path + +ROOT_DIR = Path(__file__).resolve().parents[1] +DATASETS = [ + "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip", + "http://data.celltrackingchallenge.net/training-datasets/PhC-C2DL-PSC.zip", + "http://data.celltrackingchallenge.net/training-datasets/Fluo-N3DH-CE.zip", +] + + +def download_gt_data(url, root_dir): + data_dir = os.path.join(root_dir, "downloads") + + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + filename = url.split("/")[-1] + file_path = os.path.join(data_dir, filename) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + + # Unzip the data + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + +def main(): + for url in DATASETS: + download_gt_data(url, ROOT_DIR) + + +if __name__ == "__main__": + main() diff --git a/src/traccuracy/loaders/_ctc.py b/src/traccuracy/loaders/_ctc.py index e710ba10..401dc473 100644 --- a/src/traccuracy/loaders/_ctc.py +++ b/src/traccuracy/loaders/_ctc.py @@ -66,7 +66,10 @@ def _get_node_attributes(masks): segmentation_id, x, y, z, t """ data_df = pd.concat( - [_detections_from_image(masks, idx) for idx in range(masks.shape[0])] + [ + _detections_from_image(masks, idx) + for idx in tqdm(range(masks.shape[0]), desc="Computing node attributes") + ], ).reset_index(drop=True) data_df = data_df.rename( columns={ @@ -193,9 +196,9 @@ def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0] if parent_end >= row["Start"]: raise ValueError( - f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} " - f"starts at t={row['Start']}, " - f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}." + "Invalid tracklet connection: Daughter tracklet with ID" + f" {row['Cell_ID']} starts at t={row['Start']}, but parent tracklet" + f" with ID {row['Parent_ID']} only ends at t={parent_end}." ) for t in range(tracks["Start"].min(), tracks["End"].max()): @@ -241,12 +244,13 @@ def load_ctc_data(data_dir, track_path=None, name=None, run_checks=True): track_paths = list(glob.glob(os.path.join(data_dir, "*_track.txt"))) if not track_paths: raise ValueError( - f"No track_path passed and a *_track.txt file could not be found in {data_dir}" + "No track_path passed and a *_track.txt file could not be found in" + f" {data_dir}" ) if len(track_paths) > 1: raise ValueError( - f"No track_path passed and multiple *_track.txt files found: {track_paths}." - + " Please pick one and pass it explicitly." + "No track_path passed and multiple *_track.txt files found:" + f" {track_paths}." + " Please pick one and pass it explicitly." ) track_path = track_paths[0] diff --git a/src/traccuracy/matchers/_base.py b/src/traccuracy/matchers/_base.py index d3056356..3017455f 100644 --- a/src/traccuracy/matchers/_base.py +++ b/src/traccuracy/matchers/_base.py @@ -1,6 +1,5 @@ from __future__ import annotations -import copy import logging from abc import ABC, abstractmethod from typing import Any @@ -41,10 +40,7 @@ def compute_mapping( "Input data must be a TrackingData object with a graph and segmentations" ) - # Copy graphs to avoid possible changes to graphs while computing mapping - matched = self._compute_mapping( - copy.deepcopy(gt_graph), copy.deepcopy(pred_graph) - ) + matched = self._compute_mapping(gt_graph, pred_graph) # Record matcher info on Matched object matched.matcher_info = self.info diff --git a/tests/bench.py b/tests/bench.py index 7b30bc51..c07c9851 100644 --- a/tests/bench.py +++ b/tests/bench.py @@ -1,7 +1,6 @@ import copy import os -import urllib.request -import zipfile +from pathlib import Path import pandas as pd import pytest @@ -14,67 +13,78 @@ from traccuracy.matchers import CTCMatcher, IOUMatcher from traccuracy.metrics import CTCMetrics, DivisionMetrics -ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +ROOT_DIR = Path(__file__).resolve().parents[1] +TIMEOUT = 300 -def download_gt_data(): - # Download GT data -- look into caching this in github actions - url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip" - data_dir = os.path.join(ROOT_DIR, "downloads") - - if not os.path.exists(data_dir): - os.mkdir(data_dir) - - filename = url.split("/")[-1] - file_path = os.path.join(data_dir, filename) - - if not os.path.exists(file_path): - urllib.request.urlretrieve(url, file_path) - - # Unzip the data - with zipfile.ZipFile(file_path, "r") as zip_ref: - zip_ref.extractall(data_dir) - - -@pytest.fixture(scope="module") -def gt_data(): - download_gt_data() +@pytest.fixture(scope="function") +def gt_data_2d(): + path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA" return load_ctc_data( - os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA"), - os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt"), + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), run_checks=False, ) -@pytest.fixture(scope="module") -def pred_data(): +@pytest.fixture(scope="function") +def gt_data_3d(): + path = "downloads/Fluo-N3DH-CE/01_GT/TRA" return load_ctc_data( - os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES"), - os.path.join( - ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt" - ), + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), run_checks=False, ) -@pytest.fixture(scope="module") -def ctc_matched(gt_data, pred_data): - return CTCMatcher().compute_mapping(gt_data, pred_data) +@pytest.fixture(scope="function") +def pred_data_2d(gt_data_2d): + # For now this is also GT data. + return copy.deepcopy(gt_data_2d) + + +@pytest.fixture(scope="function") +def pred_data_3d(gt_data_3d): + # For now this is also GT data. + return copy.deepcopy(gt_data_3d) + +@pytest.fixture(scope="function") +def ctc_matched_2d(gt_data_2d, pred_data_2d): + return CTCMatcher().compute_mapping(gt_data_2d, pred_data_2d) -@pytest.fixture(scope="module") -def iou_matched(gt_data, pred_data): - return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data, pred_data) +@pytest.fixture(scope="function") +def ctc_matched_3d(gt_data_3d, pred_data_3d): + return CTCMatcher().compute_mapping(gt_data_3d, pred_data_3d) -def test_load_gt_data(benchmark): - download_gt_data() + +@pytest.fixture(scope="function") +def iou_matched_2d(gt_data_2d, pred_data_2d): + return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data_2d, pred_data_2d) + + +@pytest.fixture(scope="function") +def iou_matched_3d(gt_data_3d, pred_data_3d): + return IOUMatcher(iou_threshold=0.1).compute_mapping(gt_data_3d, pred_data_3d) + + +@pytest.mark.parametrize( + "dataset", + ["PhC-C2DL-PSC", "Fluo-N3DH-CE"], + ids=["2d", "3d"], +) +def test_load_gt_ctc_data( + benchmark, + dataset, +): + path = f"downloads/{dataset}/01_GT/TRA" benchmark.pedantic( load_ctc_data, args=( - "downloads/Fluo-N2DL-HeLa/01_GT/TRA", - "downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt", + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), ), kwargs={"run_checks": False}, rounds=1, @@ -82,14 +92,20 @@ def test_load_gt_data(benchmark): ) -def test_load_pred_data(benchmark): +# TODO Add 3d results +@pytest.mark.parametrize( + "path", + [ + "examples/sample-data/Fluo-N2DL-HeLa/01_RES", + ], + ids=["2d"], +) +def test_load_pred_ctc_data(benchmark, path): benchmark.pedantic( load_ctc_data, args=( - os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES"), - os.path.join( - ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt" - ), + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "res_track.txt"), ), kwargs={"run_checks": False}, rounds=1, @@ -97,65 +113,89 @@ def test_load_pred_data(benchmark): ) -def test_ctc_checks(benchmark): +@pytest.mark.parametrize( + "dataset", + ["PhC-C2DL-PSC", "Fluo-N3DH-CE"], + ids=["2d", "3d"], +) +def test_ctc_checks(benchmark, dataset): + path = f"downloads/{dataset}/01_GT/TRA" names = ["Cell_ID", "Start", "End", "Parent_ID"] - tracks = pd.read_csv( - os.path.join( - ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt" - ), + os.path.join(ROOT_DIR, path, "man_track.txt"), header=None, sep=" ", names=names, ) - - masks = _load_tiffs( - os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES") - ) + masks = _load_tiffs(os.path.join(ROOT_DIR, path)) detections = _get_node_attributes(masks) benchmark(_check_ctc, tracks, detections, masks) -def test_ctc_matched(benchmark, gt_data, pred_data): - benchmark(CTCMatcher().compute_mapping, gt_data, pred_data) - - -@pytest.mark.timeout(300) -def test_ctc_metrics(benchmark, ctc_matched): - def run_compute(): - return CTCMetrics().compute(copy.deepcopy(ctc_matched)) - - ctc_results = benchmark.pedantic(run_compute, rounds=1, iterations=1) +@pytest.mark.timeout(TIMEOUT) +@pytest.mark.parametrize( + "gt_data,pred_data", + [ + ("gt_data_2d", "pred_data_2d"), + ("gt_data_3d", "pred_data_3d"), + ], + ids=["2d", "3d"], +) +def test_ctc_matcher(benchmark, gt_data, pred_data, request): + gt_data = request.getfixturevalue(gt_data) + pred_data = request.getfixturevalue(pred_data) + benchmark.pedantic( + CTCMatcher().compute_mapping, + args=(gt_data, pred_data), + rounds=1, + iterations=1, + ) - assert ctc_results.results["fn_edges"] == 87 - assert ctc_results.results["fn_nodes"] == 39 - assert ctc_results.results["fp_edges"] == 60 - assert ctc_results.results["fp_nodes"] == 0 - assert ctc_results.results["ns_nodes"] == 0 - assert ctc_results.results["ws_edges"] == 47 +@pytest.mark.parametrize( + "ctc_matched", + ["ctc_matched_2d", "ctc_matched_3d"], + ids=["2d", "3d"], +) +def test_ctc_metrics(benchmark, ctc_matched, request): + ctc_matched = request.getfixturevalue(ctc_matched) -def test_ctc_div_metrics(benchmark, ctc_matched): def run_compute(): - return DivisionMetrics().compute(copy.deepcopy(ctc_matched)) + return CTCMetrics().compute(ctc_matched) - div_results = benchmark(run_compute) + benchmark.pedantic(run_compute, rounds=1, iterations=1) - assert div_results.results["Frame Buffer 0"]["False Negative Divisions"] == 18 - assert div_results.results["Frame Buffer 0"]["False Positive Divisions"] == 30 - assert div_results.results["Frame Buffer 0"]["True Positive Divisions"] == 76 +@pytest.mark.timeout(TIMEOUT) +@pytest.mark.parametrize( + "gt_data,pred_data", + [ + ("gt_data_2d", "pred_data_2d"), + ("gt_data_3d", "pred_data_3d"), + ], + ids=["2d", "3d"], +) +def test_iou_matcher(benchmark, gt_data, pred_data, request): + gt_data = request.getfixturevalue(gt_data) + pred_data = request.getfixturevalue(pred_data) + benchmark.pedantic( + IOUMatcher(iou_threshold=0.1).compute_mapping, + args=(gt_data, pred_data), + rounds=1, + iterations=1, + ) -def test_iou_matched(benchmark, gt_data, pred_data): - benchmark(IOUMatcher(iou_threshold=0.1).compute_mapping, gt_data, pred_data) +@pytest.mark.timeout(TIMEOUT) +@pytest.mark.parametrize( + "iou_matched", + ["iou_matched_2d", "iou_matched_3d"], + ids=["2d", "3d"], +) +def test_iou_div_metrics(benchmark, iou_matched, request): + iou_matched = request.getfixturevalue(iou_matched) -def test_iou_div_metrics(benchmark, iou_matched): def run_compute(): - return DivisionMetrics().compute(copy.deepcopy(iou_matched)) - - div_results = benchmark(run_compute) + return DivisionMetrics().compute(iou_matched) - assert div_results.results["Frame Buffer 0"]["False Negative Divisions"] == 25 - assert div_results.results["Frame Buffer 0"]["False Positive Divisions"] == 31 - assert div_results.results["Frame Buffer 0"]["True Positive Divisions"] == 69 + benchmark.pedantic(run_compute, rounds=1, iterations=1) diff --git a/tests/metrics/test_ctc_metrics.py b/tests/metrics/test_ctc_metrics.py index 2bf18a5e..adc805d3 100644 --- a/tests/metrics/test_ctc_metrics.py +++ b/tests/metrics/test_ctc_metrics.py @@ -1,7 +1,42 @@ -from traccuracy.matchers._ctc import CTCMatcher -from traccuracy.metrics._ctc import CTCMetrics +import os +from pathlib import Path -from tests.test_utils import get_movie_with_graph +import pytest +from traccuracy.loaders import load_ctc_data +from traccuracy.matchers import CTCMatcher +from traccuracy.metrics import CTCMetrics + +from tests.test_utils import get_movie_with_graph, gt_data + +ROOT_DIR = Path(__file__).resolve().parents[2] + + +@pytest.fixture(scope="module") +def gt_hela(): + url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip" + path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA" + return gt_data(url, ROOT_DIR, path) + + +@pytest.fixture(scope="module") +def pred_hela(): + path = "examples/sample-data/Fluo-N2DL-HeLa/01_RES" + return load_ctc_data( + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "res_track.txt"), + ) + + +def test_ctc_metrics(gt_hela, pred_hela): + ctc_matched = CTCMatcher().compute_mapping(gt_hela, pred_hela) + ctc_results = CTCMetrics().compute(ctc_matched) + + assert ctc_results.results["fn_edges"] == 87 + assert ctc_results.results["fn_nodes"] == 39 + assert ctc_results.results["fp_edges"] == 60 + assert ctc_results.results["fp_nodes"] == 0 + assert ctc_results.results["ns_nodes"] == 0 + assert ctc_results.results["ws_edges"] == 47 def test_compute_mapping(): diff --git a/tests/metrics/test_divisions.py b/tests/metrics/test_divisions.py index f306ca11..db6edd98 100644 --- a/tests/metrics/test_divisions.py +++ b/tests/metrics/test_divisions.py @@ -1,8 +1,57 @@ +import os +from pathlib import Path + +import pytest from traccuracy import TrackingGraph -from traccuracy.matchers import Matched +from traccuracy.loaders import load_ctc_data +from traccuracy.matchers import CTCMatcher, IOUMatcher, Matched from traccuracy.metrics._divisions import DivisionMetrics -from tests.test_utils import get_division_graphs +from tests.test_utils import download_gt_data, get_division_graphs + +ROOT_DIR = Path(__file__).resolve().parents[2] + + +@pytest.fixture(scope="module") +def download_gt_hela(): + url = "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DL-HeLa.zip" + download_gt_data(url, ROOT_DIR) + + +@pytest.fixture(scope="function") +def gt_hela(): + path = "downloads/Fluo-N2DL-HeLa/01_GT/TRA" + return load_ctc_data( + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "man_track.txt"), + ) + + +@pytest.fixture(scope="function") +def pred_hela(): + path = "examples/sample-data/Fluo-N2DL-HeLa/01_RES" + return load_ctc_data( + os.path.join(ROOT_DIR, path), + os.path.join(ROOT_DIR, path, "res_track.txt"), + ) + + +def test_ctc_div_metrics(gt_hela, pred_hela): + ctc_matched = CTCMatcher().compute_mapping(gt_hela, pred_hela) + div_results = DivisionMetrics().compute(ctc_matched) + + assert div_results.results["Frame Buffer 0"]["False Negative Divisions"] == 18 + assert div_results.results["Frame Buffer 0"]["False Positive Divisions"] == 30 + assert div_results.results["Frame Buffer 0"]["True Positive Divisions"] == 76 + + +def test_iou_div_metrics(gt_hela, pred_hela): + iou_matched = IOUMatcher(iou_threshold=0.1).compute_mapping(gt_hela, pred_hela) + div_results = DivisionMetrics().compute(iou_matched) + + assert div_results.results["Frame Buffer 0"]["False Negative Divisions"] == 25 + assert div_results.results["Frame Buffer 0"]["False Positive Divisions"] == 31 + assert div_results.results["Frame Buffer 0"]["True Positive Divisions"] == 69 def test_DivisionMetrics(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 2913a8f9..5fb541d3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,38 @@ +import os +import urllib.request +import zipfile + import networkx as nx import numpy as np import skimage as sk from traccuracy._tracking_graph import TrackingGraph +from traccuracy.loaders import load_ctc_data + + +def download_gt_data(url, root_dir): + # Download GT data -- look into caching this in github actions + data_dir = os.path.join(root_dir, "downloads") + + if not os.path.exists(data_dir): + os.mkdir(data_dir) + + filename = url.split("/")[-1] + file_path = os.path.join(data_dir, filename) + + if not os.path.exists(file_path): + urllib.request.urlretrieve(url, file_path) + + # Unzip the data + with zipfile.ZipFile(file_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + + +def gt_data(url, root_dir, path): + download_gt_data(url, root_dir) + return load_ctc_data( + os.path.join(root_dir, path), + os.path.join(root_dir, path, "man_track.txt"), + ) def get_annotated_image(img_size=256, num_labels=3, sequential=True, seed=1):