diff --git a/simpa/core/simulation_modules/volume_creation_module/segmentation_based_adapter.py b/simpa/core/simulation_modules/volume_creation_module/segmentation_based_adapter.py index bf2358d8..b11f3296 100644 --- a/simpa/core/simulation_modules/volume_creation_module/segmentation_based_adapter.py +++ b/simpa/core/simulation_modules/volume_creation_module/segmentation_based_adapter.py @@ -21,12 +21,22 @@ class SegmentationBasedAdapter(VolumeCreationAdapterBase): def create_simulation_volume(self) -> dict: volumes, x_dim_px, y_dim_px, z_dim_px = self.create_empty_volumes() wavelength = self.global_settings[Tags.WAVELENGTH] - for key in volumes.keys(): - volumes[key] = volumes[key].to('cpu') - segmentation_volume = self.component_settings[Tags.INPUT_SEGMENTATION_VOLUME] - segmentation_classes = np.unique(segmentation_volume, return_counts=False) - x_dim_seg_px, y_dim_seg_px, z_dim_seg_px = np.shape(segmentation_volume) + segmentation_volume = torch.tensor(self.component_settings[Tags.INPUT_SEGMENTATION_VOLUME], device=self.torch_device) + class_mapping = self.component_settings[Tags.SEGMENTATION_CLASS_MAPPING] + + if torch.is_floating_point(segmentation_volume): + assert len(segmentation_volume.shape) == 4 and segmentation_volume.shape[0] == len(class_mapping), \ + "Fuzzy segmentation must be a 4D array with the first dimension being the number of classes." + fuzzy = True + segmentation_classes = np.arange(segmentation_volume.shape[0]) + + else: + assert len(segmentation_volume.shape) == 3, "Hard segmentations must be a 3D array." + fuzzy = False + segmentation_classes = torch.unique(segmentation_volume, return_counts=False).cpu().numpy() + + x_dim_seg_px, y_dim_seg_px, z_dim_seg_px = np.shape(segmentation_volume)[-3:] if x_dim_px != x_dim_seg_px: raise ValueError("x_dim of volumes and segmentation must perfectly match but was {} and {}" @@ -38,8 +48,6 @@ def create_simulation_volume(self) -> dict: raise ValueError("z_dim of volumes and segmentation must perfectly match but was {} and {}" .format(z_dim_px, z_dim_seg_px)) - class_mapping = self.component_settings[Tags.SEGMENTATION_CLASS_MAPPING] - for seg_class in segmentation_classes: class_properties = class_mapping[seg_class].get_properties_for_wavelength(self.global_settings, wavelength) for volume_key in volumes.keys(): @@ -47,7 +55,10 @@ def create_simulation_volume(self) -> dict: assigned_prop = class_properties[volume_key] if assigned_prop is None: assigned_prop = torch.nan - volumes[volume_key][segmentation_volume == seg_class] = assigned_prop + if fuzzy: + volumes[volume_key] += segmentation_volume[seg_class] * assigned_prop + else: + volumes[volume_key][segmentation_volume == seg_class] = assigned_prop elif len(torch.Tensor.size(class_properties[volume_key])) == 3: # 3D map assigned_prop = class_properties[volume_key][torch.tensor(segmentation_volume == seg_class)] assigned_prop[assigned_prop is None] = torch.nan @@ -57,6 +68,6 @@ def create_simulation_volume(self) -> dict: # convert volumes back to CPU for key in volumes.keys(): - volumes[key] = volumes[key].numpy().astype(np.float64, copy=False) + volumes[key] = volumes[key].cpu().numpy().astype(np.float64, copy=False) return volumes diff --git a/simpa/utils/quality_assurance/data_sanity_testing.py b/simpa/utils/quality_assurance/data_sanity_testing.py index dd20c2f2..5ae3100b 100644 --- a/simpa/utils/quality_assurance/data_sanity_testing.py +++ b/simpa/utils/quality_assurance/data_sanity_testing.py @@ -20,12 +20,10 @@ def assert_equal_shapes(numpy_arrays: list): if len(numpy_arrays) < 2: return - shapes = np.asarray([np.shape(_arr) for _arr in numpy_arrays]).astype(float) - mean = np.mean(shapes, axis=0) - for i in range(len(shapes)): - shapes[i, :] = shapes[i, :] - mean + first_array_shape = numpy_arrays[0].shape + equal = ([_arr.shape == first_array_shape for _arr in numpy_arrays]) - if not np.sum(np.abs(shapes)) <= 1e-5: + if not all(equal): raise AssertionError("The given volumes did not all have the same" " dimensions. Please double check the simulation" f" parameters. Called from {inspect.stack()[1].function}") diff --git a/simpa_examples/segmentation_loader.py b/simpa_examples/segmentation_loader.py index a4d9cd39..e45ae890 100644 --- a/simpa_examples/segmentation_loader.py +++ b/simpa_examples/segmentation_loader.py @@ -6,7 +6,7 @@ import simpa as sp import numpy as np from skimage.data import shepp_logan_phantom -from scipy.ndimage import zoom +from scipy.ndimage import zoom, gaussian_filter from skimage.transform import resize # FIXME temporary workaround for newest Intel architectures @@ -20,8 +20,8 @@ @profile -def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | int = 0.2, path_manager=None, - visualise: bool = True): +def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | int = 0.2, fuzzy: bool = False, + path_manager=None, visualise: bool = True): """ :param spacing: The simulation spacing between voxels in mm @@ -30,12 +30,14 @@ def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | i :param visualise: If VISUALIZE is set to True, the reconstruction result will be plotted :return: a run through of the example """ + if path_manager is None: path_manager = sp.PathManager() + C = 11 # number of classes label_mask = shepp_logan_phantom() - label_mask = np.digitize(label_mask, bins=np.linspace(0.0, 1.0, 11), right=True) + label_mask = np.digitize(label_mask, bins=np.linspace(0.0, 1.0, C), right=True) label_mask = label_mask[100:300, 100:300] label_mask = np.reshape(label_mask, (label_mask.shape[0], 1, label_mask.shape[1])) @@ -43,6 +45,13 @@ def run_segmentation_loader(spacing: float | int = 1.0, input_spacing: float | i segmentation_volume_mask = sp.round_x5_away_from_zero(zoom(segmentation_volume_tiled, input_spacing/spacing, order=0)).astype(int) + if fuzzy: + segmentation_volume_mask = np.eye(C)[segmentation_volume_mask] + segmentation_volume_mask = np.moveaxis(segmentation_volume_mask, -1, 0) + segmentation_volume_mask = gaussian_filter(segmentation_volume_mask, sigma=1e-5, axes=(1, 2, 3)) # smooth the segmentation + segmentation_volume_mask /= segmentation_volume_mask.sum(axis=0, keepdims=True) + + def segmentation_class_mapping(): ret_dict = dict() ret_dict[0] = sp.TISSUE_LIBRARY.heavy_water() @@ -68,14 +77,14 @@ def segmentation_class_mapping(): settings[Tags.RANDOM_SEED] = 1234 settings[Tags.WAVELENGTHS] = [700, 800] settings[Tags.SPACING_MM] = spacing - settings[Tags.DIM_VOLUME_X_MM] = segmentation_volume_mask.shape[0] * spacing - settings[Tags.DIM_VOLUME_Y_MM] = segmentation_volume_mask.shape[1] * spacing - settings[Tags.DIM_VOLUME_Z_MM] = segmentation_volume_mask.shape[2] * spacing + x_dim_mm, y_dim_mm, z_dim_mm = segmentation_volume_mask.shape[-3:] + settings[Tags.DIM_VOLUME_X_MM] = x_dim_mm * spacing + settings[Tags.DIM_VOLUME_Y_MM] = y_dim_mm * spacing + settings[Tags.DIM_VOLUME_Z_MM] = z_dim_mm * spacing settings.set_volume_creation_settings({ Tags.INPUT_SEGMENTATION_VOLUME: segmentation_volume_mask, Tags.SEGMENTATION_CLASS_MAPPING: segmentation_class_mapping(), - }) settings.set_optical_settings({ @@ -108,9 +117,10 @@ def segmentation_class_mapping(): parser = ArgumentParser(description='Run the segmentation loader example') parser.add_argument("--spacing", default=1, type=float, help='the voxel spacing in mm') parser.add_argument("--input_spacing", default=0.2, type=float, help='the input spacing in mm') + parser.add_argument("--fuzzy", default=False, type=bool, help='whether to use fuzzy segmentation adapter') parser.add_argument("--path_manager", default=None, help='the path manager, None uses sp.PathManager') parser.add_argument("--visualise", default=True, type=bool, help='whether to visualise the result') config = parser.parse_args() - run_segmentation_loader(spacing=config.spacing, input_spacing=config.input_spacing, + run_segmentation_loader(spacing=config.spacing, input_spacing=config.input_spacing, fuzzy=config.fuzzy, path_manager=config.path_manager, visualise=config.visualise)