From 1cc595ebc4ae871f5199cf3606504bd67d975c6e Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 6 Dec 2024 08:46:35 +0100 Subject: [PATCH 1/2] Add ribbon model and refactor IO functionality --- synaptic_reconstruction/file_utils.py | 60 ++++++++++++++++++- .../inference/postprocessing/compartments.py | 7 +++ .../inference/postprocessing/ribbon.py | 2 +- synaptic_reconstruction/napari.yaml | 15 ++++- synaptic_reconstruction/sample_data.py | 10 ++++ synaptic_reconstruction/tools/base_widget.py | 21 +++---- .../tools/segmentation_widget.py | 52 +++++++++------- synaptic_reconstruction/tools/util.py | 59 +++++++++++++++--- .../tools/volume_reader.py | 42 ++----------- 9 files changed, 189 insertions(+), 79 deletions(-) create mode 100644 synaptic_reconstruction/inference/postprocessing/compartments.py diff --git a/synaptic_reconstruction/file_utils.py b/synaptic_reconstruction/file_utils.py index d88a31d..e70f93c 100644 --- a/synaptic_reconstruction/file_utils.py +++ b/synaptic_reconstruction/file_utils.py @@ -1,5 +1,8 @@ import os -from typing import List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union + +import mrcfile +import numpy as np def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, List[str]]: @@ -23,3 +26,58 @@ def get_data_path(folder: str, n_tomograms: Optional[int] = 1) -> Union[str, Lis return tomograms assert len(tomograms) == n_tomograms, f"{folder}: {len(tomograms)}, {n_tomograms}" return tomograms[0] if n_tomograms == 1 else tomograms + + +def _parse_voxel_size(voxel_size): + parsed_voxel_size = None + try: + # The voxel sizes are stored in Angsrrom in the MRC header, but we want them + # in nanometer. Hence we divide by a factor of 10 here. + parsed_voxel_size = { + "x": voxel_size.x / 10, + "y": voxel_size.y / 10, + "z": voxel_size.z / 10, + } + except Exception as e: + print(f"Failed to read voxel size: {e}") + return parsed_voxel_size + + +def read_voxel_size(path: str) -> Dict[str, float] | None: + """Read voxel size from mrc/rec file. + + The original unit of voxel size is Angstrom and we convert it to nanometers by dividing it by ten. + + Args: + path: Path to mrc/rec file. + + Returns: + Mapping from the axis name to voxel size. None if the voxel size could not be read. + """ + with mrcfile.open(path, permissive=True) as mrc: + voxel_size = _parse_voxel_size(mrc.voxel_size) + return voxel_size + + +# TODO: double check axis ordering with elf +def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]: + """Read data and voxel size from mrc/rec file. + + Args: + path: Path to mrc/rec file. + + Returns: + The data read from the file. + The voxel size read from the file. + """ + with mrcfile.open(path, permissive=True) as mrc: + voxel_size = _parse_voxel_size(mrc.voxel_size) + data = np.asarray(mrc.data[:]) + + # Transpose the data to match python axis order. + if data.ndim == 3: + data = np.flip(data, axis=1) + else: + data = np.flip(data, axis=0) + + return data, voxel_size diff --git a/synaptic_reconstruction/inference/postprocessing/compartments.py b/synaptic_reconstruction/inference/postprocessing/compartments.py new file mode 100644 index 0000000..977d84f --- /dev/null +++ b/synaptic_reconstruction/inference/postprocessing/compartments.py @@ -0,0 +1,7 @@ + + +# TODO +# - merge compartments which share vesicles (based on threshold for merging) +# - filter out compartments with less than some threshold vesicles +def postpocess_compartments(): + pass diff --git a/synaptic_reconstruction/inference/postprocessing/ribbon.py b/synaptic_reconstruction/inference/postprocessing/ribbon.py index 644a06d..8e24476 100644 --- a/synaptic_reconstruction/inference/postprocessing/ribbon.py +++ b/synaptic_reconstruction/inference/postprocessing/ribbon.py @@ -20,7 +20,7 @@ def segment_ribbon( n_slices_exclude: The number of slices to exclude on the top / bottom in order to avoid segmentation errors due to imaging artifacts in top and bottom. n_ribbons: The number of ribbons in the tomogram. - max_vesicle_distance: The maximal distance to associate a vesicle with a ribbon. + max_vesicle_distance: The maximal distance in pixels to associate a vesicle with a ribbon. """ assert ribbon_prediction.shape == vesicle_segmentation.shape diff --git a/synaptic_reconstruction/napari.yaml b/synaptic_reconstruction/napari.yaml index b36c72f..52364a3 100644 --- a/synaptic_reconstruction/napari.yaml +++ b/synaptic_reconstruction/napari.yaml @@ -1,9 +1,12 @@ name: synaptic_reconstruction display_name: SynapseNet -# see https://napari.org/stable/plugins/manifest.html for valid categories + +# See https://napari.org/stable/plugins/manifest.html for valid categories. categories: ["Image Processing", "Annotation"] + contributions: commands: + # Commands for widgets. - id: synaptic_reconstruction.segment python_name: synaptic_reconstruction.tools.segmentation_widget:SegmentationWidget title: Segment @@ -20,6 +23,11 @@ contributions: python_name: synaptic_reconstruction.tools.vesicle_pool_widget:VesiclePoolWidget title: Vesicle Pooling + # Commands for sample data. + - id: synaptic_reconstruction.sample_data_tem_2d + python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d + title: Load TEM 2D sample data + readers: - command: synaptic_reconstruction.file_reader filename_patterns: @@ -37,3 +45,8 @@ contributions: display_name: Morphology Analysis - command: synaptic_reconstruction.vesicle_pooling display_name: Vesicle Pooling + + sample_data: + - command: synaptic_reconstruction.sample_data_tem_2d + display_name: TEM 2D Sample Data + key: synapse-net-tem-2d diff --git a/synaptic_reconstruction/sample_data.py b/synaptic_reconstruction/sample_data.py index c0a3e47..8dc699b 100644 --- a/synaptic_reconstruction/sample_data.py +++ b/synaptic_reconstruction/sample_data.py @@ -1,6 +1,8 @@ import os import pooch +from .file_utils import read_mrc + def get_sample_data(name: str) -> str: """Get the filepath to SynapseNet sample data, stored as mrc file. @@ -32,3 +34,11 @@ def get_sample_data(name: str) -> str: ) file_path = data_registry.fetch(key) return file_path + + +def sample_data_tem_2d(): + file_path = get_sample_data("tem_2d") + data, voxel_size = read_mrc(file_path) + metadata = {"file_path": file_path, "voxel_size": voxel_size} + add_image_kwargs = {"name": "tem_2d", "metadata": metadata, "colormap": "gray"} + return [(data, add_image_kwargs)] diff --git a/synaptic_reconstruction/tools/base_widget.py b/synaptic_reconstruction/tools/base_widget.py index ca39e8a..1195488 100644 --- a/synaptic_reconstruction/tools/base_widget.py +++ b/synaptic_reconstruction/tools/base_widget.py @@ -23,12 +23,11 @@ def __init__(self): self.attribute_dict = {} def _create_layer_selector(self, selector_name, layer_type="Image"): - """ - Create a layer selector for an image or labels and store it in a dictionary. + """Create a layer selector for an image or labels and store it in a dictionary. - Parameters: - - selector_name (str): The name of the selector, used as a key in the dictionary. - - layer_type (str): The type of layer to filter for ("Image" or "Labels"). + Args: + selector_name (str): The name of the selector, used as a key in the dictionary. + layer_type (str): The type of layer to filter for ("Image" or "Labels"). """ if not hasattr(self, "layer_selectors"): self.layer_selectors = {} @@ -286,17 +285,19 @@ def _get_file_path(self, name, textbox, tooltip=None): # Handle the case where the selected path is not a file print("Invalid file selected. Please try again.") - def _handle_resolution(self, metadata, voxel_size_param, ndim): + def _handle_resolution(self, metadata, voxel_size_param, ndim, return_as_list=True): # Get the resolution / voxel size from the layer metadata if available. resolution = metadata.get("voxel_size", None) - if resolution is not None: - resolution = [resolution[ax] for ax in ("zyx" if ndim == 3 else "yx")] # If user input was given then override resolution from metadata. + axes = "zyx" if ndim == 3 else "yx" if voxel_size_param.value() != 0.0: # Changed from default. - resolution = ndim * [voxel_size_param.value()] + resolution = {ax: voxel_size_param.value() for ax in axes} + + if resolution is not None and return_as_list: + resolution = [resolution[ax] for ax in axes] + assert len(resolution) == ndim - assert len(resolution) == ndim return resolution def _save_table(self, save_path, data): diff --git a/synaptic_reconstruction/tools/segmentation_widget.py b/synaptic_reconstruction/tools/segmentation_widget.py index 548a465..6fbb0cc 100644 --- a/synaptic_reconstruction/tools/segmentation_widget.py +++ b/synaptic_reconstruction/tools/segmentation_widget.py @@ -1,12 +1,15 @@ +import copy + import napari +import numpy as np + from napari.utils.notifications import show_info from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox from .base_widget import BaseWidget from .util import (run_segmentation, get_model, get_model_registry, _available_devices, get_device, get_current_tiling, compute_scale_from_voxel_size, load_custom_model) -from synaptic_reconstruction.inference.util import get_default_tiling -import copy +from ..inference.util import get_default_tiling class SegmentationWidget(BaseWidget): @@ -79,37 +82,41 @@ def on_predict(self): show_info("INFO: Please choose an image.") return - # load current tiling + # Get the current tiling. self.tiling = get_current_tiling(self.tiling, self.default_tiling, model_type) + # Get the voxel size. metadata = self._get_layer_selector_data(self.image_selector_name, return_metadata=True) - voxel_size = metadata.get("voxel_size", None) - scale = None + voxel_size = self._handle_resolution(metadata, self.voxel_size_param, image.ndim, return_as_list=False) - if self.voxel_size_param.value() != 0.0: # changed from default - voxel_size = {} - # override voxel size with user input - if len(image.shape) == 3: - voxel_size["x"] = self.voxel_size_param.value() - voxel_size["y"] = self.voxel_size_param.value() - voxel_size["z"] = self.voxel_size_param.value() - else: - voxel_size["x"] = self.voxel_size_param.value() - voxel_size["y"] = self.voxel_size_param.value() + # Determine the scaling based on the voxel size. + scale = None if voxel_size: if model_type == "custom": show_info("INFO: The image is not rescaled for a custom model.") else: # calculate scale so voxel_size is the same as in training scale = compute_scale_from_voxel_size(voxel_size, model_type) - show_info(f"INFO: Rescaled the image by {scale} to optimize for the selected model.") - + scale_info = list(map(lambda x: np.round(x, 2), scale)) + show_info(f"INFO: Rescaled the image by {scale_info} to optimize for the selected model.") + + # Some models require an additional segmentation for inference or postprocessing. + # For these models we read out the 'Extra Segmentation' widget. + if model_type == "ribbon": # Currently only the ribbon model needs the extra seg. + extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name) + kwargs = {"extra_segmentation": extra_seg} + else: + kwargs = {} segmentation = run_segmentation( - image, model=model, model_type=model_type, tiling=self.tiling, scale=scale + image, model=model, model_type=model_type, tiling=self.tiling, scale=scale, **kwargs ) - # Add the segmentation layer - self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata) + # Add the segmentation layer(s). + if isinstance(segmentation, dict): + for name, seg in segmentation.items(): + self.viewer.add_labels(seg, name=name, metadata=metadata) + else: + self.viewer.add_labels(segmentation, name=f"{model_type}-segmentation", metadata=metadata) show_info(f"INFO: Segmentation of {model_type} added to layers.") def _create_settings_widget(self): @@ -156,5 +163,10 @@ def _create_settings_widget(self): ) setting_values.layout().addLayout(layout) + # Add selection UI for additional segmentation, which some models require for inference or postproc. + self.extra_seg_selector_name = "Extra Segmentation" + self.extra_selector_widget = self._create_layer_selector(self.extra_seg_selector_name, layer_type="Labels") + setting_values.layout().addWidget(self.extra_selector_widget) + settings = self._make_collapsible(widget=setting_values, title="Advanced Settings") return settings diff --git a/synaptic_reconstruction/tools/util.py b/synaptic_reconstruction/tools/util.py index 0cedf0b..a99e177 100644 --- a/synaptic_reconstruction/tools/util.py +++ b/synaptic_reconstruction/tools/util.py @@ -9,6 +9,7 @@ from ..inference.active_zone import segment_active_zone from ..inference.compartments import segment_compartments from ..inference.mitochondria import segment_mitochondria +from ..inference.ribbon_synapse import segment_ribbon_synapse_structures from ..inference.vesicles import segment_vesicles @@ -43,8 +44,8 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None """Get the model for the given segmentation type. Args: - model_type: The model type. - One of 'vesicles', 'mitochondria', 'active_zone', 'compartments' or 'inner_ear_structures'. + model_type: The model type. You can choose One of: + 'vesicles_3d', 'active_zone', 'compartments', 'mitochondria', 'ribbon', 'vesicles_2d', 'vesicles_cryo'. device: The device to use. Returns: @@ -58,6 +59,44 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None return model +def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): + # Parse additional keyword arguments from the kwargs. + vesicles = kwargs.pop("extra_segmentation") + threshold = kwargs.pop("threshold", 0.5) + n_slices_exclude = kwargs.pop("n_slices_exclude", 20) + n_ribbons = kwargs.pop("n_slices_exclude", 1) + + predictions = segment_ribbon_synapse_structures( + image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs + ) + + # If the vesicles were passed then run additional post-processing. + if vesicles is None: + from synaptic_reconstruction.inference.postprocessing import ( + segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, + ) + + ribbon = segment_ribbon( + predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, + max_vesicle_distance=40, + ) + PD = segment_presynaptic_density( + predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, + ) + ref_segmentation = PD if PD.sum() > 0 else ribbon + membrane = segment_membrane_distance_based( + predictions["membrane"], ref_segmentation, n_sclices_exclude=n_slices_exclude, max_distance=500 + ) + + segmentation = {"ribbon": ribbon, "PD": PD, "membrane": membrane} + + # Otherwise, just return the predictions. + else: + segmentation = predictions + + return segmentation + + def run_segmentation( image: np.ndarray, model: torch.nn.Module, @@ -66,22 +105,21 @@ def run_segmentation( scale: Optional[List[float]] = None, verbose: bool = False, **kwargs, -) -> np.ndarray: +) -> np.ndarray | Dict[str, np.ndarray]: """Run synaptic structure segmentation. Args: image: The input image or image volume. model: The segmentation model. - model_type: The model type. This will determine which segmentation - post-processing is used. + model_type: The model type. This will determine which segmentation post-processing is used. tiling: The tiling settings for inference. scale: A scale factor for resizing the input before applying the model. The output will be scaled back to the initial size. verbose: Whether to print detailed information about the prediction and segmentation. - kwargs: Optional parameter for the segmentation function. + kwargs: Optional parameters for the segmentation function. Returns: - The segmentation. + The segmentation. For models that return multiple segmentations, this function returns a dictionary. """ if model_type.startswith("vesicles"): segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) @@ -91,8 +129,8 @@ def run_segmentation( segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) elif model_type == "compartments": segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) - elif model_type == "ribbon_synapse_structures": - raise NotImplementedError + elif model_type == "ribbon": + segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs) else: raise ValueError(f"Unknown model type: {model_type}") return segmentation @@ -108,6 +146,7 @@ def get_model_training_resolution(model_type): "active_zone": {"x": 1.44, "y": 1.44, "z": 1.44}, "compartments": {"x": 3.47, "y": 3.47, "z": 3.47}, "mitochondria": {"x": 2.07, "y": 2.07, "z": 2.07}, + "ribbon": {"x": 1.188, "y": 1.188, "z": 1.188}, "vesicles_2d": {"x": 1.35, "y": 1.35}, "vesicles_3d": {"x": 1.35, "y": 1.35, "z": 1.35}, "vesicles_cryo": {"x": 1.35, "y": 1.35, "z": 0.88}, @@ -120,6 +159,7 @@ def get_model_registry(): "active_zone": "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0", "compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1", "mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186", + "ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9", "vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1", "vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29", "vesicles_cryo": "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b", @@ -128,6 +168,7 @@ def get_model_registry(): "active_zone": "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download", "compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download", "mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download", + "ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download", "vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download", "vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download", "vesicles_cryo": "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download", diff --git a/synaptic_reconstruction/tools/volume_reader.py b/synaptic_reconstruction/tools/volume_reader.py index 961dfa5..4e0c7ff 100644 --- a/synaptic_reconstruction/tools/volume_reader.py +++ b/synaptic_reconstruction/tools/volume_reader.py @@ -1,10 +1,10 @@ import os +from typing import Callable, List, Optional, Sequence, Union -from typing import Callable, Dict, List, Optional, Sequence, Union -import mrcfile +from elf.io import open_file, is_dataset from napari.types import LayerData +from synaptic_reconstruction.file_utils import read_mrc -from elf.io import open_file, is_dataset PathLike = str PathOrPaths = Union[PathLike, Sequence[PathLike]] @@ -19,21 +19,14 @@ def get_reader(path: PathOrPaths) -> Optional[ReaderFunction]: return None -# For mrcfiles we just read the data from it. def _read_mrc(path, fname): - with open_file(path, mode="r") as f: - data = f["data"][:] - voxel_size = read_voxel_size(path) - metadata = { - "file_path": path, - "voxel_size": voxel_size - } + data, voxel_size = read_mrc(path) + metadata = {"file_path": path, "voxel_size": voxel_size} layer_attributes = { "name": fname, "colormap": "gray", "metadata": metadata } - return [(data, layer_attributes)] @@ -72,28 +65,3 @@ def read_image_volume(path: PathOrPaths) -> List[LayerData]: except Exception as e: print(f"Failed to read file: {e}") return - - -def read_voxel_size(input_path: str) -> Dict[str, float] | None: - """Read voxel size from mrc/rec file and store it in layer_attributes. - The original unit of voxel size is Angstrom and we convert it to nanometers - by dividing it by ten. - - Args: - input_path: Path to mrc/rec file. - - Returns: - Mapping from the axis name to voxel size. None if the voxel size could not be read. - """ - new_voxel_size = None - with mrcfile.open(input_path, permissive=True) as mrc: - try: - voxel_size = mrc.voxel_size - new_voxel_size = { - "x": voxel_size.x / 10, - "y": voxel_size.y / 10, - "z": voxel_size.z / 10, - } - except Exception as e: - print(f"Failed to read voxel size: {e}") - return new_voxel_size From 71f9b2c11336a7846ef38354817b58dfebe6bf7b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 6 Dec 2024 09:05:52 +0100 Subject: [PATCH 2/2] Add new sample data and test for file utils --- synaptic_reconstruction/file_utils.py | 8 ++--- synaptic_reconstruction/napari.yaml | 6 ++++ synaptic_reconstruction/sample_data.py | 18 +++++++--- test/test_file_utils.py | 46 ++++++++++++++++++++++++++ 4 files changed, 68 insertions(+), 10 deletions(-) create mode 100644 test/test_file_utils.py diff --git a/synaptic_reconstruction/file_utils.py b/synaptic_reconstruction/file_utils.py index e70f93c..6b54e75 100644 --- a/synaptic_reconstruction/file_utils.py +++ b/synaptic_reconstruction/file_utils.py @@ -59,7 +59,6 @@ def read_voxel_size(path: str) -> Dict[str, float] | None: return voxel_size -# TODO: double check axis ordering with elf def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]: """Read data and voxel size from mrc/rec file. @@ -73,11 +72,8 @@ def read_mrc(path: str) -> Tuple[np.ndarray, Dict[str, float]]: with mrcfile.open(path, permissive=True) as mrc: voxel_size = _parse_voxel_size(mrc.voxel_size) data = np.asarray(mrc.data[:]) + assert data.ndim in (2, 3) # Transpose the data to match python axis order. - if data.ndim == 3: - data = np.flip(data, axis=1) - else: - data = np.flip(data, axis=0) - + data = np.flip(data, axis=1) if data.ndim == 3 else np.flip(data, axis=0) return data, voxel_size diff --git a/synaptic_reconstruction/napari.yaml b/synaptic_reconstruction/napari.yaml index 52364a3..578dc80 100644 --- a/synaptic_reconstruction/napari.yaml +++ b/synaptic_reconstruction/napari.yaml @@ -27,6 +27,9 @@ contributions: - id: synaptic_reconstruction.sample_data_tem_2d python_name: synaptic_reconstruction.sample_data:sample_data_tem_2d title: Load TEM 2D sample data + - id: synaptic_reconstruction.sample_data_tem_tomo + python_name: synaptic_reconstruction.sample_data:sample_data_tem_tomo + title: Load TEM Tomo sample data readers: - command: synaptic_reconstruction.file_reader @@ -50,3 +53,6 @@ contributions: - command: synaptic_reconstruction.sample_data_tem_2d display_name: TEM 2D Sample Data key: synapse-net-tem-2d + - command: synaptic_reconstruction.sample_data_tem_tomo + display_name: TEM Tomo Sample Data + key: synapse-net-tem-tomo diff --git a/synaptic_reconstruction/sample_data.py b/synaptic_reconstruction/sample_data.py index 8dc699b..85ca481 100644 --- a/synaptic_reconstruction/sample_data.py +++ b/synaptic_reconstruction/sample_data.py @@ -8,16 +8,18 @@ def get_sample_data(name: str) -> str: """Get the filepath to SynapseNet sample data, stored as mrc file. Args: - name: The name of the sample data. Currently, we only provide the 'tem_2d' sample data. + name: The name of the sample data. Currently, we only provide 'tem_2d' and 'tem_tomo'. Returns: The filepath to the downloaded sample data. """ registry = { "tem_2d.mrc": "3c6f9ff6d7673d9bf2fd46c09750c3c7dbb8fa1aa59dcdb3363b65cc774dcf28", + "tem_tomo.mrc": "24af31a10761b59fa6ad9f0e763f8f084304e4f31c59b482dd09dde8cd443ed7", } urls = { "tem_2d.mrc": "https://owncloud.gwdg.de/index.php/s/5sAQ0U4puAspcHg/download", + "tem_tomo.mrc": "https://owncloud.gwdg.de/index.php/s/NeP7gOv76Vj26lm/download", } key = f"{name}.mrc" @@ -36,9 +38,17 @@ def get_sample_data(name: str) -> str: return file_path -def sample_data_tem_2d(): - file_path = get_sample_data("tem_2d") +def _sample_data(name): + file_path = get_sample_data(name) data, voxel_size = read_mrc(file_path) metadata = {"file_path": file_path, "voxel_size": voxel_size} - add_image_kwargs = {"name": "tem_2d", "metadata": metadata, "colormap": "gray"} + add_image_kwargs = {"name": name, "metadata": metadata, "colormap": "gray"} return [(data, add_image_kwargs)] + + +def sample_data_tem_2d(): + return _sample_data("tem_2d") + + +def sample_data_tem_tomo(): + return _sample_data("tem_tomo") diff --git a/test/test_file_utils.py b/test/test_file_utils.py new file mode 100644 index 0000000..362f87f --- /dev/null +++ b/test/test_file_utils.py @@ -0,0 +1,46 @@ +import unittest + +import numpy as np +from elf.io import open_file +from synaptic_reconstruction.sample_data import get_sample_data + + +class TestFileUtils(unittest.TestCase): + + def test_read_mrc_2d(self): + from synaptic_reconstruction.file_utils import read_mrc + + file_path = get_sample_data("tem_2d") + data, voxel_size = read_mrc(file_path) + + with open_file(file_path, "r") as f: + data_exp = f["data"][:] + + self.assertTrue(data.shape, data_exp.shape) + self.assertTrue(np.allclose(data, data_exp)) + + resolution = 0.592 + self.assertTrue(np.isclose(voxel_size["x"], resolution)) + self.assertTrue(np.isclose(voxel_size["y"], resolution)) + self.assertTrue(np.isclose(voxel_size["z"], 0.0)) + + def test_read_mrc_3d(self): + from synaptic_reconstruction.file_utils import read_mrc + + file_path = get_sample_data("tem_tomo") + data, voxel_size = read_mrc(file_path) + + with open_file(file_path, "r") as f: + data_exp = f["data"][:] + + self.assertTrue(data.shape, data_exp.shape) + self.assertTrue(np.allclose(data, data_exp)) + + resolution = 1.554 + self.assertTrue(np.isclose(voxel_size["x"], resolution)) + self.assertTrue(np.isclose(voxel_size["y"], resolution)) + self.assertTrue(np.isclose(voxel_size["z"], resolution)) + + +if __name__ == "__main__": + unittest.main()