From 9b3e7b545c25973c7730a141414efbeceb5774fb Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 5 Dec 2024 20:47:34 +0100 Subject: [PATCH] Add sample data, CI, and auto-scaling for segmenation CLI --- .github/workflows/run_tests.yaml | 40 +++++++++ synaptic_reconstruction/inference/util.py | 70 ++++++++++++++-- synaptic_reconstruction/sample_data.py | 34 ++++++++ synaptic_reconstruction/tools/cli.py | 84 ++++++++++++++----- .../tools/segmentation_widget.py | 3 +- synaptic_reconstruction/tools/util.py | 12 +-- test/test_cli.py | 68 +++++++++++++++ 7 files changed, 272 insertions(+), 39 deletions(-) create mode 100644 .github/workflows/run_tests.yaml create mode 100644 synaptic_reconstruction/sample_data.py create mode 100644 test/test_cli.py diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml new file mode 100644 index 0000000..fd93a50 --- /dev/null +++ b/.github/workflows/run_tests.yaml @@ -0,0 +1,40 @@ +name: test + +on: + push: + branches: + - main + tags: + - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 + pull_request: # run CI on commits to any open PR + workflow_dispatch: # can manually trigger CI from GitHub actions tab + + +jobs: + test: + name: ${{ matrix.os }} ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + timeout-minutes: 60 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python-version: ["3.11"] + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup micromamba + uses: mamba-org/setup-micromamba@v1 + with: + environment-file: environment_cpu.yaml + create-args: >- + python=${{ matrix.python-version }} + + - name: Install SynapseNet + shell: bash -l {0} + run: pip install --no-deps -e . + + - name: Run tests + shell: bash -l {0} + run: python -m unittest discover -s test -v diff --git a/synaptic_reconstruction/inference/util.py b/synaptic_reconstruction/inference/util.py index 3d17f99..434fb32 100644 --- a/synaptic_reconstruction/inference/util.py +++ b/synaptic_reconstruction/inference/util.py @@ -11,6 +11,7 @@ import imageio.v3 as imageio import elf.parallel as parallel +import mrcfile import numpy as np import torch import torch_em @@ -131,7 +132,7 @@ def get_prediction( # torch_em expects the root folder of a checkpoint path instead of the checkpoint itself. if model_path.endswith("best.pt"): model_path = os.path.split(model_path)[0] - print(f"tiling {tiling}") + # print(f"tiling {tiling}") # Create updated_tiling with the same structure updated_tiling = { "tile": {}, @@ -140,7 +141,7 @@ def get_prediction( # Update tile dimensions for dim in tiling["tile"]: updated_tiling["tile"][dim] = tiling["tile"][dim] - 2 * tiling["halo"][dim] - print(f"updated_tiling {updated_tiling}") + # print(f"updated_tiling {updated_tiling}") pred = get_prediction_torch_em( input_volume, updated_tiling, model_path, model, verbose, with_channels, mask=mask ) @@ -252,6 +253,33 @@ def _load_input(img_path, extra_files, i): return input_volume +def _derive_scale(img_path, model_resolution): + try: + with mrcfile.open(img_path, "r") as f: + voxel_size = f.voxel_size + if len(model_resolution) == 2: + voxel_size = [voxel_size.y, voxel_size.x] + else: + voxel_size = [voxel_size.z, voxel_size.y, voxel_size.x] + + assert len(voxel_size) == len(model_resolution) + # The voxel size is given in Angstrom and we need to translate it to nanometer. + voxel_size = [vsize / 10 for vsize in voxel_size] + + # Compute the correct scale factor. + scale = tuple(vsize / res for vsize, res in zip(voxel_size, model_resolution)) + print("Rescaling the data at", img_path, "by", scale, "to match the training voxel size", model_resolution) + + except Exception: + warnings.warn( + f"The voxel size could not be read from the data for {img_path}. " + "This data will not be scaled for prediction." + ) + scale = None + + return scale + + def inference_helper( input_path: str, output_root: str, @@ -263,6 +291,8 @@ def inference_helper( mask_input_ext: str = ".tif", force: bool = False, output_key: Optional[str] = None, + model_resolution: Optional[Tuple[float, float, float]] = None, + scale: Optional[Tuple[float, float, float]] = None, ) -> None: """Helper function to run segmentation for mrc files. @@ -282,7 +312,13 @@ def inference_helper( mask_input_ext: File extension for the mask inputs (by default .tif). force: Whether to rerun segmentation for output files that are already present. output_key: Output key for the prediction. If none will write an hdf5 file. + model_resolution: The resolution / voxel size to which the inputs should be scaled for prediction. + If given, the scaling factor will automatically be determined based on the voxel_size of the input data. + scale: Fixed factor for scaling the model inputs. Cannot be passed together with 'model_resolution'. """ + if (scale is not None) and (model_resolution is not None): + raise ValueError("You must not provide both 'scale' and 'model_resolution' arguments.") + # Get the input files. If input_path is a folder then this will load all # the mrc files beneath it. Otherwise we assume this is an mrc file already # and just return the path to this mrc file. @@ -333,8 +369,18 @@ def inference_helper( # Load the mask (if given). mask = None if mask_files is None else imageio.imread(mask_files[i]) + # Determine the scale factor: + # If the neither the 'scale' nor 'model_resolution' arguments were passed then set it to None. + if scale is None and model_resolution is None: + this_scale = None + elif scale is not None: # If 'scale' was passed then use it. + this_scale = scale + else: # Otherwise 'model_resolution' was passed, use it to derive the scaling from the data + assert model_resolution is not None + this_scale = _derive_scale(img_path, model_resolution) + # Run the segmentation. - segmentation = segmentation_function(input_volume, mask=mask) + segmentation = segmentation_function(input_volume, mask=mask, scale=this_scale) # Write the result to tif or h5. os.makedirs(os.path.split(output_path)[0], exist_ok=True) @@ -348,15 +394,21 @@ def inference_helper( print(f"Saved segmentation to {output_path}.") -def get_default_tiling() -> Dict[str, Dict[str, int]]: +def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: """Determine the tile shape and halo depending on the available VRAM. + Args: + is_2d: Whether to return tiling settings for 2d inference. + Returns: The default tiling settings for the available computational resources. """ - if torch.cuda.is_available(): - print("Determining suitable tiling") + if is_2d: + tile = {"x": 768, "y": 768, "z": 1} + halo = {"x": 128, "y": 128, "z": 0} + return {"tile": tile, "halo": halo} + if torch.cuda.is_available(): # We always use the same default halo. halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8 @@ -390,19 +442,21 @@ def get_default_tiling() -> Dict[str, Dict[str, int]]: def parse_tiling( tile_shape: Tuple[int, int, int], - halo: Tuple[int, int, int] + halo: Tuple[int, int, int], + is_2d: bool = False, ) -> Dict[str, Dict[str, int]]: """Helper function to parse tiling parameter input from the command line. Args: tile_shape: The tile shape. If None the default tile shape is used. halo: The halo. If None the default halo is used. + is_2d: Whether to return tiling for a 2d model. Returns: The tiling specification. """ - default_tiling = get_default_tiling() + default_tiling = get_default_tiling(is_2d=is_2d) if tile_shape is None: tile_shape = default_tiling["tile"] diff --git a/synaptic_reconstruction/sample_data.py b/synaptic_reconstruction/sample_data.py new file mode 100644 index 0000000..c0a3e47 --- /dev/null +++ b/synaptic_reconstruction/sample_data.py @@ -0,0 +1,34 @@ +import os +import pooch + + +def get_sample_data(name: str) -> str: + """Get the filepath to SynapseNet sample data, stored as mrc file. + + Args: + name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data. + + Returns: + The filepath to the downloaded sample data. + """ + registry = { + "tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28", + } + urls = { + "tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download", + } + key = f"{name}.mrc" + + if key not in registry: + valid_names = [k[:-4] for k in registry.keys()] + raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.") + + cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + data_registry = pooch.create( + path=os.path.join(cache_dir, "sample_data"), + base_url="", + registry=registry, + urls=urls, + ) + file_path = data_registry.fetch(key) + return file_path diff --git a/synaptic_reconstruction/tools/cli.py b/synaptic_reconstruction/tools/cli.py index 54a52a3..a103cb2 100644 --- a/synaptic_reconstruction/tools/cli.py +++ b/synaptic_reconstruction/tools/cli.py @@ -1,36 +1,47 @@ import argparse from functools import partial -from .util import run_segmentation, get_model +from .util import ( + run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model +) from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod from ..inference.util import inference_helper, parse_tiling def imod_point_cli(): - parser = argparse.ArgumentParser(description="") + parser = argparse.ArgumentParser( + description="Convert a vesicle segmentation to an IMOD point model, " + "corresponding to a sphere for each vesicle in the segmentation." + ) parser.add_argument( "--input_path", "-i", required=True, help="The filepath to the mrc file or the directory containing the tomogram data." ) parser.add_argument( "--segmentation_path", "-s", required=True, - help="The filepath to the tif file or the directory containing the segmentations." + help="The filepath to the file or the directory containing the segmentations." ) parser.add_argument( "--output_path", "-o", required=True, help="The filepath to directory where the segmentations will be saved." ) parser.add_argument( - "--segmentation_key", "-k", help="" + "--segmentation_key", "-k", + help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." + "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." ) parser.add_argument( - "--min_radius", type=float, default=10.0, help="" + "--min_radius", type=float, default=10.0, + help="The minimum vesicle radius in nm. Objects that are smaller than this radius will be exclded from the export." # noqa ) parser.add_argument( - "--radius_factor", type=float, default=1.0, help="", + "--radius_factor", type=float, default=1.0, + help="A factor for scaling the sphere radius for the export. " + "This can be used to fit the size of segmented vesicles to the best matching spheres.", ) parser.add_argument( - "--force", action="store_true", help="", + "--force", action="store_true", + help="Whether to over-write already present export results." ) args = parser.parse_args() @@ -51,24 +62,29 @@ def imod_point_cli(): def imod_object_cli(): - parser = argparse.ArgumentParser(description="") + parser = argparse.ArgumentParser( + description="Convert segmented objects to close contour IMOD models." + ) parser.add_argument( "--input_path", "-i", required=True, help="The filepath to the mrc file or the directory containing the tomogram data." ) parser.add_argument( "--segmentation_path", "-s", required=True, - help="The filepath to the tif file or the directory containing the segmentations." + help="The filepath to the file or the directory containing the segmentations." ) parser.add_argument( "--output_path", "-o", required=True, help="The filepath to directory where the segmentations will be saved." ) parser.add_argument( - "--segmentation_key", "-k", help="" + "--segmentation_key", "-k", + help="The key in the segmentation files. If not given we assume that the segmentations are stored as tif." + "If given, we assume they are stored as hdf5 files, and use the key to load the internal dataset." ) parser.add_argument( - "--force", action="store_true", help="", + "--force", action="store_true", + help="Whether to over-write already present export results." ) args = parser.parse_args() export_helper( @@ -82,8 +98,6 @@ def imod_object_cli(): # TODO: handle kwargs -# TODO: add custom model path -# TODO: enable autoscaling from input resolution def segmentation_cli(): parser = argparse.ArgumentParser(description="Run segmentation.") parser.add_argument( @@ -94,9 +108,11 @@ def segmentation_cli(): "--output_path", "-o", required=True, help="The filepath to directory where the segmentations will be saved." ) - # TODO: list the availabel models here by parsing the keys of the model registry + model_names = list(get_model_registry().urls.keys()) + model_names = ", ".join(model_names) parser.add_argument( - "--model", "-m", required=True, help="The model type." + "--model", "-m", required=True, + help=f"The model type. The following models are currently available: {model_names}" ) parser.add_argument( "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." @@ -119,23 +135,45 @@ def segmentation_cli(): "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc." ) parser.add_argument( - "--segmentation_key", "-s", help="" + "--checkpoint", "-c", help="Path to a custom model, e.g. from domain adaptation.", ) - # TODO enable autoscaling parser.add_argument( - "--scale", type=float, default=None, help="" + "--segmentation_key", "-s", + help="If given, the outputs will be saved to an hdf5 file with this key. Otherwise they will be saved as tif.", + ) + parser.add_argument( + "--scale", type=float, + help="The factor for rescaling the data before inference. " + "By default, the scaling factor will be derived from the voxel size of the input data. " + "If this parameter is given it will over-ride the default behavior. " ) args = parser.parse_args() - model = get_model(args.model) - tiling = parse_tiling(args.tile_shape, args.halo) - scale = None if args.scale is None else 3 * (args.scale,) + if args.checkpoint is None: + model = get_model(args.model) + else: + model = load_custom_model(args.checkpoint) + assert model is not None, f"The model from {args.checkpoint} could not be loaded." + + is_2d = "2d" in args.model + tiling = parse_tiling(args.tile_shape, args.halo, is_2d=is_2d) + + # If the scale argument is not passed, then we get the average training resolution for the model. + # The inputs will then be scaled to match this resolution based on the voxel size from the mrc files. + if args.scale is None: + model_resolution = get_model_training_resolution(args.model) + model_resolution = tuple(model_resolution[ax] for ax in ("yx" if is_2d else "zyx")) + scale = None + # Otherwise, we set the model resolution to None and use the scaling factor provided by the user. + else: + model_resolution = None + scale = (2 if is_2d else 3) * (args.scale,) segmentation_function = partial( - run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, scale=scale + run_segmentation, model=model, model_type=args.model, verbose=False, tiling=tiling, ) inference_helper( args.input_path, args.output_path, segmentation_function, mask_input_path=args.mask_path, force=args.force, data_ext=args.data_ext, - output_key=args.segmentation_key, + output_key=args.segmentation_key, model_resolution=model_resolution, scale=scale, ) diff --git a/synaptic_reconstruction/tools/segmentation_widget.py b/synaptic_reconstruction/tools/segmentation_widget.py index 0b63642..548a465 100644 --- a/synaptic_reconstruction/tools/segmentation_widget.py +++ b/synaptic_reconstruction/tools/segmentation_widget.py @@ -136,7 +136,6 @@ def _create_settings_widget(self): setting_values.layout().addLayout(layout) # Create UI for the halo. - self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param( ("halo_x", "halo_y", "halo_z"), (self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]), @@ -145,7 +144,7 @@ def _create_settings_widget(self): ) setting_values.layout().addLayout(layout) - # read voxel size from layer metadata + # Read voxel size from layer metadata. self.voxel_size_param, layout = self._add_float_param( "voxel_size", 0.0, min_val=0.0, max_val=100.0, ) diff --git a/synaptic_reconstruction/tools/util.py b/synaptic_reconstruction/tools/util.py index cb4b67b..edb51a1 100644 --- a/synaptic_reconstruction/tools/util.py +++ b/synaptic_reconstruction/tools/util.py @@ -54,7 +54,7 @@ def get_model_path(model_type: str) -> str: model_path = model_registry.fetch(model_type) return model_path - + def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: """Get the model for the given segmentation type. @@ -100,14 +100,14 @@ def run_segmentation( The segmentation. """ if model_type.startswith("vesicles"): - segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose) + segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "mitochondria": - segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose) + segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "active_zone": - segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose) + segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "compartments": - segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose) - elif model_type == "inner_ear_structures": + segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "ribbon_synapse_structures": raise NotImplementedError else: raise ValueError(f"Unknown model type: {model_type}") diff --git a/test/test_cli.py b/test/test_cli.py new file mode 100644 index 0000000..6b0d1fb --- /dev/null +++ b/test/test_cli.py @@ -0,0 +1,68 @@ +import os +import unittest +from subprocess import run +from shutil import rmtree + +import imageio.v3 as imageio +import mrcfile +import pooch +from synaptic_reconstruction.sample_data import get_sample_data + + +class TestCLI(unittest.TestCase): + tmp_dir = "./tmp" + + def setUp(self): + self.data_path = get_sample_data("tem_2d") + os.makedirs(self.tmp_dir, exist_ok=True) + + def tearDown(self): + try: + rmtree(self.tmp_dir) + except OSError: + pass + + def check_segmentation_result(self): + output_path = os.path.join(self.tmp_dir, "tem_2d_prediction.tif") + self.assertTrue(os.path.exists(output_path)) + + prediction = imageio.imread(output_path) + with mrcfile.open(self.data_path, "r") as f: + data = f.data[:] + self.assertEqual(prediction.shape, data.shape) + + num_labels = prediction.max() + self.assertGreater(num_labels, 1) + + # import napari + # v = napari.Viewer() + # v.add_image(data) + # v.add_labels(prediction) + # napari.run() + + def test_segmentation_cli(self): + cmd = ["synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d"] + run(cmd) + self.check_segmentation_result() + + def test_segmentation_cli_with_scale(self): + cmd = [ + "synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d", + "--scale", "0.5" + ] + run(cmd) + self.check_segmentation_result() + + def test_segmentation_cli_with_checkpoint(self): + cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + model_path = os.path.join(cache_dir, "models", "vesicles_2d") + cmd = [ + "synapse_net.run_segmentation", "-i", self.data_path, "-o", self.tmp_dir, "-m", "vesicles_2d", + "-c", model_path, + ] + run(cmd) + self.check_segmentation_result() + + +if __name__ == "__main__": + unittest.main()