Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

44 finish applypy #68

Merged
merged 6 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 200 additions & 5 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,208 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(run_name: str, iteration: int, dataset_name: str):
def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion) # type: ignore
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run_name, iteration)

# find the best parameters
if isinstance(validation_dataset, str) and run.datasplit.validate is not None:
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
elif isinstance(validation_dataset, Dataset) or parameters is not None:
pass
else:
raise ValueError(
"validation_dataset must be a dataset name or a Dataset object, or parameters must be provided explicitly."
)
if parameters is None:
logger.info(
"Finding best parameters for validation dataset %s", validation_dataset
)
parameters = run.task.evaluator.get_overall_best_parameters( # TODO
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value) # type: ignore
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value) # type: ignore
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
if roi is None:
roi = input_array.roi
else:
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

logger.info(
"Applying results from run %s at iteration %d to dataset %s",
run_name,
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
iteration,
dataset_name,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)
raise NotImplementedError("This function is not yet implemented.")


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8, # type: ignore
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(parameters, output_array_identifier)

logger.info("Done")
return
55 changes: 44 additions & 11 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -40,21 +42,52 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run-name", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-i",
"--iteration",
"-ic",
"--input_container",
required=True,
type=int,
help="The iteration weights and parameters to use.",
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-r",
"--dataset",
required=True,
"-roi",
"--roi",
type=str,
help="The name of the dataset to apply the run to.",
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
8 changes: 4 additions & 4 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def predict( # TODO: MAKE THIS CLI ACCESSIBLE
num_cpu_workers: int = 4,
compute_context: ComputeContext = LocalTorch(),
output_roi: Optional[Roi] = None,
output_dtype: np.dtype = np.float32, # type: ignore
output_dtype: Optional[np.dtype] = np.uint8, # type: ignore
overwrite: bool = False,
):
# get the model's input and output size
Expand Down Expand Up @@ -112,7 +112,7 @@ def predict( # TODO: MAKE THIS CLI ACCESSIBLE
# write to zarr
pipeline += gp.ZarrWrite(
{prediction: prediction_array_identifier.dataset},
prediction_array_identifier.container.parent,
str(prediction_array_identifier.container.parent),
prediction_array_identifier.container.name,
dataset_dtypes={prediction: output_dtype},
)
Expand All @@ -130,8 +130,8 @@ def predict( # TODO: MAKE THIS CLI ACCESSIBLE
with gp.build(pipeline):
pipeline.request_batch(gp.BatchRequest())

container = zarr.open(prediction_array_identifier.container)
container = zarr.open(str(prediction_array_identifier.container))
dataset = container[prediction_array_identifier.dataset]
dataset.attrs["axes"] = (
dataset.attrs["axes"] = ( # type: ignore
raw_array.axes if "c" in raw_array.axes else ["c"] + raw_array.axes
)
2 changes: 2 additions & 0 deletions dacapo/store/create_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def create_config_store():
elif store_type == "files":
store_path = Path(options.runs_base_dir).expanduser()
return FileConfigStore(store_path / "configs")
else:
raise ValueError(f"Unknown store type {store_type}")


def create_stats_store():
Expand Down
3 changes: 2 additions & 1 deletion dacapo/store/local_weights_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from .weights_store import WeightsStore, Weights
from dacapo.experiments.run import Run

Expand Down Expand Up @@ -100,7 +101,7 @@ def store_best(self, run: str, iteration: int, dataset: str, criterion: str):
with best_weights_json.open("w") as f:
f.write(json.dumps({"iteration": iteration}))

def retrieve_best(self, run: str, dataset: str, criterion: str) -> int:
def retrieve_best(self, run: str, dataset: str | Dataset, criterion: str) -> int:
logger.info("Retrieving weights for run %s, criterion %s", run, criterion)

weights_info = json.loads(
Expand Down
56 changes: 56 additions & 0 deletions tests/operations/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from ..fixtures import *

from dacapo.experiments import Run
from dacapo.compute_context import LocalTorch
from dacapo.store import create_config_store, create_weights_store
from dacapo import apply

import pytest
from pytest_lazyfixture import lazy_fixture

import logging

logging.basicConfig(level=logging.INFO)


@pytest.mark.parametrize(
"run_config",
[
lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
lazy_fixture("onehot_run"),
],
)
def test_apply(
options,
run_config,
):
# TODO: test the apply function
return # remove this line to run the test
compute_context = LocalTorch(device="cpu")

# create a store

store = create_config_store()
weights_store = create_weights_store()

# store the configs

store.store_run_config(run_config)

run_config = store.retrieve_run_config(run_config.name)
run = Run(run_config)

# -------------------------------------

# apply

# test validating iterations for which we know there are weights
weights_store.store_weights(run, 0)
apply(run_config.name, 0, compute_context=compute_context)
weights_store.store_weights(run, 1)
apply(run_config.name, 1, compute_context=compute_context)

# test validating weights that don't exist
with pytest.raises(FileNotFoundError):
apply(run_config.name, 2, compute_context=compute_context)
Loading