Skip to content

Commit

Permalink
Merge branch 'main' into gap-closing
Browse files Browse the repository at this point in the history
  • Loading branch information
DragaDoncila authored Dec 15, 2023
2 parents a885b5f + ad3f68d commit 6e84746
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 19 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ repos:
- id: typos

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.4
rev: v0.1.6
hooks:
- id: ruff
args: [--fix]

- repo: https://github.com/psf/black
rev: 23.10.1
rev: 23.11.0
hooks:
- id: black

Expand All @@ -26,7 +26,7 @@ repos:
- id: validate-pyproject

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.1
rev: v1.7.1
hooks:
- id: mypy
files: "^src/"
Expand Down
4 changes: 2 additions & 2 deletions src/traccuracy/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
Each loading function must return one TrackingGraph object which has a
track graph and optionally contains a corresponding segmentation.
"""
from ._ctc import load_ctc_data
from ._ctc import _check_ctc, _get_node_attributes, _load_tiffs, load_ctc_data

__all__ = ["load_ctc_data"]
__all__ = ["load_ctc_data", "_check_ctc", "_load_tiffs", "_get_node_attributes"]
99 changes: 86 additions & 13 deletions src/traccuracy/loaders/_ctc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import glob
import logging
import os

import networkx as nx
import numpy as np
import pandas as pd
from skimage.measure import regionprops_table
from skimage.measure import label, regionprops_table
from tifffile import imread
from tqdm import tqdm

from traccuracy._tracking_graph import TrackingGraph

logger = logging.getLogger(__name__)

def load_tiffs(data_dir):

def _load_tiffs(data_dir):
"""Load a directory of individual frames into a stack.
Args:
Expand Down Expand Up @@ -52,7 +55,7 @@ def _detections_from_image(stack, idx):
return pd.DataFrame(props)


def get_node_attributes(masks):
def _get_node_attributes(masks):
"""Calculates x,y,z,t,label for each detection in a movie.
Args:
Expand Down Expand Up @@ -83,14 +86,11 @@ def ctc_to_graph(df, detections):
Args:
data (pd.DataFrame): DataFrame of CTC-style info
detections (pd.DataFrame): Dataframe from get_node_attributes with position
detections (pd.DataFrame): Dataframe from _get_node_attributes with position
and segmentation label for each cell detection
Returns:
networkx.Graph: Graph representation of the CTC data.
Raises:
ValueError: If the Parent_ID is not in any previous frames.
"""
edges = []

Expand Down Expand Up @@ -157,19 +157,83 @@ def ctc_to_graph(df, detections):
return G


def load_ctc_data(data_dir, track_path=None):
def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray):
"""Sanity checks for valid CTC format.
Hard checks (throws exception):
- Tracklet IDs in tracks file must be unique and positive
- Parent tracklet IDs must exist in the tracks file
- Intertracklet edges must be directed forward in time.
- In each time point, the set of segmentation IDs present in the detections must equal the set
of tracklet IDs in the tracks file that overlap this time point.
Soft checks (prints warning):
- No duplicate tracklet IDs (non-connected pixels with same ID) in a single timepoint.
Args:
tracks (pd.DataFrame): Tracks in CTC format with columns Cell_ID, Start, End, Parent_ID.
detections (pd.DataFrame): Detections extracted from masks, containing columns
segmentation_id, t.
masks (np.ndarray): Set of masks with time in the first axis.
Raises:
ValueError: If any of the hard checks fail.
"""
logger.info("Running CTC format checks")
if tracks["Cell_ID"].min() < 1:
raise ValueError("Cell_IDs in tracks file must be positive integers.")
if len(tracks["Cell_ID"]) < len(tracks["Cell_ID"].unique()):
raise ValueError("Cell_IDs in tracks file must be unique integers.")

for _, row in tracks.iterrows():
if row["Parent_ID"] != 0:
if row["Parent_ID"] not in tracks["Cell_ID"].values:
raise ValueError(
f"Parent_ID {row['Parent_ID']} is not present in tracks."
)
parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0]
if parent_end >= row["Start"]:
raise ValueError(
f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} "
f"starts at t={row['Start']}, "
f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}."
)

for t in range(tracks["Start"].min(), tracks["End"].max()):
track_ids = set(
tracks[(tracks["Start"] <= t) & (tracks["End"] >= t)]["Cell_ID"]
)
det_ids = set(detections[(detections["t"] == t)]["segmentation_id"])
if not track_ids.issubset(det_ids):
raise ValueError(f"Missing IDs in masks at t={t}: {track_ids - det_ids}")
if not det_ids.issubset(track_ids):
raise ValueError(
f"IDs {det_ids - track_ids} at t={t} not represented in tracks file."
)

for t, frame in enumerate(masks):
_, n_components = label(frame, return_num=True)
n_labels = len(detections[detections["t"] == t])
if n_labels < n_components:
logger.warning(f"{n_components - n_labels} non-connected masks at t={t}.")


def load_ctc_data(data_dir, track_path=None, run_checks=True):
"""Read the CTC segmentations and track file and create TrackingData.
Args:
data_dir (str): Path to directory containing CTC tiffs.
track_path (optional, str): Path to CTC track file. If not passed,
finds `*_track.txt` in data_dir.
run_checks (optional, bool): If set to `True` (default), runs checks on the data to ensure
valid CTC format.
Returns:
TrackingData: Object containing segmentations and TrackingGraph.
Raises:
ValueError: If the Parent_ID is not in any previous frames.
ValueError:
If `run_checks` is True, whenever any of the CTC format checks are violated.
If `run_checks` is False, whenever any other Exception occurs while creating the graph.
"""
names = ["Cell_ID", "Start", "End", "Parent_ID"]
if not track_path:
Expand All @@ -187,9 +251,18 @@ def load_ctc_data(data_dir, track_path=None):

tracks = pd.read_csv(track_path, header=None, sep=" ", names=names)

masks = load_tiffs(data_dir)
detections = get_node_attributes(masks)

G = ctc_to_graph(tracks, detections)
masks = _load_tiffs(data_dir)
detections = _get_node_attributes(masks)
if run_checks:
_check_ctc(tracks, detections, masks)

try:
G = ctc_to_graph(tracks, detections)
except BaseException as e:
logger.error(e)
raise ValueError(
"Error in converting CTC to graph. "
"Consider setting `run_checks=True` for detailed error message."
) from e

return TrackingGraph(G, segmentation=masks)
31 changes: 30 additions & 1 deletion tests/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@
import urllib.request
import zipfile

import pandas as pd
import pytest
from traccuracy.loaders import load_ctc_data
from traccuracy.loaders import (
_check_ctc,
_get_node_attributes,
_load_tiffs,
load_ctc_data,
)
from traccuracy.matchers import CTCMatcher, IOUMatcher
from traccuracy.metrics import CTCMetrics, DivisionMetrics

Expand Down Expand Up @@ -36,6 +42,7 @@ def gt_data():
return load_ctc_data(
os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA"),
os.path.join(ROOT_DIR, "downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt"),
run_checks=False,
)


Expand All @@ -46,6 +53,7 @@ def pred_data():
os.path.join(
ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt"
),
run_checks=False,
)


Expand All @@ -68,6 +76,7 @@ def test_load_gt_data(benchmark):
"downloads/Fluo-N2DL-HeLa/01_GT/TRA",
"downloads/Fluo-N2DL-HeLa/01_GT/TRA/man_track.txt",
),
kwargs={"run_checks": False},
rounds=1,
iterations=1,
)
Expand All @@ -82,11 +91,31 @@ def test_load_pred_data(benchmark):
ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt"
),
),
kwargs={"run_checks": False},
rounds=1,
iterations=1,
)


def test_ctc_checks(benchmark):
names = ["Cell_ID", "Start", "End", "Parent_ID"]

tracks = pd.read_csv(
os.path.join(
ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES/res_track.txt"
),
header=None,
sep=" ",
names=names,
)

masks = _load_tiffs(
os.path.join(ROOT_DIR, "examples/sample-data/Fluo-N2DL-HeLa/01_RES")
)
detections = _get_node_attributes(masks)
benchmark(_check_ctc, tracks, detections, masks)


def test_ctc_matched(benchmark, gt_data, pred_data):
benchmark(CTCMatcher().compute_mapping, gt_data, pred_data)

Expand Down

1 comment on commit 6e84746

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE ad3f68d Mean (s) HEAD 6e84746 Percent Change
test_load_gt_data 1.28065 1.24617 -2.69
test_load_pred_data 1.15569 1.16436 0.75
test_ctc_checks 0.41545 0.41925 0.91
test_ctc_matched 2.19752 2.25848 2.77
test_ctc_metrics 0.52157 0.53773 3.1
test_ctc_div_metrics 0.28202 0.27889 -1.11
test_iou_matched 9.41798 9.00748 -4.36
test_iou_div_metrics 0.27595 0.26978 -2.24

Please sign in to comment.