Skip to content

Commit

Permalink
Add sample data, CI, and auto-scaling for segmenation CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Dec 5, 2024
1 parent 414b4cb commit 9b3e7b5
Show file tree
Hide file tree
Showing 7 changed files with 272 additions and 39 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: test

on:
push:
branches:
- main
tags:
- "v*" # Push events to matching v*, i.e. v1.0, v20.15.10
pull_request: # run CI on commits to any open PR
workflow_dispatch: # can manually trigger CI from GitHub actions tab


jobs:
test:
name: ${{ matrix.os }} ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.11"]
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup micromamba
uses: mamba-org/setup-micromamba@v1
with:
environment-file: environment_cpu.yaml
create-args: >-
python=${{ matrix.python-version }}
- name: Install SynapseNet
shell: bash -l {0}
run: pip install --no-deps -e .

- name: Run tests
shell: bash -l {0}
run: python -m unittest discover -s test -v
70 changes: 62 additions & 8 deletions synaptic_reconstruction/inference/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import imageio.v3 as imageio
import elf.parallel as parallel
import mrcfile
import numpy as np
import torch
import torch_em
Expand Down Expand Up @@ -131,7 +132,7 @@ def get_prediction(
# torch_em expects the root folder of a checkpoint path instead of the checkpoint itself.
if model_path.endswith("best.pt"):
model_path = os.path.split(model_path)[0]
print(f"tiling {tiling}")
# print(f"tiling {tiling}")
# Create updated_tiling with the same structure
updated_tiling = {
"tile": {},
Expand All @@ -140,7 +141,7 @@ def get_prediction(
# Update tile dimensions
for dim in tiling["tile"]:
updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim]
print(f"updated_tiling {updated_tiling}")
# print(f"updated_tiling {updated_tiling}")
pred = get_prediction_torch_em(
input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask
)
Expand Down Expand Up @@ -252,6 +253,33 @@ def _load_input(img_path, extra_files, i):
return input_volume


def _derive_scale(img_path, model_resolution):
try:
with mrcfile.open(img_path, "r") as f:
voxel_size = f.voxel_size
if len(model_resolution) == 2:
voxel_size = [voxel_size.y, voxel_size.x]
else:
voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x]

assert len(voxel_size) == len(model_resolution)
# The voxel size is given in Angstrom and we need to translate it to nanometer.
voxel_size = [vsize / 10 for vsize in voxel_size]

# Compute the correct scale factor.
scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution))
print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution)

except Exception:
warnings.warn(
f"The voxel size could not be read from the data for {img_path}. "
"This data will not be scaled for prediction."
)
scale = None

return scale


def inference_helper(
input_path: str,
output_root: str,
Expand All @@ -263,6 +291,8 @@ def inference_helper(
mask_input_ext: str = ".tif",
force: bool = False,
output_key: Optional[str] = None,
model_resolution: Optional[Tuple[float, float, float]] = None,
scale: Optional[Tuple[float, float, float]] = None,
) -> None:
"""Helper function to run segmentation for mrc files.
Expand All @@ -282,7 +312,13 @@ def inference_helper(
mask_input_ext: File extension for the mask inputs (by default .tif).
force: Whether to rerun segmentation for output files that are already present.
output_key: Output key for the prediction. If none will write an hdf5 file.
model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction.
If given, the scaling factor will automatically be determined based on the voxel_size of the input data.
scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'.
"""
if (scale is not None) and (model_resolution is not None):
raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.")

# Get the input files. If input_path is a folder then this will load all
# the mrc files beneath it. Otherwise we assume this is an mrc file already
# and just return the path to this mrc file.
Expand Down Expand Up @@ -333,8 +369,18 @@ def inference_helper(
# Load the mask (if given).
mask = None if mask_files is None else imageio.imread(mask_files[i])

# Determine the scale factor:
# If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None.
if scale is None and model_resolution is None:
this_scale = None
elif scale is not None: # If 'scale' was passed then use it.
this_scale = scale
else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data
assert model_resolution is not None
this_scale = _derive_scale(img_path, model_resolution)

# Run the segmentation.
segmentation = segmentation_function(input_volume, mask=mask)
segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale)

# Write the result to tif or h5.
os.makedirs(os.path.split(output_path)[0], exist_ok=True)
Expand All @@ -348,15 +394,21 @@ def inference_helper(
print(f"Saved segmentation to {output_path}.")


def get_default_tiling() -> Dict[str, Dict[str, int]]:
def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]:
"""Determine the tile shape and halo depending on the available VRAM.
Args:
is_2d: Whether to return tiling settings for 2d inference.
Returns:
The default tiling settings for the available computational resources.
"""
if torch.cuda.is_available():
print("Determining suitable tiling")
if is_2d:
tile = {"x": 768, "y": 768, "z": 1}
halo = {"x": 128, "y": 128, "z": 0}
return {"tile": tile, "halo": halo}

if torch.cuda.is_available():
# We always use the same default halo.
halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8

Expand Down Expand Up @@ -390,19 +442,21 @@ def get_default_tiling() -> Dict[str, Dict[str, int]]:

def parse_tiling(
tile_shape: Tuple[int, int, int],
halo: Tuple[int, int, int]
halo: Tuple[int, int, int],
is_2d: bool = False,
) -> Dict[str, Dict[str, int]]:
"""Helper function to parse tiling parameter input from the command line.
Args:
tile_shape: The tile shape. If None the default tile shape is used.
halo: The halo. If None the default halo is used.
is_2d: Whether to return tiling for a 2d model.
Returns:
The tiling specification.
"""

default_tiling = get_default_tiling()
default_tiling = get_default_tiling(is_2d=is_2d)

if tile_shape is None:
tile_shape = default_tiling["tile"]
Expand Down
34 changes: 34 additions & 0 deletions synaptic_reconstruction/sample_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import pooch


def get_sample_data(name: str) -> str:
"""Get the filepath to SynapseNet sample data, stored as mrc file.
Args:
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
Returns:
The filepath to the downloaded sample data.
"""
registry = {
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
}
urls = {
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
}
key = f"{name}.mrc"

if key not in registry:
valid_names = [k[:-4] for k in registry.keys()]
raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.")

cache_dir = os.path.expanduser(pooch.os_cache("synapse-net"))
data_registry = pooch.create(
path=os.path.join(cache_dir, "sample_data"),
base_url="",
registry=registry,
urls=urls,
)
file_path = data_registry.fetch(key)
return file_path
84 changes: 61 additions & 23 deletions synaptic_reconstruction/tools/cli.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,47 @@
import argparse
from functools import partial

from .util import run_segmentation, get_model
from .util import (
run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model
)
from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod
from ..inference.util import inference_helper, parse_tiling


def imod_point_cli():
parser = argparse.ArgumentParser(description="")
parser = argparse.ArgumentParser(
description="Convert a vesicle segmentation to an IMOD point model, "
"corresponding to a sphere for each vesicle in the segmentation."
)
parser.add_argument(
"--input_path", "-i", required=True,
help="The filepath to the mrc file or the directory containing the tomogram data."
)
parser.add_argument(
"--segmentation_path", "-s", required=True,
help="The filepath to the tif file or the directory containing the segmentations."
help="The filepath to the file or the directory containing the segmentations."
)
parser.add_argument(
"--output_path", "-o", required=True,
help="The filepath to directory where the segmentations will be saved."
)
parser.add_argument(
"--segmentation_key", "-k", help=""
"--segmentation_key", "-k",
help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
"If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
)
parser.add_argument(
"--min_radius", type=float, default=10.0, help=""
"--min_radius", type=float, default=10.0,
help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa
)
parser.add_argument(
"--radius_factor", type=float, default=1.0, help="",
"--radius_factor", type=float, default=1.0,
help="A factor for scaling the sphere radius for the export. "
"This can be used to fit the size of segmented vesicles to the best matching spheres.",
)
parser.add_argument(
"--force", action="store_true", help="",
"--force", action="store_true",
help="Whether to over-write already present export results."
)
args = parser.parse_args()

Expand All @@ -51,24 +62,29 @@ def imod_point_cli():


def imod_object_cli():
parser = argparse.ArgumentParser(description="")
parser = argparse.ArgumentParser(
description="Convert segmented objects to close contour IMOD models."
)
parser.add_argument(
"--input_path", "-i", required=True,
help="The filepath to the mrc file or the directory containing the tomogram data."
)
parser.add_argument(
"--segmentation_path", "-s", required=True,
help="The filepath to the tif file or the directory containing the segmentations."
help="The filepath to the file or the directory containing the segmentations."
)
parser.add_argument(
"--output_path", "-o", required=True,
help="The filepath to directory where the segmentations will be saved."
)
parser.add_argument(
"--segmentation_key", "-k", help=""
"--segmentation_key", "-k",
help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif."
"If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset."
)
parser.add_argument(
"--force", action="store_true", help="",
"--force", action="store_true",
help="Whether to over-write already present export results."
)
args = parser.parse_args()
export_helper(
Expand All @@ -82,8 +98,6 @@ def imod_object_cli():


# TODO: handle kwargs
# TODO: add custom model path
# TODO: enable autoscaling from input resolution
def segmentation_cli():
parser = argparse.ArgumentParser(description="Run segmentation.")
parser.add_argument(
Expand All @@ -94,9 +108,11 @@ def segmentation_cli():
"--output_path", "-o", required=True,
help="The filepath to directory where the segmentations will be saved."
)
# TODO: list the availabel models here by parsing the keys of the model registry
model_names = list(get_model_registry().urls.keys())
model_names = ", ".join(model_names)
parser.add_argument(
"--model", "-m", required=True, help="The model type."
"--model", "-m", required=True,
help=f"The model type. The following models are currently available: {model_names}"
)
parser.add_argument(
"--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation."
Expand All @@ -119,23 +135,45 @@ def segmentation_cli():
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
)
parser.add_argument(
"--segmentation_key", "-s", help=""
"--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.",
)
# TODO enable autoscaling
parser.add_argument(
"--scale", type=float, default=None, help=""
"--segmentation_key", "-s",
help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.",
)
parser.add_argument(
"--scale", type=float,
help="The factor for rescaling the data before inference. "
"By default, the scaling factor will be derived from the voxel size of the input data. "
"If this parameter is given it will over-ride the default behavior. "
)
args = parser.parse_args()

model = get_model(args.model)
tiling = parse_tiling(args.tile_shape, args.halo)
scale = None if args.scale is None else 3 * (args.scale,)
if args.checkpoint is None:
model = get_model(args.model)
else:
model = load_custom_model(args.checkpoint)
assert model is not None, f"The model from {args.checkpoint} could not be loaded."

is_2d = "2d" in args.model
tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d)

# If the scale argument is not passed, then we get the average training resolution for the model.
# The inputs will then be scaled to match this resolution based on the voxel size from the mrc files.
if args.scale is None:
model_resolution = get_model_training_resolution(args.model)
model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx"))
scale = None
# Otherwise, we set the model resolution to None and use the scaling factor provided by the user.
else:
model_resolution = None
scale = (2 if is_2d else 3) * (args.scale,)

segmentation_function = partial(
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, scale=scale
run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling,
)
inference_helper(
args.input_path, args.output_path, segmentation_function,
mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext,
output_key=args.segmentation_key,
output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale,
)
Loading

0 comments on commit 9b3e7b5

Please sign in to comment.