Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rhoadesj simple changes #25

Merged
merged 2 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 192 additions & 4 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,200 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(run_name: str, iteration: int, dataset_name: str):
def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
# weights_store._load_best(run, criterion)
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion)
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights?

# find the best parameters
if isinstance(validation_dataset, str):
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
logger.info("Finding best parameters for validation dataset %s", validation_dataset)
if parameters is None:
parameters = run.task.evaluator.get_overall_best_parameters(
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value)
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value)
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

logger.info(
"Applying results from run %s at iteration %d to dataset %s",
run_name,
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
iteration,
dataset_name,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(
parameters, output_array_identifier, overwrite=overwrite, blockwise=True
)

logger.info("Done")
return
55 changes: 44 additions & 11 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -40,21 +42,52 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-i",
"--iteration",
"-ic",
"--input_container",
required=True,
type=int,
help="The iteration weights and parameters to use.",
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-r",
"--dataset",
required=True,
"-roi",
"--roi",
type=str,
help="The name of the dataset to apply the run to.",
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def axes(self):
logger.debug(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
)
return ["c", "z", "y", "x"][-self.dims : :]

Expand Down
6 changes: 5 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
num_voxels=task_config.num_voxels,
downsample_lsds=task_config.downsample_lsds,
grow_boundary_iterations=task_config.grow_boundary_iterations,
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
Expand Down
20 changes: 20 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,23 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
num_voxels: int = attr.ib(
default=20,
metadata={
"help_text": "The number of voxels to use for the gaussian sigma when computing lsds."
},
)
downsample_lsds: int = attr.ib(
default=1,
metadata={
"help_text": "The amount to downsample the lsds. "
"This is useful for speeding up training and inference."
},
)
grow_boundary_iterations: int = attr.ib(
default=0,
metadata={
"help_text": "The number of iterations to run the grow boundaries algorithm. "
"This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects."
},
)
33 changes: 24 additions & 9 deletions dacapo/experiments/tasks/predictors/affinities_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@


class AffinitiesPredictor(Predictor):
def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
def __init__(
self,
neighborhood: List[Coordinate],
lsds: bool = True,
num_voxels: int = 20,
downsample_lsds: int = 1,
grow_boundary_iterations: int = 0,
):
self.neighborhood = neighborhood
self.lsds = lsds
self.num_voxels = num_voxels
if lsds:
self._extractor = None
if self.dims == 2:
Expand All @@ -30,12 +38,16 @@ def __init__(self, neighborhood: List[Coordinate], lsds: bool = True):
raise ValueError(
f"Cannot compute lsds on volumes with {self.dims} dimensions"
)
self.downsample_lsds = downsample_lsds
else:
self.num_lsds = 0
self.grow_boundary_iterations = grow_boundary_iterations

def extractor(self, voxel_size):
if self._extractor is None:
self._extractor = LsdExtractor(self.sigma(voxel_size))
self._extractor = LsdExtractor(
self.sigma(voxel_size), downsample=self.downsample_lsds
)

return self._extractor

Expand All @@ -45,8 +57,7 @@ def dims(self):

def sigma(self, voxel_size):
voxel_dist = max(voxel_size) # arbitrarily chosen
num_voxels = 10 # arbitrarily chosen
sigma = voxel_dist * num_voxels
sigma = voxel_dist * self.num_voxels # arbitrarily chosen
return Coordinate((sigma,) * self.dims)

def lsd_pad(self, voxel_size):
Expand Down Expand Up @@ -118,7 +129,9 @@ def _grow_boundaries(self, mask, slab):
slice(start[d], start[d] + slab[d]) for d in range(len(slab))
)
mask_slab = mask[slices]
dilated_mask_slab = ndimage.binary_dilation(mask_slab, iterations=1)
dilated_mask_slab = ndimage.binary_dilation(
mask_slab, iterations=self.grow_boundary_iterations
)
foreground[slices] = dilated_mask_slab

# label new background
Expand All @@ -130,10 +143,12 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
(moving_class_counts, moving_lsd_class_counts) = (
moving_class_counts if moving_class_counts is not None else (None, None)
)
# mask_data = self._grow_boundaries(
# mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
# )
mask_data = mask[target.roi]
if self.grow_boundary_iterations > 0:
mask_data = self._grow_boundaries(
mask[target.roi], slab=tuple(1 if c == "c" else -1 for c in target.axes)
)
else:
mask_data = mask[target.roi]
aff_weights, moving_class_counts = balance_weights(
target[target.roi][: self.num_channels - self.num_lsds].astype(np.uint8),
2,
Expand Down
Loading
Loading