Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev/post-process: post-processing related changes #54

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ def set_prediction(self, prediction_array_identifier):
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[dim for dim in self.prediction_array.axes if dim != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = np.argmax(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def enumerate_parameters(self) -> Iterable[DummyPostProcessorParameters]:
def set_prediction(self, prediction_array):
pass

def process(self, parameters, output_array_identifier):
def process(self, parameters, output_array_identifier, overwrite: bool = False):
# store some dummy data
f = zarr.open(str(output_array_identifier.container), "a")
f[output_array_identifier.dataset] = np.ones((10, 10, 10)) * parameters.min_size
2 changes: 2 additions & 0 deletions dacapo/experiments/tasks/post_processors/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: "bool",
blockwise: "bool",
) -> "Array":
"""Convert predictions into the final output."""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def process(
self,
parameters: "PostProcessorParameters",
output_array_identifier: "LocalArrayIdentifier",
overwrite: bool = False,
) -> ZarrArray:
# TODO: Investigate Liskov substitution princple and whether it is a problem here
# OOP theory states the super class should always be replaceable with its subclasses
Expand All @@ -47,6 +48,7 @@ def process(
self.prediction_array.num_channels,
self.prediction_array.voxel_size,
np.uint8,
overwrite=overwrite,
)

output_array[self.prediction_array.roi] = (
Expand Down
108 changes: 65 additions & 43 deletions dacapo/experiments/tasks/post_processors/watershed_post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,53 +24,75 @@ def enumerate_parameters(self):
"""Enumerate all possible parameters of this post-processor. Should
return instances of ``PostProcessorParameters``."""

for i, bias in enumerate([0.1, 0.25, 0.5, 0.75, 0.9]):
for i, bias in enumerate(
[0.1, 0.3, 0.5, 0.7, 0.9]
): # TODO: add this to the config
yield WatershedPostProcessorParameters(id=i, bias=bias)

def set_prediction(self, prediction_array_identifier):
self.prediction_array = ZarrArray.open_from_array_identifier(
prediction_array_identifier
)

def process(self, parameters, output_array_identifier):
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[axis for axis in self.prediction_array.axes if axis != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint64,
)
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
pred_data = self.prediction_array[self.prediction_array.roi]
affs = pred_data[: len(self.offsets)].astype(np.float64)
segmentation = mws.agglom(
affs - parameters.bias,
self.offsets,
)
# filter fragments
average_affs = np.mean(affs, axis=0)

filtered_fragments = []

fragment_ids = np.unique(segmentation)

for fragment, mean in zip(
fragment_ids, measurements.mean(average_affs, segmentation, fragment_ids)
):
if mean < parameters.bias:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)

# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
if filtered_fragments.size > 0:
segmentation = npi.remap(
segmentation.flatten(), filtered_fragments, replace
).reshape(segmentation.shape)

output_array[self.prediction_array.roi] = segmentation

return output_array
def process(
self,
parameters,
output_array_identifier,
overwrite: bool = False,
blockwise: bool = False,
): # TODO: will probably break with large arrays...
if not blockwise:
output_array = ZarrArray.create_from_array_identifier(
output_array_identifier,
[axis for axis in self.prediction_array.axes if axis != "c"],
self.prediction_array.roi,
None,
self.prediction_array.voxel_size,
np.uint64,
overwrite=overwrite,
)
# if a previous segmentation is provided, it must have a "grid graph"
# in its metadata.
# pred_data = self.prediction_array[self.prediction_array.roi]
# affs = pred_data[: len(self.offsets)].astype(
# np.float64
# ) # TODO: shouldn't need to be float64
affs = self.prediction_array[self.prediction_array.roi][: len(self.offsets)]
if affs.dtype == np.uint8:
affs = affs.astype(np.float64) / 255.0
else:
affs = affs.astype(np.float64)
segmentation = mws.agglom(
affs - parameters.bias,
self.offsets,
)
# filter fragments
average_affs = np.mean(affs, axis=0)

filtered_fragments = []

fragment_ids = np.unique(segmentation)

for fragment, mean in zip(
fragment_ids,
measurements.mean(average_affs, segmentation, fragment_ids),
):
if mean < parameters.bias:
filtered_fragments.append(fragment)

filtered_fragments = np.array(filtered_fragments, dtype=segmentation.dtype)
replace = np.zeros_like(filtered_fragments)

# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
if filtered_fragments.size > 0:
segmentation = npi.remap(
segmentation.flatten(), filtered_fragments, replace
).reshape(segmentation.shape)

output_array[self.prediction_array.roi] = segmentation

return output_array
else:
raise NotImplementedError(
"Blockwise processing not yet implemented."
) # TODO: add rusty mws
Loading