Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ribbon model and refactor IO functionality #75

Merged
merged 2 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion synaptic_reconstruction/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
from typing import List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union

import mrcfile
import numpy as np


def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]:
Expand All @@ -23,3 +26,54 @@ def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, Lis
return tomograms
assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}"
return tomograms[0] if n_tomograms == 1 else tomograms


def _parse_voxel_size(voxel_size):
parsed_voxel_size = None
try:
# The voxel sizes are stored in Angsrrom in the MRC header, but we want them
# in nanometer. Hence we divide by a factor of 10 here.
parsed_voxel_size = {
"x": voxel_size.x / 10,
"y": voxel_size.y / 10,
"z": voxel_size.z / 10,
}
except Exception as e:
print(f"Failed to read voxel size: {e}")
return parsed_voxel_size


def read_voxel_size(path: str) -> Dict[str, float] | None:
"""Read voxel size from mrc/rec file.

The original unit of voxel size is Angstrom and we convert it to nanometers by dividing it by ten.

Args:
path: Path to mrc/rec file.

Returns:
Mapping from the axis name to voxel size. None if the voxel size could not be read.
"""
with mrcfile.open(path, permissive=True) as mrc:
voxel_size = _parse_voxel_size(mrc.voxel_size)
return voxel_size


def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]:
"""Read data and voxel size from mrc/rec file.

Args:
path: Path to mrc/rec file.

Returns:
The data read from the file.
The voxel size read from the file.
"""
with mrcfile.open(path, permissive=True) as mrc:
voxel_size = _parse_voxel_size(mrc.voxel_size)
data = np.asarray(mrc.data[:])
assert data.ndim in (2, 3)

# Transpose the data to match python axis order.
data = np.flip(data, axis=1) if data.ndim == 3 else np.flip(data, axis=0)
return data, voxel_size
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@


# TODO
# - merge compartments which share vesicles (based on threshold for merging)
# - filter out compartments with less than some threshold vesicles
def postpocess_compartments():
pass
2 changes: 1 addition & 1 deletion synaptic_reconstruction/inference/postprocessing/ribbon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def segment_ribbon(
n_slices_exclude: The number of slices to exclude on the top / bottom
in order to avoid segmentation errors due to imaging artifacts in top and bottom.
n_ribbons: The number of ribbons in the tomogram.
max_vesicle_distance: The maximal distance to associate a vesicle with a ribbon.
max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon.
"""
assert ribbon_prediction.shape == vesicle_segmentation.shape

Expand Down
21 changes: 20 additions & 1 deletion synaptic_reconstruction/napari.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
name: synaptic_reconstruction
display_name: SynapseNet
# see https://napari.org/stable/plugins/manifest.html for valid categories

# See https://napari.org/stable/plugins/manifest.html for valid categories.
categories: ["Image Processing", "Annotation"]

contributions:
commands:
# Commands for widgets.
- id: synaptic_reconstruction.segment
python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget
title: Segment
Expand All @@ -20,6 +23,14 @@ contributions:
python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget
title: Vesicle Pooling

# Commands for sample data.
- id: synaptic_reconstruction.sample_data_tem_2d
python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d
title: Load TEM 2D sample data
- id: synaptic_reconstruction.sample_data_tem_tomo
python_name: synaptic_reconstruction.sample_data:sample_data_tem_tomo
title: Load TEM Tomo sample data

readers:
- command: synaptic_reconstruction.file_reader
filename_patterns:
Expand All @@ -37,3 +48,11 @@ contributions:
display_name: Morphology Analysis
- command: synaptic_reconstruction.vesicle_pooling
display_name: Vesicle Pooling

sample_data:
- command: synaptic_reconstruction.sample_data_tem_2d
display_name: TEM 2D Sample Data
key: synapse-net-tem-2d
- command: synaptic_reconstruction.sample_data_tem_tomo
display_name: TEM Tomo Sample Data
key: synapse-net-tem-tomo
22 changes: 21 additions & 1 deletion synaptic_reconstruction/sample_data.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import os
import pooch

from .file_utils import read_mrc


def get_sample_data(name: str) -> str:
"""Get the filepath to SynapseNet sample data, stored as mrc file.

Args:
name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data.
name: The name of the sample data. Currently, we only provide 'tem_2d' and 'tem_tomo'.

Returns:
The filepath to the downloaded sample data.
"""
registry = {
"tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28",
"tem_tomo.mrc": "24af31a10761b59fa6ad9f0e763f8f084304e4f31c59b482dd09dde8cd443ed7",
}
urls = {
"tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download",
"tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/NeP7gOv76Vj26lm/download",
}
key = f"{name}.mrc"

Expand All @@ -32,3 +36,19 @@ def get_sample_data(name: str) -> str:
)
file_path = data_registry.fetch(key)
return file_path


def _sample_data(name):
file_path = get_sample_data(name)
data, voxel_size = read_mrc(file_path)
metadata = {"file_path": file_path, "voxel_size": voxel_size}
add_image_kwargs = {"name": name, "metadata": metadata, "colormap": "gray"}
return [(data, add_image_kwargs)]


def sample_data_tem_2d():
return _sample_data("tem_2d")


def sample_data_tem_tomo():
return _sample_data("tem_tomo")
21 changes: 11 additions & 10 deletions synaptic_reconstruction/tools/base_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def __init__(self):
self.attribute_dict = {}

def _create_layer_selector(self, selector_name, layer_type="Image"):
"""
Create a layer selector for an image or labels and store it in a dictionary.
"""Create a layer selector for an image or labels and store it in a dictionary.

Parameters:
- selector_name (str): The name of the selector, used as a key in the dictionary.
- layer_type (str): The type of layer to filter for ("Image" or "Labels").
Args:
selector_name (str): The name of the selector, used as a key in the dictionary.
layer_type (str): The type of layer to filter for ("Image" or "Labels").
"""
if not hasattr(self, "layer_selectors"):
self.layer_selectors = {}
Expand Down Expand Up @@ -286,17 +285,19 @@ def _get_file_path(self, name, textbox, tooltip=None):
# Handle the case where the selected path is not a file
print("Invalid file selected. Please try again.")

def _handle_resolution(self, metadata, voxel_size_param, ndim):
def _handle_resolution(self, metadata, voxel_size_param, ndim, return_as_list=True):
# Get the resolution / voxel size from the layer metadata if available.
resolution = metadata.get("voxel_size", None)
if resolution is not None:
resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")]

# If user input was given then override resolution from metadata.
axes = "zyx" if ndim == 3 else "yx"
if voxel_size_param.value() != 0.0: # Changed from default.
resolution = ndim * [voxel_size_param.value()]
resolution = {ax: voxel_size_param.value() for ax in axes}

if resolution is not None and return_as_list:
resolution = [resolution[ax] for ax in axes]
assert len(resolution) == ndim

assert len(resolution) == ndim
return resolution

def _save_table(self, save_path, data):
Expand Down
52 changes: 32 additions & 20 deletions synaptic_reconstruction/tools/segmentation_widget.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import copy

import napari
import numpy as np

from napari.utils.notifications import show_info
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox

from .base_widget import BaseWidget
from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device,
get_current_tiling, compute_scale_from_voxel_size, load_custom_model)
from synaptic_reconstruction.inference.util import get_default_tiling
import copy
from ..inference.util import get_default_tiling


class SegmentationWidget(BaseWidget):
Expand Down Expand Up @@ -79,37 +82,41 @@ def on_predict(self):
show_info("INFO: Please choose an image.")
return

# load current tiling
# Get the current tiling.
self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type)

# Get the voxel size.
metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True)
voxel_size = metadata.get("voxel_size", None)
scale = None
voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False)

if self.voxel_size_param.value() != 0.0: # changed from default
voxel_size = {}
# override voxel size with user input
if len(image.shape) == 3:
voxel_size["x"] = self.voxel_size_param.value()
voxel_size["y"] = self.voxel_size_param.value()
voxel_size["z"] = self.voxel_size_param.value()
else:
voxel_size["x"] = self.voxel_size_param.value()
voxel_size["y"] = self.voxel_size_param.value()
# Determine the scaling based on the voxel size.
scale = None
if voxel_size:
if model_type == "custom":
show_info("INFO: The image is not rescaled for a custom model.")
else:
# calculate scale so voxel_size is the same as in training
scale = compute_scale_from_voxel_size(voxel_size, model_type)
show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.")

scale_info = list(map(lambda x: np.round(x, 2), scale))
show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.")

# Some models require an additional segmentation for inference or postprocessing.
# For these models we read out the 'Extra Segmentation' widget.
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
kwargs = {"extra_segmentation": extra_seg}
else:
kwargs = {}
segmentation = run_segmentation(
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale
image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs
)

# Add the segmentation layer
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
# Add the segmentation layer(s).
if isinstance(segmentation, dict):
for name, seg in segmentation.items():
self.viewer.add_labels(seg, name=name, metadata=metadata)
else:
self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata)
show_info(f"INFO: Segmentation of {model_type} added to layers.")

def _create_settings_widget(self):
Expand Down Expand Up @@ -156,5 +163,10 @@ def _create_settings_widget(self):
)
setting_values.layout().addLayout(layout)

# Add selection UI for additional segmentation, which some models require for inference or postproc.
self.extra_seg_selector_name = "Extra Segmentation"
self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels")
setting_values.layout().addWidget(self.extra_selector_widget)

settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
return settings
Loading
Loading