Skip to content

Commit

Permalink
If num_channels ==0, uncertainty has no channels.
Browse files Browse the repository at this point in the history
  • Loading branch information
lmanan committed Jun 18, 2024
1 parent d4b3906 commit d4d03e0
Showing 1 changed file with 48 additions and 37 deletions.
85 changes: 48 additions & 37 deletions src/napari_cellulus/widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,25 +791,14 @@ def infer(self):
)

if self.napari_dataset.get_num_channels() == 0:
output_shape = gp.Coordinate(
self.model(
torch.zeros((1, *crop_size_tuple), dtype=torch.float32).to(
self.device
)
).shape
)
output_shape = gp.Coordinate((1, 1, *predicted_crop_size_tuple))
else:
output_shape = gp.Coordinate(
self.model(
torch.zeros(
(
1,
self.napari_dataset.get_num_channels(),
*crop_size_tuple,
),
dtype=torch.float32,
).to(self.device)
).shape
(
1,
self.napari_dataset.get_num_channels(),
*predicted_crop_size_tuple,
)
)

voxel_size = (
Expand Down Expand Up @@ -932,26 +921,48 @@ def infer(self):

embeddings_centered[sample] = embeddings_centered_sample

embeddings_layers = [
(
embeddings_centered[:, i : i + 1, ...].copy(),
{
"name": "Offset ("
+ "zyx"[self.napari_dataset.get_num_spatial_dims() - i]
+ ")"
if i < self.napari_dataset.get_num_spatial_dims()
else "Uncertainty",
"colormap": colormaps[
self.napari_dataset.get_num_spatial_dims() - i
]
if i < self.napari_dataset.get_num_spatial_dims()
else "gray",
"blending": "additive",
},
"image",
)
for i in range(self.napari_dataset.get_num_spatial_dims() + 1)
]
if self.napari_dataset.get_num_channels() == 0:
embeddings_layers = [
(
embeddings_centered[:, i, ...].copy(),
{
"name": "Offset ("
+ "zyx"[self.napari_dataset.get_num_spatial_dims() - i]
+ ")"
if i < self.napari_dataset.get_num_spatial_dims()
else "Uncertainty",
"colormap": colormaps[
self.napari_dataset.get_num_spatial_dims() - i
]
if i < self.napari_dataset.get_num_spatial_dims()
else "gray",
"blending": "additive",
},
"image",
)
for i in range(self.napari_dataset.get_num_spatial_dims() + 1)
]
else:
embeddings_layers = [
(
embeddings_centered[:, i : i + 1, ...].copy(),
{
"name": "Offset ("
+ "zyx"[self.napari_dataset.get_num_spatial_dims() - i]
+ ")"
if i < self.napari_dataset.get_num_spatial_dims()
else "Uncertainty",
"colormap": colormaps[
self.napari_dataset.get_num_spatial_dims() - i
]
if i < self.napari_dataset.get_num_spatial_dims()
else "gray",
"blending": "additive",
},
"image",
)
for i in range(self.napari_dataset.get_num_spatial_dims() + 1)
]

print("Clustering Objects in the obtained Foreground Mask ...")
if hasattr(self, "detection"):
Expand Down

0 comments on commit d4d03e0

Please sign in to comment.