Skip to content

Commit

Permalink
Basic tutorial (#310)
Browse files Browse the repository at this point in the history
add some improvements to DaCapo and the basic tutorial for learning
instance segmentation through affs + watershed

DaCapo:
1) avoid doing any special postprocessing to convert to uint8 when
writing out the predictions. just store as float32
2) avoid converting back to float in the watershed postprocessor, just
use the predictions as saved in the zarr.

Tutorial:
1) add labels colormap
2) train z affs
3) use valid padding

Loss and validation plots are included below.
There is still something strange happening with the loss after the first
validation.
The results still aren't as nice as they should be on such a simple toy
dataset.


![Figure_2](https://github.com/user-attachments/assets/8badb5ca-5f4b-4e8e-821b-80f13c4987eb)

![Figure_3](https://github.com/user-attachments/assets/6f7276b6-5e3b-40b1-bcdc-ba0f2b865052)
  • Loading branch information
mzouink authored Oct 23, 2024
2 parents f1fdef3 + 51b3821 commit f4ac3a6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def process(
self.prediction_array_identifier.dataset,
)

data = to_ndarray(input_array, output_array.roi)
data = to_ndarray(input_array, output_array.roi).astype(float)
segmentation = mws.agglom(
data - parameters.bias, offsets=self.offsets, randomized_strides=True
)
Expand Down
5 changes: 1 addition & 4 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def predict(
output_roi,
num_channels,
output_voxel_size,
np.uint8,
np.float32,
)

logger.info("Total input ROI: %s, output ROI: %s", input_size, output_roi)
Expand All @@ -82,9 +82,6 @@ def predict_fn(block):
.cpu()
.numpy()[0]
)
predictions = (predictions + 1) * 255.0 / 2.0
predictions[predictions > 254] = 0
predictions = np.round(predictions).astype(np.uint8)

save_ndarray(predictions, block.write_roi, result_dataset)
# result_dataset[block.write_roi] = predictions
Expand Down
49 changes: 20 additions & 29 deletions docs/source/notebooks/minimal_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
# import random

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import numpy as np
from funlib.geometry import Coordinate, Roi
from funlib.persistence import prepare_ds
Expand Down Expand Up @@ -134,29 +135,18 @@
)
labels_array[labels_array.roi] = label(mask_array.to_ndarray(mask_array.roi))[0]

# Generate affinity graph
affs_array = prepare_ds(
"cells3d.zarr",
"affs",
Roi((0, 0, 0), cell_data.shape[1:]) * voxel_size,
voxel_size=voxel_size,
num_channels=3,
dtype=np.uint8,
)
affs_array[affs_array.roi] = (
seg_to_affgraph(
labels_array.to_ndarray(labels_array.roi),
neighborhood=[Coordinate(1, 0, 0), Coordinate(0, 1, 0), Coordinate(0, 0, 1)],
)
* 255
)
print("Data saved to cells3d.zarr")


# Create a custom label color map for showing instances
np.random.seed(1)
colors = [[0, 0, 0]] + [list(np.random.choice(range(256), size=3)) for _ in range(254)]
label_cmap = ListedColormap(colors)

# %% [markdown]
# Here we show a slice of the raw data:
# %%
plt.imshow(cell_array.data[30])
# plt.imshow(cell_array.data[30])

# %% [markdown]
# ## Datasplit
Expand All @@ -177,14 +167,14 @@
raw_container="cells3d.zarr",
raw_dataset="raw",
gt_container="cells3d.zarr",
gt_dataset="mask",
gt_dataset="labels",
),
DatasetSpec(
dataset_type="val",
raw_container="cells3d.zarr",
raw_dataset="raw",
gt_container="cells3d.zarr",
gt_dataset="mask",
gt_dataset="labels",
),
]

Expand Down Expand Up @@ -229,7 +219,7 @@
# an example affinities task configuration
affs_task_config = AffinitiesTaskConfig(
name="example_affs",
neighborhood=[(0, 1, 0), (0, 0, 1)],
neighborhood=[(1, 0, 0), (0, 1, 0), (0, 0, 1)],
)
# config_store.delete_task_config(dist_task_config.name)
config_store.store_task_config(affs_task_config)
Expand All @@ -249,8 +239,8 @@
# all with 1s in z meaning no downsampling or convolving in the z direction.
architecture_config = CNNectomeUNetConfig(
name="example_unet",
input_shape=(2, 64, 64),
eval_shape_increase=(7, 0, 0),
input_shape=(2, 132, 132),
eval_shape_increase=(8, 32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
Expand All @@ -259,7 +249,7 @@
kernel_size_down=[[(1, 3, 3)] * 2] * 3,
kernel_size_up=[[(1, 3, 3)] * 2] * 2,
constant_upsample=True,
padding="same",
padding="valid",
)
config_store.store_architecture_config(architecture_config)

Expand Down Expand Up @@ -401,17 +391,18 @@
)[0]
pred_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/prediction"
out_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))"
output = zarr.open(
out_path
)[:]
output = zarr.open(out_path)[:]
prediction = zarr.open(pred_path)[0]
print(raw.shape, gt.shape, output.shape)
c = (raw.shape[1] - gt.shape[1]) // 2
if c != 0:
raw = raw[:, c:-c, c:-c]
ax[validation - 1, 0].imshow(raw[raw.shape[0] // 2])
ax[validation - 1, 1].imshow(gt[gt.shape[0] // 2])
ax[validation - 1, 1].imshow(
gt[gt.shape[0] // 2], cmap=label_cmap, interpolation="none"
)
ax[validation - 1, 2].imshow(prediction[prediction.shape[0] // 2])
ax[validation - 1, 3].imshow(output[output.shape[0] // 2])
ax[validation - 1, 3].imshow(
output[output.shape[0] // 2], cmap=label_cmap, interpolation="none"
)
ax[validation - 1, 0].set_ylabel(f"Validation {validation_it}")
plt.show()

0 comments on commit f4ac3a6

Please sign in to comment.