Skip to content

Commit

Permalink
Merge pull request #75 from computational-cell-analytics/add-ribbon-i…
Browse files Browse the repository at this point in the history
…nference

Add ribbon model and refactor IO functionality
  • Loading branch information
constantinpape authored Dec 6, 2024
2 parents 0fe01c4 + 71f9b2c commit 57b7258
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 80 deletions.
56 changes: 55 additions & 1 deletion synaptic_reconstruction/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import mrcfile
import numpy as np


def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]:
Expand All @@ -23,3 +26,54 @@ def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, Lis
return tomograms
assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}"
return tomograms[0] if n_tomograms == 1 else tomograms


def _parse_voxel_size(voxel_size):
parsed_voxel_size = None
try:
# The voxel sizes are stored in Angsrrom in the MRC header, but we want them
# in nanometer. Hence we divide by a factor of 10 here.
parsed_voxel_size = {
"x": voxel_size.x / 10,
"y": voxel_size.y / 10,
"z": voxel_size.z / 10,
}
except Exception as e:
print(f"Failed to read voxel size: {e}")
return parsed_voxel_size


def read_voxel_size(path: str) -> Dict[str, float] | None:
"""Read voxel size from mrc/rec file.
The original unit of voxel size is Angstrom and we convert it to nanometers by dividing it by ten.
Args:
path: Path to mrc/rec file.
Returns:
Mapping from the axis name to voxel size. None if the voxel size could not be read.
"""
with mrcfile.open(path, permissive=True) as mrc:
voxel_size = _parse_voxel_size(mrc.voxel_size)
return voxel_size


def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]:
"""Read data and voxel size from mrc/rec file.
Args:
path: Path to mrc/rec file.
Returns:
The data read from the file.
The voxel size read from the file.
"""
with mrcfile.open(path, permissive=True) as mrc:
voxel_size = _parse_voxel_size(mrc.voxel_size)
data = np.asarray(mrc.data[:])
assert data.ndim in (2, 3)

# Transpose the data to match python axis order.
data = np.flip(data, axis=1) if data.ndim == 3 else np.flip(data, axis=0)
return data, voxel_size
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@


# TODO
# - merge compartments which share vesicles (based on threshold for merging)
# - filter out compartments with less than some threshold vesicles
def postpocess_compartments():
pass
2 changes: 1 addition & 1 deletion synaptic_reconstruction/inference/postprocessing/ribbon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def segment_ribbon(
n_slices_exclude: The number of slices to exclude on the top / bottom
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
n_ribbons: The number of ribbons in the tomogram.
max_vesicle_distance: The maximal distance to associate a vesicle with a ribbon.
max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
"""
assert ribbon_prediction.shape == vesicle_segmentation.shape

Expand Down
21 changes: 20 additions & 1 deletion synaptic_reconstruction/napari.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: synaptic_reconstruction
display_name: SynapseNet
# see https://napari.org/stable/plugins/manifest.html for valid categories

# See https://napari.org/stable/plugins/manifest.html for valid categories.
categories: ["Image Processing", "Annotation"]

contributions:
commands:
# Commands for widgets.
- id: synaptic_reconstruction.segment
python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget
title: Segment
Expand All @@ -20,6 +23,14 @@ contributions:
python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget
title: Vesicle Pooling

# Commands for sample data.
- id: synaptic_reconstruction.sample_data_tem_2d
python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d
title: Load TEM 2D sample data
- id: synaptic_reconstruction.sample_data_tem_tomo
python_name: synaptic_reconstruction.sample_data:sample_data_tem_tomo
title: Load TEM Tomo sample data

readers:
- command: synaptic_reconstruction.file_reader
filename_patterns:
Expand All @@ -37,3 +48,11 @@ contributions:
display_name: Morphology Analysis
- command: synaptic_reconstruction.vesicle_pooling
display_name: Vesicle Pooling

sample_data:
- command: synaptic_reconstruction.sample_data_tem_2d
display_name: TEM 2D Sample Data
key: synapse-net-tem-2d
- command: synaptic_reconstruction.sample_data_tem_tomo
display_name: TEM Tomo Sample Data
key: synapse-net-tem-tomo
22 changes: 21 additions & 1 deletion synaptic_reconstruction/sample_data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import os
import pooch

from .file_utils import read_mrc


def get_sample_data(name: str) -> str:
"""Get the filepath to SynapseNet sample data, stored as mrc file.
Args:
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
name: The name of the sample data. Currently, we only provide 'tem_2d' and 'tem_tomo'.
Returns:
The filepath to the downloaded sample data.
"""
registry = {
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
"tem_tomo.mrc": "24af31a10761b59fa6ad9f0e763f8f084304e4f31c59b482dd09dde8cd443ed7",
}
urls = {
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
"tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/NeP7gOv76Vj26lm/download",
}
key = f"{name}.mrc"

Expand All @@ -32,3 +36,19 @@ def get_sample_data(name: str) -> str:
)
file_path = data_registry.fetch(key)
return file_path


def _sample_data(name):
file_path = get_sample_data(name)
data, voxel_size = read_mrc(file_path)
metadata = {"file_path": file_path, "voxel_size": voxel_size}
add_image_kwargs = {"name": name, "metadata": metadata, "colormap": "gray"}
return [(data, add_image_kwargs)]


def sample_data_tem_2d():
return _sample_data("tem_2d")


def sample_data_tem_tomo():
return _sample_data("tem_tomo")
21 changes: 11 additions & 10 deletions synaptic_reconstruction/tools/base_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def __init__(self):
self.attribute_dict = {}

def _create_layer_selector(self, selector_name, layer_type="Image"):
"""
Create a layer selector for an image or labels and store it in a dictionary.
"""Create a layer selector for an image or labels and store it in a dictionary.
Parameters:
- selector_name (str): The name of the selector, used as a key in the dictionary.
- layer_type (str): The type of layer to filter for ("Image" or "Labels").
Args:
selector_name (str): The name of the selector, used as a key in the dictionary.
layer_type (str): The type of layer to filter for ("Image" or "Labels").
"""
if not hasattr(self, "layer_selectors"):
self.layer_selectors = {}
Expand Down Expand Up @@ -286,17 +285,19 @@ def _get_file_path(self, name, textbox, tooltip=None):
# Handle the case where the selected path is not a file
print("Invalid file selected. Please try again.")

def _handle_resolution(self, metadata, voxel_size_param, ndim):
def _handle_resolution(self, metadata, voxel_size_param, ndim, return_as_list=True):
# Get the resolution / voxel size from the layer metadata if available.
resolution = metadata.get("voxel_size", None)
if resolution is not None:
resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")]

# If user input was given then override resolution from metadata.
axes = "zyx" if ndim == 3 else "yx"
if voxel_size_param.value() != 0.0: # Changed from default.
resolution = ndim * [voxel_size_param.value()]
resolution = {ax: voxel_size_param.value() for ax in axes}

if resolution is not None and return_as_list:
resolution = [resolution[ax] for ax in axes]
assert len(resolution) == ndim

assert len(resolution) == ndim
return resolution

def _save_table(self, save_path, data):
Expand Down
52 changes: 32 additions & 20 deletions synaptic_reconstruction/tools/segmentation_widget.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import copy

import napari
import numpy as np

from napari.utils.notifications import show_info
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox

from .base_widget import BaseWidget
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
from synaptic_reconstruction.inference.util import get_default_tiling
import copy
from ..inference.util import get_default_tiling


class SegmentationWidget(BaseWidget):
Expand Down Expand Up @@ -79,37 +82,41 @@ def on_predict(self):
show_info("INFO: Please choose an image.")
return

# load current tiling
# Get the current tiling.
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)

# Get the voxel size.
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
voxel_size = metadata.get("voxel_size", None)
scale = None
voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)

if self.voxel_size_param.value() != 0.0: # changed from default
voxel_size = {}
# override voxel size with user input
if len(image.shape) == 3:
voxel_size["x"] = self.voxel_size_param.value()
voxel_size["y"] = self.voxel_size_param.value()
voxel_size["z"] = self.voxel_size_param.value()
else:
voxel_size["x"] = self.voxel_size_param.value()
voxel_size["y"] = self.voxel_size_param.value()
# Determine the scaling based on the voxel size.
scale = None
if voxel_size:
if model_type == "custom":
show_info("INFO: The image is not rescaled for a custom model.")
else:
# calculate scale so voxel_size is the same as in training
scale = compute_scale_from_voxel_size(voxel_size, model_type)
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")

scale_info = list(map(lambda x: np.round(x, 2), scale))
show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")

# Some models require an additional segmentation for inference or postprocessing.
# For these models we read out the 'Extra Segmentation' widget.
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
kwargs = {"extra_segmentation": extra_seg}
else:
kwargs = {}
segmentation = run_segmentation(
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
)

# Add the segmentation layer
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
# Add the segmentation layer(s).
if isinstance(segmentation, dict):
for name, seg in segmentation.items():
self.viewer.add_labels(seg, name=name, metadata=metadata)
else:
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
show_info(f"INFO: Segmentation of {model_type} added to layers.")

def _create_settings_widget(self):
Expand Down Expand Up @@ -156,5 +163,10 @@ def _create_settings_widget(self):
)
setting_values.layout().addLayout(layout)

# Add selection UI for additional segmentation, which some models require for inference or postproc.
self.extra_seg_selector_name = "Extra Segmentation"
self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
setting_values.layout().addWidget(self.extra_selector_widget)

settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
return settings
Loading

0 comments on commit 57b7258

Please sign in to comment.