Skip to content

Commit

Permalink
Basic tutorial (#308)
Browse files Browse the repository at this point in the history
Add a basic in-memory postprocessor for watershed.

It doesn't seem to be working that well because prediction seems broken.
Here's a picture:

![Figure_1](https://github.com/user-attachments/assets/01c209e3-914e-44a2-85b4-70343b1a984f)

I expect the postprocessor to work significantly better if the
predictions are done with enough context to avoid the block artifacts
and the predictions are post processed properly so that high affinity
within objects leads to an affinity of 1 (it looks like its set to 0
here), and a low affinity in the background leads to an affinity of 0
(looks like it is set to 0.5 here)
  • Loading branch information
mzouink authored Oct 22, 2024
2 parents 1cf8ecc + f9b0ac1 commit be53a19
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from dacapo.blockwise.scheduler import segment_blockwise
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.utils.array_utils import to_ndarray, save_ndarray
from funlib.persistence import open_ds
import daisy
import mwatershed as mws

from .watershed_post_processor_parameters import WatershedPostProcessorParameters
from .post_processor import PostProcessor
Expand Down Expand Up @@ -123,29 +127,15 @@ def process(
np.uint64,
block_size * self.prediction_array.voxel_size,
)
input_array = open_ds(
self.prediction_array_identifier.container.path,
self.prediction_array_identifier.dataset,
)

read_roi = Roi((0, 0, 0), self.prediction_array.voxel_size * block_size)
# run blockwise prediction
pars = {
"offsets": self.offsets,
"bias": parameters.bias,
"context": parameters.context,
}
segment_blockwise(
segment_function_file=str(
Path(Path(dacapo.blockwise.__file__).parent, "watershed_function.py")
),
context=parameters.context,
total_roi=self.prediction_array.roi,
read_roi=read_roi.grow(parameters.context, parameters.context),
write_roi=read_roi,
num_workers=num_workers,
max_retries=2, # TODO: make this an option
timeout=None, # TODO: make this an option
######
input_array_identifier=self.prediction_array_identifier,
output_array_identifier=output_array_identifier,
parameters=pars,
data = to_ndarray(input_array, output_array.roi)
segmentation = mws.agglom(
data - parameters.bias, offsets=self.offsets, randomized_strides=True
)
save_ndarray(segmentation, self.prediction_array.roi, output_array)

return output_array
return output_array_identifier
39 changes: 39 additions & 0 deletions docs/source/notebooks/minimal_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,42 @@
ax[snapshot, 2].imshow(prediction[prediction.shape[0] // 2])
ax[snapshot, 0].set_ylabel(f"Snapshot {snapshot_it}")
plt.show()

# %%
# Visualize validations
import zarr

num_validations = run_config.num_iterations // run_config.validation_interval
fig, ax = plt.subplots(num_validations, 4, figsize=(10, 2 * num_validations))

# Set column titles
column_titles = ["Raw", "Ground Truth", "Prediction", "Segmentation"]
for col in range(len(column_titles)):
ax[0, col].set_title(column_titles[col])

for validation in range(1, num_validations + 1):
dataset = run.datasplit.validate[0].name
validation_it = validation * run_config.validation_interval
# break
raw = zarr.open(
f"/Users/pattonw/dacapo/example_run/validation.zarr/inputs/{dataset}/raw"
)[:]
gt = zarr.open(
f"/Users/pattonw/dacapo/example_run/validation.zarr/inputs/{dataset}/gt"
)[0]
pred_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/prediction"
out_path = f"/Users/pattonw/dacapo/example_run/validation.zarr/{validation_it}/ds_{dataset}/output/WatershedPostProcessorParameters(id=2, bias=0.5, context=(32, 32, 32))"
output = zarr.open(
out_path
)[:]
prediction = zarr.open(pred_path)[0]
print(raw.shape, gt.shape, output.shape)
c = (raw.shape[1] - gt.shape[1]) // 2
if c != 0:
raw = raw[:, c:-c, c:-c]
ax[validation - 1, 0].imshow(raw[raw.shape[0] // 2])
ax[validation - 1, 1].imshow(gt[gt.shape[0] // 2])
ax[validation - 1, 2].imshow(prediction[prediction.shape[0] // 2])
ax[validation - 1, 3].imshow(output[output.shape[0] // 2])
ax[validation - 1, 0].set_ylabel(f"Validation {validation_it}")
plt.show()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ docs = [
"sphinx-click",
"sphinx-rtd-theme",
"myst-parser",
"matplotlib",
"pooch",
]
examples = [
"ipython",
Expand Down

0 comments on commit be53a19

Please sign in to comment.