From 4b3b539d6f0b1f15164e459c4856954036e60ece Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 7 Dec 2024 20:50:25 +0100 Subject: [PATCH 1/9] Bump version; add outline for analysis pipeline --- examples/analysis_pipeline.py | 7 +++++++ synapse_net/__version__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 examples/analysis_pipeline.py diff --git a/examples/analysis_pipeline.py b/examples/analysis_pipeline.py new file mode 100644 index 0000000..f7e1b7e --- /dev/null +++ b/examples/analysis_pipeline.py @@ -0,0 +1,7 @@ +# TODO implement analysis pipeline for our tomo sample data: +# - segment vesicles, AZ, and compartment +# - use compartment segmentation to find the presynaptic terminal +# - postprocess the AZ segmentation +# - measure distances between vesicles and AZ +# - vesicle pool assignment into docked and non-attached terminals +# - export table with distance and morphology measurements for the two pools diff --git a/synapse_net/__version__.py b/synapse_net/__version__.py index f102a9c..3dc1f76 100644 --- a/synapse_net/__version__.py +++ b/synapse_net/__version__.py @@ -1 +1 @@ -__version__ = "0.0.1" +__version__ = "0.1.0" From 8b5cf12e16cd214d93003475ef60dbbda76cdb8d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 8 Dec 2024 09:26:56 +0100 Subject: [PATCH 2/9] Update example analysis pipeline and sample data --- examples/analysis_pipeline.py | 95 ++++++++++++++++++++++++++++++++--- synapse_net/sample_data.py | 4 +- 2 files changed, 90 insertions(+), 9 deletions(-) diff --git a/examples/analysis_pipeline.py b/examples/analysis_pipeline.py index f7e1b7e..246af5d 100644 --- a/examples/analysis_pipeline.py +++ b/examples/analysis_pipeline.py @@ -1,7 +1,88 @@ -# TODO implement analysis pipeline for our tomo sample data: -# - segment vesicles, AZ, and compartment -# - use compartment segmentation to find the presynaptic terminal -# - postprocess the AZ segmentation -# - measure distances between vesicles and AZ -# - vesicle pool assignment into docked and non-attached terminals -# - export table with distance and morphology measurements for the two pools +import napari +import pandas as pd + +from synapse_net.file_utils import read_mrc +from synapse_net.sample_data import get_sample_data +from synapse_net.tools.util import run_segmentation, get_model, compute_scale_from_voxel_size + + +def segment_structures(tomogram, voxel_size): + # Segment the synaptic vesicles. The data will automatically be resized + # to match the average voxel size of the training data. + model_name = "vesicles_3d" # This is the name for the vesicle model for EM tomography. + model = get_model(model_name) # Load the corresponding model. + # Compute the scale to match the tomogram voxel size to the training data. + scale = compute_scale_from_voxel_size(voxel_size, model_name) + vesicles = run_segmentation(tomogram, model, model_name, scale=scale) + + # Segment the active zone. + model_name = "active_zone" + model = get_model(model_name) + scale = compute_scale_from_voxel_size(voxel_size, model_name) + active_zone = run_segmentation(tomogram, model, model_name, scale=scale) + + # Segment the synaptic compartments. + model_name = "compartments" + model = get_model(model_name) + scale = compute_scale_from_voxel_size(voxel_size, model_name) + compartments = run_segmentation(tomogram, model, model_name, scale=scale) + + return {"vesicles": vesicles, "active_zone": active_zone, "compartments": compartments} + + +def postprocess_segmentation(segmentations): + pass + + +def measure_distances(segmentations): + pass + + +def assign_vesicle_pools(distances): + pass + + +def visualize_results(tomogram, segmentations, vesicle_pools): + # TODO vesicle pool visualization + viewer = napari.Viewer() + viewer.add_image(tomogram) + for name, segmentation in segmentations.items(): + viewer.add_labels(segmentation, name=name) + napari.run() + + +def save_analysis(segmentations, distances, vesicle_pools, save_path): + pass + + +def main(): + """This script implements an example analysis pipeline with SynapseNet and applies it to a tomogram. + Here, we analyze docked and non-attached vesicles in a sample tomogram.""" + + # Load the tomogram for our sample data. + mrc_path = get_sample_data("tem_tomo") + tomogram, voxel_size = read_mrc(mrc_path) + + # Segment synaptic vesicles, the active zone, and the synaptic compartment. + segmentations = segment_structures(tomogram, voxel_size) + + # Post-process the segmentations, to find the presynaptic terminal, + # filter out vesicles not in the terminal, and to 'snape' the AZ to the presynaptic boundary. + segmentations = postprocess_segmentation(segmentations) + + # Measure the distances between the AZ and vesicles. + distances = measure_distances(segmentations) + + # Assign the vesicle pools, 'docked' and 'non-attached' vesicles, based on the distances. + vesicle_pools = assign_vesicle_pools(distances) + + # Visualize the results. + visualize_results(tomogram, segmentations, vesicle_pools) + + # Compute the vesicle radii and combine and save all measurements. + save_path = "analysis_results.xlsx" + save_analysis(segmentations, distances, vesicle_pools, save_path) + + +if __name__ == "__main__": + main() diff --git a/synapse_net/sample_data.py b/synapse_net/sample_data.py index 85ca481..7d40c4e 100644 --- a/synapse_net/sample_data.py +++ b/synapse_net/sample_data.py @@ -15,11 +15,11 @@ def get_sample_data(name: str) -> str: """ registry = { "tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28", - "tem_tomo.mrc": "24af31a10761b59fa6ad9f0e763f8f084304e4f31c59b482dd09dde8cd443ed7", + "tem_tomo.mrc": "eb790f83efb4c967c96961239ae52578d95da902fc32307629f76a26c3dc61fa", } urls = { "tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download", - "tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/NeP7gOv76Vj26lm/download", + "tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/TmLjDCXi42E49Ef/download", } key = f"{name}.mrc" From 22ea7141ff0701d3b27ce384033b7d2abb35bd1a Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 8 Dec 2024 20:06:55 +0100 Subject: [PATCH 3/9] Refactor segmentation functionality --- synapse_net/file_utils.py | 11 + synapse_net/inference/__init__.py | 2 +- synapse_net/inference/inference.py | 225 +++++++++++++++++ synapse_net/inference/util.py | 63 ++++- synapse_net/sample_data.py | 4 +- synapse_net/tools/cli.py | 9 +- synapse_net/tools/segmentation_widget.py | 85 ++++++- synapse_net/tools/util.py | 309 ----------------------- test/test_plugin.py | 21 ++ 9 files changed, 405 insertions(+), 324 deletions(-) create mode 100644 synapse_net/inference/inference.py create mode 100644 test/test_plugin.py diff --git a/synapse_net/file_utils.py b/synapse_net/file_utils.py index 6b54e75..a186f80 100644 --- a/synapse_net/file_utils.py +++ b/synapse_net/file_utils.py @@ -3,6 +3,17 @@ import mrcfile import numpy as np +import pooch + + +def get_cache_dir() -> str: + """Get the cache directory of synapse net. + + Returns: + The cache directory. + """ + cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + return cache_dir def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]: diff --git a/synapse_net/inference/__init__.py b/synapse_net/inference/__init__.py index 10dcfec..1619054 100644 --- a/synapse_net/inference/__init__.py +++ b/synapse_net/inference/__init__.py @@ -1,3 +1,3 @@ """This submodule implements SynapseNet's segmentation functionality. """ -from .vesicles import segment_vesicles +from .inference import run_segmentation, get_model diff --git a/synapse_net/inference/inference.py b/synapse_net/inference/inference.py new file mode 100644 index 0000000..9ea6e75 --- /dev/null +++ b/synapse_net/inference/inference.py @@ -0,0 +1,225 @@ +import os +from typing import Dict, List, Optional, Union + +import torch +import numpy as np +import pooch + +from .active_zone import segment_active_zone +from .compartments import segment_compartments +from .mitochondria import segment_mitochondria +from .ribbon_synapse import segment_ribbon_synapse_structures +from .vesicles import segment_vesicles +from .util import get_device +from ..file_utils import get_cache_dir + + +# +# Functions to access SynapseNet's pretrained models. +# + + +def _get_model_registry(): + registry = { + "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", + "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", + "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", + "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", + "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", + "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", + "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", + } + urls = { + "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", + "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", + "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", + "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", + "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", + "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", + "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", + } + cache_dir = get_cache_dir() + models = pooch.create( + path=os.path.join(cache_dir, "models"), + base_url="", + registry=registry, + urls=urls, + ) + return models + + +def get_model_path(model_type: str) -> str: + """Get the local path to a pretrained model. + + Args: + The model type. + + Returns: + The local path to the model. + """ + model_registry = _get_model_registry() + model_path = model_registry.fetch(model_type) + return model_path + + +def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: + """Get the model for a specific segmentation type. + + Args: + model_type: The model for one of the following segmentation tasks: + 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. + device: The device to use. + + Returns: + The model. + """ + if device is None: + device = get_device(device) + model_path = get_model_path(model_type) + model = torch.load(model_path, weights_only=False) + model.to(device) + return model + + +# +# Functions for training resolution / voxel size. +# + + +def get_model_training_resolution(model_type: str) -> Dict[str, float]: + """Get the average resolution / voxel size of the training data for a given pretrained model. + + Args: + model_type: The name of the pretrained model. + + Returns: + Mapping of axis (x, y, z) to the voxel size (in nm) of that axis. + """ + resolutions = { + "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, + "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, + "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, + "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, + "vesicles_2d": {"x": 1.35, "y": 1.35}, + "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, + "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, + } + return resolutions[model_type] + + +def compute_scale_from_voxel_size( + voxel_size: Dict[str, float], + model_type: str +) -> List[float]: + """Compute the appropriate scale factor for inference with a given pretrained model. + + Args: + voxel_size: The voxel size of the data for inference. + model_type: The name of the pretrained model. + + Returns: + The scale factor, as a list in zyx order. + """ + training_voxel_size = get_model_training_resolution(model_type) + scale = [ + voxel_size["x"] / training_voxel_size["x"], + voxel_size["y"] / training_voxel_size["y"], + ] + if len(voxel_size) == 3 and len(training_voxel_size) == 3: + scale.append( + voxel_size["z"] / training_voxel_size["z"] + ) + return scale + + +# +# Convenience functions for segmentation. +# + + +def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): + from synapse_net.inference.postprocessing import ( + segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, + ) + + ribbon = segment_ribbon( + predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, + max_vesicle_distance=40, + ) + PD = segment_presynaptic_density( + predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, + ) + ref_segmentation = PD if PD.sum() > 0 else ribbon + membrane = segment_membrane_distance_based( + predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, + ) + + segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} + return segmentations + + +def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): + # Parse additional keyword arguments from the kwargs. + vesicles = kwargs.pop("extra_segmentation") + threshold = kwargs.pop("threshold", 0.5) + n_slices_exclude = kwargs.pop("n_slices_exclude", 20) + n_ribbons = kwargs.pop("n_slices_exclude", 1) + + predictions = segment_ribbon_synapse_structures( + image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs + ) + + # Otherwise, just return the predictions. + if vesicles is None: + if verbose: + print("Vesicle segmentation was not passed, WILL NOT run post-processing.") + segmentations = predictions + + # If the vesicles were passed then run additional post-processing. + else: + if verbose: + print("Vesicle segmentation was passed, WILL run post-processing.") + segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) + + if return_predictions: + return segmentations, predictions + return segmentations + + +def run_segmentation( + image: np.ndarray, + model: torch.nn.Module, + model_type: str, + tiling: Optional[Dict[str, Dict[str, int]]] = None, + scale: Optional[List[float]] = None, + verbose: bool = False, + **kwargs, +) -> np.ndarray | Dict[str, np.ndarray]: + """Run synaptic structure segmentation. + + Args: + image: The input image or image volume. + model: The segmentation model. + model_type: The model type. This will determine which segmentation post-processing is used. + tiling: The tiling settings for inference. + scale: A scale factor for resizing the input before applying the model. + The output will be scaled back to the initial size. + verbose: Whether to print detailed information about the prediction and segmentation. + kwargs: Optional parameters for the segmentation function. + + Returns: + The segmentation. For models that return multiple segmentations, this function returns a dictionary. + """ + if model_type.startswith("vesicles"): + segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "mitochondria": + segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "active_zone": + segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "compartments": + segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + elif model_type == "ribbon": + segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) + else: + raise ValueError(f"Unknown model type: {model_type}") + return segmentation diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 434fb32..1ad3a73 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -2,7 +2,7 @@ import time import warnings from glob import glob -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union # # Suppress annoying import warnings. # with warnings.catch_warnings(): @@ -26,6 +26,11 @@ from tqdm import tqdm +# +# Utils for prediction. +# + + class _Scaler: def __init__(self, scale, verbose): self.scale = scale @@ -474,6 +479,11 @@ def parse_tiling( return tiling +# +# Utils for post-processing. +# + + def apply_size_filter( segmentation: np.ndarray, min_size: int, @@ -525,3 +535,54 @@ def _postprocess_seg_3d(seg, area_threshold=1000, iterations=4, iterations_3d=8) seg[bb][mask] = prop.label return seg + + +# +# Utils for torch device. +# + +def _get_default_device(): + # Check that we're in CI and use the CPU if we are. + # Otherwise the tests may run out of memory on MAC if MPS is used. + if os.getenv("GITHUB_ACTIONS") == "true": + return "cpu" + # Use cuda enabled gpu if it's available. + if torch.cuda.is_available(): + device = "cuda" + # As second priority use mps. + # See https://pytorch.org/docs/stable/notes/mps.html for details + elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): + device = "mps" + # Use the CPU as fallback. + else: + device = "cpu" + return device + + +def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: + """Get the torch device. + + If no device is passed the default device for your system is used. + Else it will be checked if the device you have passed is supported. + + Args: + device: The input device. + + Returns: + The device. + """ + if device is None or device == "auto": + device = _get_default_device() + else: + device_type = device if isinstance(device, str) else device.type + if device_type.lower() == "cuda": + if not torch.cuda.is_available(): + raise RuntimeError("PyTorch CUDA backend is not available.") + elif device_type.lower() == "mps": + if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): + raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") + elif device_type.lower() == "cpu": + pass # cpu is always available + else: + raise RuntimeError(f"Unsupported device: {device}. Please choose from 'cpu', 'cuda', or 'mps'.") + return device diff --git a/synapse_net/sample_data.py b/synapse_net/sample_data.py index 7d40c4e..cac30bd 100644 --- a/synapse_net/sample_data.py +++ b/synapse_net/sample_data.py @@ -1,7 +1,7 @@ import os import pooch -from .file_utils import read_mrc +from .file_utils import read_mrc, get_cache_dir def get_sample_data(name: str) -> str: @@ -27,7 +27,7 @@ def get_sample_data(name: str) -> str: valid_names = [k[:-4] for k in registry.keys()] raise ValueError(f"Invalid sample name {name}, please choose one of {valid_names}.") - cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) + cache_dir = get_cache_dir() data_registry = pooch.create( path=os.path.join(cache_dir, "sample_data"), base_url="", diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index a103cb2..11caeb7 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -1,10 +1,9 @@ import argparse from functools import partial -from .util import ( - run_segmentation, get_model, get_model_registry, get_model_training_resolution, load_custom_model -) +import torch from ..imod.to_imod import export_helper, write_segmentation_to_imod_as_points, write_segmentation_to_imod +from ..inference.inference import _get_model_registry, get_model, get_model_training_resolution, run_segmentation from ..inference.util import inference_helper, parse_tiling @@ -108,7 +107,7 @@ def segmentation_cli(): "--output_path", "-o", required=True, help="The filepath to directory where the segmentations will be saved." ) - model_names = list(get_model_registry().urls.keys()) + model_names = list(_get_model_registry().urls.keys()) model_names = ", ".join(model_names) parser.add_argument( "--model", "-m", required=True, @@ -152,7 +151,7 @@ def segmentation_cli(): if args.checkpoint is None: model = get_model(args.model) else: - model = load_custom_model(args.checkpoint) + model = torch.load(args.checkpoint, weights_only=False) assert model is not None, f"The model from {args.checkpoint} could not be loaded." is_2d = "2d" in args.model diff --git a/synapse_net/tools/segmentation_widget.py b/synapse_net/tools/segmentation_widget.py index 6fbb0cc..c4f84c0 100644 --- a/synapse_net/tools/segmentation_widget.py +++ b/synapse_net/tools/segmentation_widget.py @@ -1,15 +1,88 @@ import copy +import re +from typing import Optional, Union import napari import numpy as np +import torch from napari.utils.notifications import show_info from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox from .base_widget import BaseWidget -from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device, - get_current_tiling, compute_scale_from_voxel_size, load_custom_model) -from ..inference.util import get_default_tiling +from ..inference.inference import _get_model_registry, get_model, run_segmentation, compute_scale_from_voxel_size +from ..inference.util import get_default_tiling, get_device + + +def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: + model_path = _clean_filepath(model_path) + if device is None: + device = get_device(device) + try: + model = torch.load(model_path, map_location=torch.device(device), weights_only=False) + except Exception as e: + print(e) + print("model path", model_path) + return None + return model + + +def _available_devices(): + available_devices = [] + for i in ["cuda", "mps", "cpu"]: + try: + device = get_device(i) + except RuntimeError: + pass + else: + available_devices.append(device) + return available_devices + + +def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): + # get tiling values from qt objects + for k, v in tiling.items(): + for k2, v2 in v.items(): + if isinstance(v2, int): + continue + tiling[k][k2] = v2.value() + # check if user inputs tiling/halo or not + if default_tiling == tiling: + if "2d" in model_type: + # if its a 2d model expand x,y and set z to 1 + tiling = { + "tile": {"x": 512, "y": 512, "z": 1}, + "halo": {"x": 64, "y": 64, "z": 1}, + } + elif "2d" in model_type: + # if its a 2d model set z to 1 + tiling["tile"]["z"] = 1 + tiling["halo"]["z"] = 1 + + return tiling + + +def _clean_filepath(filepath): + """Cleans a given filepath by: + - Removing newline characters (\n) + - Removing escape sequences + - Stripping the 'file://' prefix if present + + Args: + filepath (str): The original filepath + + Returns: + str: The cleaned filepath + """ + # Remove 'file://' prefix if present + if filepath.startswith("file://"): + filepath = filepath[7:] + + # Remove escape sequences and newlines + filepath = re.sub(r'\\.', '', filepath) + filepath = filepath.replace('\n', '').replace('\r', '') + + return filepath class SegmentationWidget(BaseWidget): @@ -42,7 +115,7 @@ def load_model_widget(self): model_widget = QWidget() title_label = QLabel("Select Model:") - models = ["- choose -"] + list(get_model_registry().urls.keys()) + models = ["- choose -"] + list(_get_model_registry().urls.keys()) self.model_selector = QComboBox() self.model_selector.addItems(models) # Create a layout and add the title label and combo box @@ -66,7 +139,7 @@ def on_predict(self): # Load the model. Override if user chose custom model if custom_model_path: - model = load_custom_model(custom_model_path, device) + model = _load_custom_model(custom_model_path, device) if model: show_info(f"INFO: Using custom model from path: {custom_model_path}") model_type = "custom" @@ -83,7 +156,7 @@ def on_predict(self): return # Get the current tiling. - self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type) + self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) # Get the voxel size. metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) diff --git a/synapse_net/tools/util.py b/synapse_net/tools/util.py index 1495112..5d96522 100644 --- a/synapse_net/tools/util.py +++ b/synapse_net/tools/util.py @@ -1,314 +1,5 @@ import os -import re from typing import Dict, List, Optional, Union import torch import numpy as np -import pooch - -from ..inference.active_zone import segment_active_zone -from ..inference.compartments import segment_compartments -from ..inference.mitochondria import segment_mitochondria -from ..inference.ribbon_synapse import segment_ribbon_synapse_structures -from ..inference.vesicles import segment_vesicles - - -def load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: - model_path = _clean_filepath(model_path) - if device is None: - device = get_device(device) - try: - model = torch.load(model_path, map_location=torch.device(device), weights_only=False) - except Exception as e: - print(e) - print("model path", model_path) - return None - return model - - -def get_model_path(model_type: str) -> str: - """Get the local path to a given model. - - Args: - The model type. - - Returns: - The local path to the model. - """ - model_registry = get_model_registry() - model_path = model_registry.fetch(model_type) - return model_path - - -def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module: - """Get the model for the given segmentation type. - - Args: - model_type: The model type. You can choose One of: - 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. - device: The device to use. - - Returns: - The model. - """ - if device is None: - device = get_device(device) - model_path = get_model_path(model_type) - model = torch.load(model_path, weights_only=False) - model.to(device) - return model - - -def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): - # Parse additional keyword arguments from the kwargs. - vesicles = kwargs.pop("extra_segmentation") - threshold = kwargs.pop("threshold", 0.5) - n_slices_exclude = kwargs.pop("n_slices_exclude", 20) - n_ribbons = kwargs.pop("n_slices_exclude", 1) - - predictions = segment_ribbon_synapse_structures( - image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs - ) - - # If the vesicles were passed then run additional post-processing. - if vesicles is None: - segmentation = predictions - - # Otherwise, just return the predictions. - else: - from synapse_net.inference.postprocessing import ( - segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, - ) - - ribbon = segment_ribbon( - predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, - max_vesicle_distance=40, - ) - PD = segment_presynaptic_density( - predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, - ) - ref_segmentation = PD if PD.sum() > 0 else ribbon - membrane = segment_membrane_distance_based( - predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, - ) - - segmentation = {"ribbon": ribbon, "PD": PD, "membrane": membrane} - - return segmentation - - -def run_segmentation( - image: np.ndarray, - model: torch.nn.Module, - model_type: str, - tiling: Optional[Dict[str, Dict[str, int]]] = None, - scale: Optional[List[float]] = None, - verbose: bool = False, - **kwargs, -) -> np.ndarray | Dict[str, np.ndarray]: - """Run synaptic structure segmentation. - - Args: - image: The input image or image volume. - model: The segmentation model. - model_type: The model type. This will determine which segmentation post-processing is used. - tiling: The tiling settings for inference. - scale: A scale factor for resizing the input before applying the model. - The output will be scaled back to the initial size. - verbose: Whether to print detailed information about the prediction and segmentation. - kwargs: Optional parameters for the segmentation function. - - Returns: - The segmentation. For models that return multiple segmentations, this function returns a dictionary. - """ - if model_type.startswith("vesicles"): - segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "mitochondria": - segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "active_zone": - segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "compartments": - segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "ribbon": - segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - else: - raise ValueError(f"Unknown model type: {model_type}") - return segmentation - - -def get_cache_dir(): - cache_dir = os.path.expanduser(pooch.os_cache("synapse-net")) - return cache_dir - - -def get_model_training_resolution(model_type): - resolutions = { - "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, - "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, - "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, - "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, - "vesicles_2d": {"x": 1.35, "y": 1.35}, - "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, - "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, - } - return resolutions[model_type] - - -def get_model_registry(): - registry = { - "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", - "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", - "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", - "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", - "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", - "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", - "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", - } - urls = { - "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", - "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", - "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", - "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", - "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", - "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", - "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", - } - cache_dir = get_cache_dir() - models = pooch.create( - path=os.path.join(cache_dir, "models"), - base_url="", - registry=registry, - urls=urls, - ) - return models - - -def _get_default_device(): - # check that we're in CI and use the CPU if we are - # otherwise the tests may run out of memory on MAC if MPS is used. - if os.getenv("GITHUB_ACTIONS") == "true": - return "cpu" - # Use cuda enabled gpu if it's available. - if torch.cuda.is_available(): - device = "cuda" - # As second priority use mps. - # See https://pytorch.org/docs/stable/notes/mps.html for details - elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): - device = "mps" - # Use the CPU as fallback. - else: - device = "cpu" - return device - - -def get_device(device: Optional[Union[str, torch.device]] = None) -> Union[str, torch.device]: - """Get the torch device. - - If no device is passed the default device for your system is used. - Else it will be checked if the device you have passed is supported. - - Args: - device: The input device. - - Returns: - The device. - """ - if device is None or device == "auto": - device = _get_default_device() - else: - device_type = device if isinstance(device, str) else device.type - if device_type.lower() == "cuda": - if not torch.cuda.is_available(): - raise RuntimeError("PyTorch CUDA backend is not available.") - elif device_type.lower() == "mps": - if not (torch.backends.mps.is_available() and torch.backends.mps.is_built()): - raise RuntimeError("PyTorch MPS backend is not available or is not built correctly.") - elif device_type.lower() == "cpu": - pass # cpu is always available - else: - raise RuntimeError(f"Unsupported device: {device}\n" - "Please choose from 'cpu', 'cuda', or 'mps'.") - return device - - -def _available_devices(): - available_devices = [] - for i in ["cuda", "mps", "cpu"]: - try: - device = get_device(i) - except RuntimeError: - pass - else: - available_devices.append(device) - return available_devices - - -def get_current_tiling(tiling: dict, default_tiling: dict, model_type: str): - # get tiling values from qt objects - for k, v in tiling.items(): - for k2, v2 in v.items(): - if isinstance(v2, int): - continue - tiling[k][k2] = v2.value() - # check if user inputs tiling/halo or not - if default_tiling == tiling: - if "2d" in model_type: - # if its a 2d model expand x,y and set z to 1 - tiling = { - "tile": { - "x": 512, - "y": 512, - "z": 1 - }, - "halo": { - "x": 64, - "y": 64, - "z": 1 - } - } - elif "2d" in model_type: - # if its a 2d model set z to 1 - tiling["tile"]["z"] = 1 - tiling["halo"]["z"] = 1 - - return tiling - - -def compute_scale_from_voxel_size( - voxel_size: dict, - model_type: str -) -> List[float]: - training_voxel_size = get_model_training_resolution(model_type) - scale = [ - voxel_size["x"] / training_voxel_size["x"], - voxel_size["y"] / training_voxel_size["y"], - ] - if len(voxel_size) == 3 and len(training_voxel_size) == 3: - scale.append( - voxel_size["z"] / training_voxel_size["z"] - ) - return scale - - -def _clean_filepath(filepath): - """ - Cleans a given filepath by: - - Removing newline characters (\n) - - Removing escape sequences - - Stripping the 'file://' prefix if present - - Args: - filepath (str): The original filepath - - Returns: - str: The cleaned filepath - """ - # Remove 'file://' prefix if present - if filepath.startswith("file://"): - filepath = filepath[7:] - - # Remove escape sequences and newlines - filepath = re.sub(r'\\.', '', filepath) - filepath = filepath.replace('\n', '').replace('\r', '') - - return filepath diff --git a/test/test_plugin.py b/test/test_plugin.py new file mode 100644 index 0000000..acaa510 --- /dev/null +++ b/test/test_plugin.py @@ -0,0 +1,21 @@ +import os +import unittest + + +# We just test that the plugins can be imported. +class TestPlugin(unittest.TestCase): + def test_distance_measure_widget(self): + from synapse_net.tools.distance_measure_widget import DistanceMeasureWidget + + def test_morphology_widget(self): + from synapse_net.tools.morphology_widget import MorphologyWidget + + def test_segmentation_widget(self): + from synapse_net.tools.segmentation_widget import SegmentationWidget + + def test_vesicle_pool_widget(self): + from synapse_net.tools.vesicle_pool_widget import VesiclePoolWidget + + +if __name__ == "__main__": + unittest.main() From a7ff94efe215a5814ec3059913afbee7b396e283 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 8 Dec 2024 20:15:32 +0100 Subject: [PATCH 4/9] Udpate example analysis pipeline --- examples/analysis_pipeline.py | 6 +++++- synapse_net/inference/__init__.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/analysis_pipeline.py b/examples/analysis_pipeline.py index 246af5d..34451e0 100644 --- a/examples/analysis_pipeline.py +++ b/examples/analysis_pipeline.py @@ -3,7 +3,7 @@ from synapse_net.file_utils import read_mrc from synapse_net.sample_data import get_sample_data -from synapse_net.tools.util import run_segmentation, get_model, compute_scale_from_voxel_size +from synapse_net.inference import compute_scale_from_voxel_size, get_model, run_segmentation def segment_structures(tomogram, voxel_size): @@ -65,6 +65,10 @@ def main(): # Segment synaptic vesicles, the active zone, and the synaptic compartment. segmentations = segment_structures(tomogram, voxel_size) + import h5py + with h5py.File("seg.h5", "r") as f: + for name, seg in segmentations.items(): + f.create_dataset(name, data=seg, compression="gzip") # Post-process the segmentations, to find the presynaptic terminal, # filter out vesicles not in the terminal, and to 'snape' the AZ to the presynaptic boundary. diff --git a/synapse_net/inference/__init__.py b/synapse_net/inference/__init__.py index 1619054..bdcabec 100644 --- a/synapse_net/inference/__init__.py +++ b/synapse_net/inference/__init__.py @@ -1,3 +1,3 @@ """This submodule implements SynapseNet's segmentation functionality. """ -from .inference import run_segmentation, get_model +from .inference import run_segmentation, get_model, compute_scale_from_voxel_size From fb1fcd0691920ba5bf987ba5093c912fd774642d Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 8 Dec 2024 20:59:32 +0100 Subject: [PATCH 5/9] Fix issues in active zone segmentation --- synapse_net/inference/active_zone.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/synapse_net/inference/active_zone.py b/synapse_net/inference/active_zone.py index e273d45..e9040fe 100644 --- a/synapse_net/inference/active_zone.py +++ b/synapse_net/inference/active_zone.py @@ -87,7 +87,8 @@ def segment_active_zone( verbose: Whether to print timing information. scale: The scale factor to use for rescaling the input volume before prediction. mask: An optional mask that is used to restrict the segmentation. - compartment: + compartment: Pass a compartment segmentation, to intersect the boundaries of the + compartments with the active zone prediction. Returns: The foreground mask as a numpy array. @@ -108,13 +109,17 @@ def segment_active_zone( print(f"shape {foreground.shape}") segmentation = _run_segmentation(foreground, verbose=verbose, min_size=min_size) + segmentation = scaler.rescale_output(segmentation, is_segmentation=True) - # returning prediciton and intersection not possible atm, but currently do not need prediction anyways + # Returning prediciton and intersection currently not possible. if return_predictions: + assert compartment is None pred = scaler.rescale_output(pred, is_segmentation=False) return segmentation, pred if compartment is not None: + assert not return_predictions + compartment = scaler.scale_input(input_volume, is_segmentation=True) intersection = find_intersection_boundary(segmentation, compartment) return segmentation, intersection From ba56b574882cfe229ec223dbc407be2a818e59dd Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 8 Dec 2024 21:26:07 +0100 Subject: [PATCH 6/9] Update tomo sample data and example analysis pipeline --- examples/analysis_pipeline.py | 90 +++++++++++++++++++++++++++++------ synapse_net/sample_data.py | 4 +- 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/examples/analysis_pipeline.py b/examples/analysis_pipeline.py index 34451e0..80e2321 100644 --- a/examples/analysis_pipeline.py +++ b/examples/analysis_pipeline.py @@ -1,8 +1,14 @@ import napari import pandas as pd +import numpy as np + +from scipy.ndimage import binary_closing +from skimage.measure import regionprops +from skimage.segmentation import find_boundaries from synapse_net.file_utils import read_mrc from synapse_net.sample_data import get_sample_data +from synapse_net.distance_measurements import measure_segmentation_to_object_distances from synapse_net.inference import compute_scale_from_voxel_size, get_model, run_segmentation @@ -30,28 +36,77 @@ def segment_structures(tomogram, voxel_size): return {"vesicles": vesicles, "active_zone": active_zone, "compartments": compartments} +def n_vesicles(mask, ves): + return len(np.unique(ves[mask])) - 1 + + def postprocess_segmentation(segmentations): - pass + # We find the compartment corresponding to the presynaptic terminal + # by selecting the compartment with most vesicles. We filter out all + # vesicles that do not overlap with this compartment. + vesicles, compartments = segmentations["vesicles"], segmentations["compartments"] -def measure_distances(segmentations): - pass + # First, we find the compartment with most vesicles. + props = regionprops(compartments, intensity_image=vesicles, extra_properties=[n_vesicles]) + compartment_ids = [prop.label for prop in props] + vesicle_counts = [prop.n_vesicles for prop in props] + compartments = (compartments == compartment_ids[np.argmax(vesicle_counts)]).astype("uint8") + # Filter all vesicles that are not in the compartment. + props = regionprops(vesicles, compartments) + filter_ids = [prop.label for prop in props if prop.max_intensity == 0] + vesicles[np.isin(vesicles, filter_ids)] = 0 -def assign_vesicle_pools(distances): - pass + segmentations["vesicles"], segmentations["compartments"] = vesicles, compartments + + # We also apply closing to the active zone segmentation to avoid gaps and then + # intersect it with the boundary of the presynaptic compartment. + active_zone = segmentations["active_zone"] + active_zone = binary_closing(active_zone, iterations=4) + boundary = find_boundaries(compartments) + active_zone = np.logical_and(active_zone, boundary).astype("uint8") + segmentations["active_zone"] = active_zone + + return segmentations -def visualize_results(tomogram, segmentations, vesicle_pools): - # TODO vesicle pool visualization +def measure_distances(segmentations, voxel_size): + vesicles, active_zone = segmentations["vesicles"], segmentations["active_zone"] + voxel_size = tuple(voxel_size[ax] for ax in "zyx") + distances, _, _, vesicle_ids = measure_segmentation_to_object_distances( + vesicles, active_zone, resolution=voxel_size + ) + return pd.DataFrame({"vesicle_id": vesicle_ids, "distance": distances}) + + +def assign_vesicle_pools(vesicle_attributes): + docked_vesicle_distance = 2 # nm + vesicle_attributes["pool"] = vesicle_attributes["distance"].apply( + lambda x: "docked" if x < docked_vesicle_distance else "non-attached" + ) + return vesicle_attributes + + +def visualize_results(tomogram, segmentations, vesicle_attributes): + + # Create a segmentation to visualize the vesicle pools. + docked_ids = vesicle_attributes[vesicle_attributes.pool == "docked"].vesicle_id + non_attached_ids = vesicle_attributes[vesicle_attributes.pool == "non-attached"].vesicle_id + vesicles = segmentations["vesicles"] + vesicle_pools = np.isin(vesicles, docked_ids).astype("uint8") + vesicle_pools[np.isin(vesicles, non_attached_ids)] = 2 + viewer = napari.Viewer() viewer.add_image(tomogram) for name, segmentation in segmentations.items(): viewer.add_labels(segmentation, name=name) + viewer.add_labels(vesicle_pools) napari.run() -def save_analysis(segmentations, distances, vesicle_pools, save_path): +# TODO compute the vesicle radii and other features and then save the attributes. +def save_analysis(segmentations, vesicle_attributes, save_path): pass @@ -64,28 +119,33 @@ def main(): tomogram, voxel_size = read_mrc(mrc_path) # Segment synaptic vesicles, the active zone, and the synaptic compartment. - segmentations = segment_structures(tomogram, voxel_size) + # segmentations = segment_structures(tomogram, voxel_size) + + # Load saved segmentations for development. import h5py + segmentations = {} with h5py.File("seg.h5", "r") as f: - for name, seg in segmentations.items(): - f.create_dataset(name, data=seg, compression="gzip") + for name, ds in f.items(): + # f.create_dataset(name, data=seg, compression="gzip") + seg = ds[:] + segmentations[name] = seg # Post-process the segmentations, to find the presynaptic terminal, # filter out vesicles not in the terminal, and to 'snape' the AZ to the presynaptic boundary. segmentations = postprocess_segmentation(segmentations) # Measure the distances between the AZ and vesicles. - distances = measure_distances(segmentations) + vesicle_attributes = measure_distances(segmentations, voxel_size) # Assign the vesicle pools, 'docked' and 'non-attached' vesicles, based on the distances. - vesicle_pools = assign_vesicle_pools(distances) + vesicle_attributes = assign_vesicle_pools(vesicle_attributes) # Visualize the results. - visualize_results(tomogram, segmentations, vesicle_pools) + visualize_results(tomogram, segmentations, vesicle_attributes) # Compute the vesicle radii and combine and save all measurements. save_path = "analysis_results.xlsx" - save_analysis(segmentations, distances, vesicle_pools, save_path) + save_analysis(segmentations, vesicle_attributes, save_path) if __name__ == "__main__": diff --git a/synapse_net/sample_data.py b/synapse_net/sample_data.py index cac30bd..0271015 100644 --- a/synapse_net/sample_data.py +++ b/synapse_net/sample_data.py @@ -15,11 +15,11 @@ def get_sample_data(name: str) -> str: """ registry = { "tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28", - "tem_tomo.mrc": "eb790f83efb4c967c96961239ae52578d95da902fc32307629f76a26c3dc61fa", + "tem_tomo.mrc": "fe862ce7c22000d4440e3aa717ca9920b42260f691e5b2ab64cd61c928693c99", } urls = { "tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download", - "tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/TmLjDCXi42E49Ef/download", + "tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/FJDhDfbT4UxhtOn/download", } key = f"{name}.mrc" From c75abd9ddd839d0d6bb5ea2a70f37881b872c935 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 9 Dec 2024 15:37:16 +0100 Subject: [PATCH 7/9] Update the example scripts and add downloads for data from zenodo --- examples/.gitignore | 5 +++++ examples/analysis_pipeline.py | 37 +++++++++++++++++++------------ examples/domain_adaptation.py | 36 +++++++++++++++++------------- examples/network_training.py | 22 +++++++++++------- scripts/prepare_zenodo_uploads.py | 2 +- synapse_net/sample_data.py | 30 +++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 38 deletions(-) create mode 100644 examples/.gitignore diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 0000000..8ac35e8 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,5 @@ +data/ +set_up_pool.py +*.h5 +*.tif +*.mrc diff --git a/examples/analysis_pipeline.py b/examples/analysis_pipeline.py index 80e2321..fe01065 100644 --- a/examples/analysis_pipeline.py +++ b/examples/analysis_pipeline.py @@ -6,10 +6,11 @@ from skimage.measure import regionprops from skimage.segmentation import find_boundaries -from synapse_net.file_utils import read_mrc -from synapse_net.sample_data import get_sample_data from synapse_net.distance_measurements import measure_segmentation_to_object_distances +from synapse_net.file_utils import read_mrc +from synapse_net.imod.to_imod import convert_segmentation_to_spheres from synapse_net.inference import compute_scale_from_voxel_size, get_model, run_segmentation +from synapse_net.sample_data import get_sample_data def segment_structures(tomogram, voxel_size): @@ -72,15 +73,23 @@ def postprocess_segmentation(segmentations): def measure_distances(segmentations, voxel_size): + # Here, we measure the distances from each vesicle to the active zone. + # We use the function 'measure_segmentation_to_object_distances' for this, + # which uses an euclidean distance transform scaled with the voxel size + # to determine distances. vesicles, active_zone = segmentations["vesicles"], segmentations["active_zone"] voxel_size = tuple(voxel_size[ax] for ax in "zyx") distances, _, _, vesicle_ids = measure_segmentation_to_object_distances( vesicles, active_zone, resolution=voxel_size ) + # We convert the result to a pandas data frame. return pd.DataFrame({"vesicle_id": vesicle_ids, "distance": distances}) def assign_vesicle_pools(vesicle_attributes): + # We assign the vesicles to their respective pool, 'docked' and 'non-attached', + # based on the criterion of being within 2 nm from the active zone. + # We add the pool assignment as a new column to the dataframe with vesicle attributes. docked_vesicle_distance = 2 # nm vesicle_attributes["pool"] = vesicle_attributes["distance"].apply( lambda x: "docked" if x < docked_vesicle_distance else "non-attached" @@ -89,6 +98,7 @@ def assign_vesicle_pools(vesicle_attributes): def visualize_results(tomogram, segmentations, vesicle_attributes): + # Here, we visualize the segmentation and pool assignment result in napari. # Create a segmentation to visualize the vesicle pools. docked_ids = vesicle_attributes[vesicle_attributes.pool == "docked"].vesicle_id @@ -97,6 +107,7 @@ def visualize_results(tomogram, segmentations, vesicle_attributes): vesicle_pools = np.isin(vesicles, docked_ids).astype("uint8") vesicle_pools[np.isin(vesicles, non_attached_ids)] = 2 + # Create a napari viewer, add the tomogram data and the segmentation results. viewer = napari.Viewer() viewer.add_image(tomogram) for name, segmentation in segmentations.items(): @@ -105,9 +116,16 @@ def visualize_results(tomogram, segmentations, vesicle_attributes): napari.run() -# TODO compute the vesicle radii and other features and then save the attributes. def save_analysis(segmentations, vesicle_attributes, save_path): - pass + # Here, we compute the radii and centroid positions of the vesicles, + # add them to the vesicle attributes and then save all vesicle attributes to + # an excel table. You can use this table for evaluation of the analysis. + vesicles = segmentations["vesicles"] + coordinates, radii = convert_segmentation_to_spheres(vesicles, radius_factor=0.7) + vesicle_attributes["radius"] = radii + for ax_id, ax_name in enumerate("zyx"): + vesicle_attributes[f"center-{ax_name}"] = coordinates[:, ax_id] + vesicle_attributes.to_excel(save_path, index=False) def main(): @@ -119,16 +137,7 @@ def main(): tomogram, voxel_size = read_mrc(mrc_path) # Segment synaptic vesicles, the active zone, and the synaptic compartment. - # segmentations = segment_structures(tomogram, voxel_size) - - # Load saved segmentations for development. - import h5py - segmentations = {} - with h5py.File("seg.h5", "r") as f: - for name, ds in f.items(): - # f.create_dataset(name, data=seg, compression="gzip") - seg = ds[:] - segmentations[name] = seg + segmentations = segment_structures(tomogram, voxel_size) # Post-process the segmentations, to find the presynaptic terminal, # filter out vesicles not in the terminal, and to 'snape' the AZ to the presynaptic boundary. diff --git a/examples/domain_adaptation.py b/examples/domain_adaptation.py index afad37c..7966b25 100644 --- a/examples/domain_adaptation.py +++ b/examples/domain_adaptation.py @@ -4,35 +4,41 @@ a different electron tomogram with different specimen and sample preparation. You don't need any annotations in the new domain to run this script. -You can download example data for this script from: -- Adaptation to 2d TEM data: TODO zenodo link -- Adaptation to different tomography data: TODO zenodo link +We use data from the SynapseNet publication for this example: +- Adaptation to 2d TEM data: https://doi.org/10.5281/zenodo.14236381 +- Adaptation to different tomography data (3d data): https://doi.org/10.5281/zenodo.14232606 + +It is of course possible to adapt it to your own data. """ import os from glob import glob from sklearn.model_selection import train_test_split +from synapse_net.inference.inference import get_model_path +from synapse_net.sample_data import download_data_from_zenodo from synapse_net.training import mean_teacher_adaptation -from synapse_net.tools.util import get_model_path def main(): # Choose whether to adapt the model to 2D or to 3D data. - train_2d_model = True - - # TODO adjust to zenodo downloads - # These are the data folders for the example data downloaded from zenodo. - # Update these paths to apply the script to your own data. - # Check out the example data to see the data format for training. - data_root_folder_2d = "./data/2d_tem/train_unlabeled" - data_root_folder_3d = "./data/..." + train_2d_model = False - # Choose the correct data folder depending on 2d/3d training. - data_root_folder = data_root_folder_2d if train_2d_model else data_root_folder_3d + # Download the training data from zenodo. + # You have to replace this if you want to train on your own data. + # The training data should be stored in an hdf5 file per tomogram, + # with tomgoram data stored in the internal dataset 'raw'. + if train_2d_model: + data_root = "./data/2d_tem" + download_data_from_zenodo(data_root, "2d_tem") + train_root_folder = os.path.join(data_root, "train_unlabeled") + else: + data_root = "./data/inner_ear_ribbon_synapse" + download_data_from_zenodo(data_root, "inner_ear_ribbon_synapse") + train_root_folder = data_root # Get all files with ending .h5 in the training folder. - files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True)) + files = sorted(glob(os.path.join(train_root_folder, "**", "*.h5"), recursive=True)) # Crate a train / val split. train_ratio = 0.85 diff --git a/examples/network_training.py b/examples/network_training.py index c90d592..9775209 100644 --- a/examples/network_training.py +++ b/examples/network_training.py @@ -5,22 +5,28 @@ to adapt an already trained network to your data without the need for additional annotations then check out `domain_adaptation.py`. -You can download example data for this script from: -TODO zenodo link to Single-Ax / Chemical Fix data. +We will use the data from our manuscript here: +https://doi.org/10.5281/zenodo.14330011 + +You can also use your own data, if you prepare it in the same format. """ import os from glob import glob from sklearn.model_selection import train_test_split +from synapse_net.sample_data import download_data_from_zenodo from synapse_net.training import supervised_training def main(): - # This is the folder that contains your training data. - # The example was designed so that it runs for the sample data downloaded to './data'. - # If you want to train on your own data than change this filepath accordingly. - # TODO update to match zenodo download - data_root_folder = "./data/vesicles/train" + # Download the training data from zenodo. + # You have to replace this if you want to train on your own data. + # The training data should be stored in an hdf5 file per tomogram, + # with tomgoram data stored in the internal dataset 'raw' + # and the vesicle annotations stored in the internal dataset 'labels/vesicles'. + data_root = "./data/training_data" + download_data_from_zenodo(data_root, "training_data") + train_root_folder = os.path.join(data_root, "vesicles/train") # The training data should be saved as .h5 files, with: # an internal dataset called 'raw' that contains the image data @@ -28,7 +34,7 @@ def main(): label_key = "labels/vesicles" # Get all files with the ending .h5 in the training folder. - files = sorted(glob(os.path.join(data_root_folder, "**", "*.h5"), recursive=True)) + files = sorted(glob(os.path.join(train_root_folder, "**", "*.h5"), recursive=True)) # Crate a train / val split. train_ratio = 0.85 diff --git a/scripts/prepare_zenodo_uploads.py b/scripts/prepare_zenodo_uploads.py index b642c07..344532c 100644 --- a/scripts/prepare_zenodo_uploads.py +++ b/scripts/prepare_zenodo_uploads.py @@ -56,7 +56,7 @@ def _export_az(train_root, test_tomos, name): for tomo in tqdm(tomograms): fname = os.path.basename(tomo) - if tomo in test_tomos: + if fname in test_tomos: out_path = os.path.join(test_out, fname) else: out_path = os.path.join(train_out, fname) diff --git a/synapse_net/sample_data.py b/synapse_net/sample_data.py index 0271015..7eb53ad 100644 --- a/synapse_net/sample_data.py +++ b/synapse_net/sample_data.py @@ -1,4 +1,5 @@ import os +import tempfile import pooch from .file_utils import read_mrc, get_cache_dir @@ -52,3 +53,32 @@ def sample_data_tem_2d(): def sample_data_tem_tomo(): return _sample_data("tem_tomo") + + +def download_data_from_zenodo(path: str, name: str): + """Download data uploaded for the SynapseNet manuscript from zenodo. + + Args: + path: The path where the downloaded data will be saved. + name: The name of the zenodi dataset. + """ + from torch_em.data.datasets.util import download_source, unzip + + urls = { + "2d_tem": "https://zenodo.org/records`/14236382/files/tem_2d.zip?download=1", + "inner_ear_ribbon_synapse": "https://zenodo.org/records/14232607/files/inner-ear-ribbon-synapse-tomgrams.zip?download=1", # noqa + "training_data": "https://zenodo.org/records/14330011/files/synapse-net.zip?download=1" + } + assert name in urls + url = urls[name] + + # May need to adapt this for other datasets. + # Check if the download already exists. + dl_path = path + if os.path.exists(dl_path): + return + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = os.path.join(tmp, f"{name}.zip") + download_source(tmp_path, url, download=True, checksum=None) + unzip(tmp_path, path, remove=False) From c0afedb07dcd5325c13d7d0830dd8be6423eff78 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 9 Dec 2024 23:10:09 +0100 Subject: [PATCH 8/9] Update the cooper scripts --- scripts/cooper/README.md | 4 ++-- scripts/cooper/run_compartment_segmentation.py | 11 +++++++++-- scripts/cooper/run_mitochondria_segmentation.py | 10 ++++++++-- scripts/cooper/run_vesicle_segmentation.py | 10 ++++++++-- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/scripts/cooper/README.md b/scripts/cooper/README.md index f8d635a..dd07f5a 100644 --- a/scripts/cooper/README.md +++ b/scripts/cooper/README.md @@ -19,9 +19,9 @@ $ micromamba activate sam The segmentation scripts (`run_..._segmentation.py`) all work similarly and can either run segmentation for a single mrc file or for all mrcs in a folder structure. For example, you can run vesicle segmentation like this: ``` -$ python run_vesicle_segmentation.py -i /path/to/input_folder -o /path/to/output_folder -m /path/to/vesicle_model.pt +$ python run_vesicle_segmentation.py -i /path/to/input_folder -o /path/to/output_folder ``` -The filepath after `-i` specifices the location of the folder with the mrcs to be segmented, the segmentation results will be stored (as tifs) in the folder following `-o` and `-m` is used to specify the path to the segmentation model. +The filepath after `-i` specifices the location of the folder with the mrcs to be segmented and the segmentation results will be stored (as tifs) in the folder following `-o`. To segment vesicles with an additional mask, you can use the `--mask_path` option. The segmentation scripts accept additional parameters, e.g. `--force` to overwrite existing segmentations in the output folder (by default these are skipped to avoid unnecessary computation) and `--tile_shape ` to specify a different tile shape (which may be necessary to avoid running out of GPU memory). diff --git a/scripts/cooper/run_compartment_segmentation.py b/scripts/cooper/run_compartment_segmentation.py index 0c8e908..6396fd1 100644 --- a/scripts/cooper/run_compartment_segmentation.py +++ b/scripts/cooper/run_compartment_segmentation.py @@ -2,13 +2,20 @@ from functools import partial from synapse_net.inference.compartments import segment_compartments +from synapse_net.inference.inference import get_model_path from synapse_net.inference.util import inference_helper, parse_tiling def run_compartment_segmentation(args): tiling = parse_tiling(args.tile_shape, args.halo) + + if args.model is None: + model_path = get_model_path("compartments") + else: + model_path = args.model + segmentation_function = partial( - segment_compartments, model_path=args.model_path, verbose=False, tiling=tiling, scale=[0.25, 0.25, 0.25] + segment_compartments, model_path=model_path, verbose=False, tiling=tiling, scale=[0.25, 0.25, 0.25] ) inference_helper( args.input_path, args.output_path, segmentation_function, force=args.force, data_ext=args.data_ext @@ -26,7 +33,7 @@ def main(): help="The filepath to directory where the segmentation will be saved." ) parser.add_argument( - "--model_path", "-m", required=True, help="The filepath to the compartment model." + "--model", "-m", help="The filepath to the compartment model." ) parser.add_argument( "--force", action="store_true", diff --git a/scripts/cooper/run_mitochondria_segmentation.py b/scripts/cooper/run_mitochondria_segmentation.py index a71af91..adb641f 100644 --- a/scripts/cooper/run_mitochondria_segmentation.py +++ b/scripts/cooper/run_mitochondria_segmentation.py @@ -2,13 +2,19 @@ from functools import partial from synapse_net.inference.mitochondria import segment_mitochondria +from synapse_net.inference.inference import get_model_path from synapse_net.inference.util import inference_helper, parse_tiling def run_mitochondria_segmentation(args): + if args.model is None: + model_path = get_model_path("mitochondria") + else: + model_path = args.model + tiling = parse_tiling(args.tile_shape, args.halo) segmentation_function = partial( - segment_mitochondria, model_path=args.model_path, verbose=False, tiling=tiling, scale=[0.5, 0.5, 0.5] + segment_mitochondria, model_path=model_path, verbose=False, tiling=tiling, scale=[0.5, 0.5, 0.5] ) inference_helper( args.input_path, args.output_path, segmentation_function, @@ -27,7 +33,7 @@ def main(): help="The filepath to directory where the segmentation will be saved." ) parser.add_argument( - "--model_path", "-m", required=True, help="The filepath to the mitochondria model." + "--model", "-m", help="The filepath to the mitochondria model." ) parser.add_argument( "--force", action="store_true", diff --git a/scripts/cooper/run_vesicle_segmentation.py b/scripts/cooper/run_vesicle_segmentation.py index c4b749b..4ceb40d 100644 --- a/scripts/cooper/run_vesicle_segmentation.py +++ b/scripts/cooper/run_vesicle_segmentation.py @@ -2,13 +2,19 @@ from functools import partial from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.inference.inference import get_model_path from synapse_net.inference.util import inference_helper, parse_tiling def run_vesicle_segmentation(args): + if args.model is None: + model_path = get_model_path("vesicles_3d") + else: + model_path = args.model + tiling = parse_tiling(args.tile_shape, args.halo) segmentation_function = partial( - segment_vesicles, model_path=args.model_path, verbose=False, tiling=tiling, + segment_vesicles, model_path=model_path, verbose=False, tiling=tiling, exclude_boundary=not args.include_boundary ) inference_helper( @@ -28,7 +34,7 @@ def main(): help="The filepath to directory where the segmentations will be saved." ) parser.add_argument( - "--model_path", "-m", required=True, help="The filepath to the vesicle model." + "--model_path", "-m", help="The filepath to the vesicle model." ) parser.add_argument( "--mask_path", help="The filepath to a tif file with a mask that will be used to restrict the segmentation." From 03d42a8c8f0e4a68949be14292479065ddb21ba7 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 10 Dec 2024 21:54:41 +0100 Subject: [PATCH 9/9] Fix issues with inference code --- synapse_net/inference/util.py | 10 ++++++---- synapse_net/tools/cli.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/synapse_net/inference/util.py b/synapse_net/inference/util.py index 1ad3a73..ea92f29 100644 --- a/synapse_net/inference/util.py +++ b/synapse_net/inference/util.py @@ -414,8 +414,8 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: return {"tile": tile, "halo": halo} if torch.cuda.is_available(): - # We always use the same default halo. - halo = {"x": 64, "y": 64, "z": 16} # before 64,64,8 + # The default halo size. + halo = {"x": 64, "y": 64, "z": 16} # Determine the GPU RAM and derive a suitable tiling. vram = torch.cuda.get_device_properties(0).total_memory / 1e9 @@ -426,9 +426,11 @@ def get_default_tiling(is_2d: bool = False) -> Dict[str, Dict[str, int]]: tile = {"x": 512, "y": 512, "z": 64} elif vram >= 20: tile = {"x": 352, "y": 352, "z": 48} + elif vram >= 10: + tile = {"x": 256, "y": 256, "z": 32} + halo = {"x": 64, "y": 64, "z": 8} # Choose a smaller halo in z. else: - # TODO determine tilings for smaller VRAM - raise NotImplementedError(f"Estimating the tile size for a GPU with {vram} GB is not yet supported.") + raise NotImplementedError(f"Infererence with a GPU with {vram} GB VRAM is not supported.") print(f"Determined tile size: {tile}") tiling = {"tile": tile, "halo": halo} diff --git a/synapse_net/tools/cli.py b/synapse_net/tools/cli.py index 11caeb7..609bb0e 100644 --- a/synapse_net/tools/cli.py +++ b/synapse_net/tools/cli.py @@ -124,11 +124,11 @@ def segmentation_cli(): ) parser.add_argument( "--tile_shape", type=int, nargs=3, - help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient." + help="The tile shape for prediction, in ZYX order. Lower the tile shape if GPU memory is insufficient." ) parser.add_argument( "--halo", type=int, nargs=3, - help="The halo for prediction. Increase the halo to minimize boundary artifacts." + help="The halo for prediction, in ZYX order. Increase the halo to minimize boundary artifacts." ) parser.add_argument( "--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."