Skip to content

Commit

Permalink
Merge pull request #7 from funkelab/add_3d_example
Browse files Browse the repository at this point in the history
Add 3d example
  • Loading branch information
lmanan authored Mar 3, 2024
2 parents e8c8edb + b2f891f commit b15c21b
Show file tree
Hide file tree
Showing 13 changed files with 674 additions and 285 deletions.
19 changes: 13 additions & 6 deletions cellulus/configs/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ class InferenceConfig:
Configuration object produced by predict.py.
segmentation_dataset_config:
detection_dataset_config:
Configuration object produced by segment.py.
Configuration object produced by detect.py.
post_processed_dataset_config:
segmentation_dataset_config:
Configuration object produced by post_process.py.
Configuration object produced by segment.py.
evaluation_dataset_config:
Expand Down Expand Up @@ -72,6 +72,12 @@ class InferenceConfig:
How to cluster the embeddings?
Can be one of 'meanshift' or 'greedy'.
use_seeds (default = False):
If set to True, the local optima of the distance map from the
predicted object centers is used.
Else, seeds are determined by sklearn.cluster.MeanShift.
num_bandwidths (default = 1):
Number of bandwidths to obtain segmentations for.
Expand Down Expand Up @@ -118,11 +124,11 @@ class InferenceConfig:
default=None, converter=to_config(DatasetConfig)
)

segmentation_dataset_config: DatasetConfig = attrs.field(
detection_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)

post_processed_dataset_config: DatasetConfig = attrs.field(
segmentation_dataset_config: DatasetConfig = attrs.field(
default=None, converter=to_config(DatasetConfig)
)

Expand All @@ -139,6 +145,7 @@ class InferenceConfig:
clustering = attrs.field(
default="meanshift", validator=in_(["meanshift", "greedy"])
)
use_seeds = attrs.field(default=False, validator=instance_of(bool))
bandwidth = attrs.field(
default=None, validator=attrs.validators.optional(instance_of(float))
)
Expand Down
23 changes: 15 additions & 8 deletions cellulus/datasets/zarr_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,24 @@ def __yield_sample(self):

with gp.build(self.pipeline):
while True:
array_is_zero = True
# request one sample, all channels, plus crop dimensions
request = gp.BatchRequest()
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * self.num_dims, (1, self.num_channels, *self.crop_size)
while array_is_zero:
request = gp.BatchRequest()
request[self.raw] = gp.ArraySpec(
roi=gp.Roi(
(0,) * self.num_dims,
(1, self.num_channels, *self.crop_size),
)
)
)

sample = self.pipeline.request_batch(request)
sample_data = sample[self.raw].data[0]
anchor_samples, reference_samples = self.sample_coordinates()
sample = self.pipeline.request_batch(request)
sample_data = sample[self.raw].data[0]
if np.max(sample_data) <= 0.0:
pass
else:
array_is_zero = False
anchor_samples, reference_samples = self.sample_coordinates()
yield sample_data, anchor_samples, reference_samples

def __read_meta_data(self):
Expand Down
192 changes: 192 additions & 0 deletions cellulus/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import numpy as np
import zarr
from scipy.ndimage import gaussian_filter
from skimage.feature import peak_local_max
from skimage.filters import threshold_otsu
from tqdm import tqdm

from cellulus.configs.inference_config import InferenceConfig
from cellulus.datasets.meta_data import DatasetMetaData
from cellulus.utils.greedy_cluster import Cluster2d, Cluster3d
from cellulus.utils.mean_shift import mean_shift_segmentation


def detect(inference_config: InferenceConfig) -> None:
dataset_config = inference_config.dataset_config
dataset_meta_data = DatasetMetaData.from_dataset_config(dataset_config)

f = zarr.open(inference_config.detection_dataset_config.container_path)
ds = f[inference_config.detection_dataset_config.secondary_dataset_name]

# prepare the zarr dataset to write to
f_detection = zarr.open(inference_config.detection_dataset_config.container_path)
ds_detection = f_detection.create_dataset(
inference_config.detection_dataset_config.dataset_name,
shape=(
dataset_meta_data.num_samples,
inference_config.num_bandwidths,
*dataset_meta_data.spatial_array,
),
dtype=np.uint16,
)

ds_detection.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
-dataset_meta_data.num_spatial_dims :
]
ds_detection.attrs["resolution"] = (1,) * dataset_meta_data.num_spatial_dims
ds_detection.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims

# prepare the binary segmentation zarr dataset to write to
ds_binary_segmentation = f_detection.create_dataset(
"binary-segmentation",
shape=(
dataset_meta_data.num_samples,
1,
*dataset_meta_data.spatial_array,
),
dtype=np.uint16,
)

ds_binary_segmentation.attrs["axis_names"] = ["s", "c"] + ["t", "z", "y", "x"][
-dataset_meta_data.num_spatial_dims :
]
ds_binary_segmentation.attrs["resolution"] = (
1,
) * dataset_meta_data.num_spatial_dims
ds_binary_segmentation.attrs["offset"] = (0,) * dataset_meta_data.num_spatial_dims

# prepare the object centered embeddings zarr dataset to write to
ds_object_centered_embeddings = f_detection.create_dataset(
"centered-embeddings",
shape=(
dataset_meta_data.num_samples,
dataset_meta_data.num_spatial_dims + 1,
*dataset_meta_data.spatial_array,
),
dtype=float,
)

ds_object_centered_embeddings.attrs["axis_names"] = ["s", "c"] + [
"t",
"z",
"y",
"x",
][-dataset_meta_data.num_spatial_dims :]
ds_object_centered_embeddings.attrs["resolution"] = (
1,
) * dataset_meta_data.num_spatial_dims
ds_object_centered_embeddings.attrs["offset"] = (
0,
) * dataset_meta_data.num_spatial_dims

for sample in tqdm(range(dataset_meta_data.num_samples)):
embeddings = ds[sample]
embeddings_std = embeddings[-1, ...]
embeddings_mean = embeddings[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
if inference_config.threshold is None:
threshold = threshold_otsu(embeddings_std)
else:
threshold = inference_config.threshold

print(f"For sample {sample}, binary threshold {threshold} was used.")
binary_mask = embeddings_std < threshold
ds_binary_segmentation[sample, 0, ...] = binary_mask

# find mean of embeddings
embeddings_centered = embeddings.copy()
embeddings_mean_masked = (
binary_mask[np.newaxis, np.newaxis, ...] * embeddings_mean
)
if embeddings_centered.shape[0] == 3:
c_x = embeddings_mean_masked[0, 0]
c_y = embeddings_mean_masked[0, 1]
c_x = c_x[c_x != 0].mean()
c_y = c_y[c_y != 0].mean()
embeddings_centered[0] -= c_x
embeddings_centered[1] -= c_y
elif embeddings_centered.shape[0] == 4:
c_x = embeddings_mean_masked[0, 0]
c_y = embeddings_mean_masked[0, 1]
c_z = embeddings_mean_masked[0, 2]
c_x = c_x[c_x != 0].mean()
c_y = c_y[c_y != 0].mean()
c_z = c_z[c_z != 0].mean()
embeddings_centered[0] -= c_x
embeddings_centered[1] -= c_y
embeddings_centered[2] -= c_z
ds_object_centered_embeddings[sample] = embeddings_centered

embeddings_centered_mean = embeddings_centered[
np.newaxis, : dataset_meta_data.num_spatial_dims
]
embeddings_centered_std = embeddings_centered[-1]

if inference_config.clustering == "meanshift":
for bandwidth_factor in range(inference_config.num_bandwidths):
if inference_config.use_seeds:
offset_magnitude = np.linalg.norm(embeddings_centered[:-1], axis=0)
offset_magnitude_smooth = gaussian_filter(offset_magnitude, sigma=2)
coordinates = peak_local_max(-offset_magnitude_smooth)
seeds = np.flip(coordinates, 1)
segmentation = mean_shift_segmentation(
embeddings_centered_mean,
embeddings_centered_std,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_size=inference_config.min_size,
reduction_probability=inference_config.reduction_probability,
threshold=threshold,
seeds=seeds,
)
embeddings_centered_mean = embeddings_centered[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
else:
segmentation = mean_shift_segmentation(
embeddings_mean,
embeddings_std,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_size=inference_config.min_size,
reduction_probability=inference_config.reduction_probability,
threshold=threshold,
seeds=None,
)
# Note that the line below is needed
# because the embeddings_mean is modified
# by mean_shift_segmentation
embeddings_mean = embeddings[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()
ds_detection[sample, bandwidth_factor, ...] = segmentation
elif inference_config.clustering == "greedy":
if dataset_meta_data.num_spatial_dims == 3:
cluster3d = Cluster3d(
width=embeddings.shape[-1],
height=embeddings.shape[-2],
depth=embeddings.shape[-3],
fg_mask=binary_mask,
device=inference_config.device,
)
for bandwidth_factor in range(inference_config.num_bandwidths):
segmentation = cluster3d.cluster(
prediction=embeddings,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_object_size=inference_config.min_size,
)
ds_detection[sample, bandwidth_factor, ...] = segmentation
elif dataset_meta_data.num_spatial_dims == 2:
cluster2d = Cluster2d(
width=embeddings.shape[-1],
height=embeddings.shape[-2],
fg_mask=binary_mask,
device=inference_config.device,
)
for bandwidth_factor in range(inference_config.num_bandwidths):
segmentation = cluster2d.cluster(
prediction=embeddings,
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_object_size=inference_config.min_size,
)

ds_detection[sample, bandwidth_factor, ...] = segmentation
12 changes: 6 additions & 6 deletions cellulus/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch

from cellulus.datasets.meta_data import DatasetMetaData
from cellulus.detect import detect
from cellulus.evaluate import evaluate
from cellulus.models import get_model
from cellulus.post_process import post_process
from cellulus.predict import predict
from cellulus.segment import segment

Expand Down Expand Up @@ -35,7 +35,7 @@ def infer(experiment_config):
)
elif dataset_meta_data.num_spatial_dims == 3:
inference_config.min_size = int(
0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3)
0.1 * 4.0 / 3.0 * np.pi * (experiment_config.object_size**3) / 8
)
# set model
model = get_model(
Expand Down Expand Up @@ -69,12 +69,12 @@ def infer(experiment_config):
# get predicted embeddings...
if inference_config.prediction_dataset_config is not None:
predict(model, inference_config, normalization_factor)
# ...turn them into a segmentation...
# ...turn them into a detection ...
if inference_config.detection_dataset_config is not None:
detect(inference_config)
# ...and post-process the detection to obtain an instance segmentation
if inference_config.segmentation_dataset_config is not None:
segment(inference_config)
# ...and post-process the segmentation
if inference_config.post_processed_dataset_config is not None:
post_process(inference_config)
# ...and evaluate if ground-truth exists
if inference_config.evaluation_dataset_config is not None:
evaluate(inference_config)
Loading

0 comments on commit b15c21b

Please sign in to comment.