Skip to content

Commit

Permalink
Dev/main (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored May 9, 2024
2 parents d88faf8 + 9ee5feb commit d81281f
Show file tree
Hide file tree
Showing 186 changed files with 18,430 additions and 931 deletions.
8 changes: 8 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,11 @@ This will also be run automatically when a PR is made to master and a codecov re
- For a completely new feature, make a branch off of the `dev/main` branch of CellMap's fork of DaCapo with a name describing the feature. If you are collaborating on a feature that already has a branch, you can branch off that feature branch.
- Currently, you should make your PRs into the `dev/main` branch of CellMap's fork, or the feature branch you branched off of. PRs currently require one maintainer's approval before merging. Once the PR is merged, the feature branch should be deleted.
- `dev/main` will be regularly merged to `main` when new features are fully implemented and all tests are passing.


## Documentation
Documentation is built using Sphinx. To build the documentation locally, run
```bash
sphinx-build -M html docs/source docs/build
```
This will generate the html files in the `docs/build/html` directory.
67 changes: 64 additions & 3 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
create_weights_store,
)

from pathlib import Path
from upath import UPath as Path

logger = logging.getLogger(__name__)

Expand All @@ -38,7 +38,40 @@ def apply(
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
"""
Load weights and apply a trained model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used.
Args:
run_name (str): Name of the run to apply.
input_container (Path | str): Path to the input container.
input_dataset (str): Name of the input dataset.
output_path (Path | str): Path to the output container.
validation_dataset (Optional[Dataset | str], optional): Validation dataset to use for finding the best parameters. Defaults to None.
criterion (str, optional): Criterion to use for finding the best parameters. Defaults to "voi".
iteration (Optional[int], optional): Iteration to use. If None, the best iteration is used. Defaults to None.
parameters (Optional[PostProcessorParameters | str], optional): Post-processor parameters to use. If None, the best parameters are found. Defaults to None.
roi (Optional[Roi | str], optional): Region of interest to use. If None, the whole input dataset is used. Defaults to None.
num_workers (int, optional): Number of workers to use. Defaults to 12.
output_dtype (np.dtype | str, optional): Output dtype. Defaults to np.uint8.
overwrite (bool, optional): Overwrite existing output. Defaults to True.
file_format (str, optional): File format to use. Defaults to "zarr".
Raises:
ValueError: If validation_dataset is None and criterion is not None.
ValueError: If parameters is a string that cannot be parsed to PostProcessorParameters.
ValueError: If parameters is not a PostProcessorParameters object.
Examples:
>>> apply(
... run_name="run_1",
... input_container="data.zarr",
... input_dataset="raw",
... output_path="output.zarr",
... validation_dataset="validate",
... criterion="voi",
... num_workers=12,
... output_dtype=np.uint8,
... overwrite=True,
... )
"""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

Expand Down Expand Up @@ -178,8 +211,36 @@ def apply_run(
output_dtype: np.dtype | str = np.uint8, # type: ignore
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
"""
Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded.
Args:
run (Run): The run object containing the task and post-processor.
iteration (int): The iteration number.
parameters (PostProcessorParameters): The post-processor parameters.
input_array_identifier (LocalArrayIdentifier): The identifier for the input array.
prediction_array_identifier (LocalArrayIdentifier): The identifier for the prediction array.
output_array_identifier (LocalArrayIdentifier): The identifier for the output array.
roi (Optional[Roi], optional): The region of interest. Defaults to None.
num_workers (int, optional): The number of workers for parallel processing. Defaults to 12.
output_dtype (np.dtype | str, optional): The output data type. Defaults to np.uint8.
overwrite (bool, optional): Whether to overwrite existing output. Defaults to True.
Raises:
ValueError: If the input array is not a ZarrArray.
Examples:
>>> apply_run(
... run=run,
... iteration=1,
... parameters=parameters,
... input_array_identifier=LocalArrayIdentifier(Path("data.zarr"), "raw"),
... prediction_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "prediction_run_1_1"),
... output_array_identifier=LocalArrayIdentifier(Path("output.zarr"), "output_run_1_1"),
... roi=None,
... num_workers=12,
... output_dtype=np.uint8,
... overwrite=True,
... )
"""
# render prediction dataset
print(f"Predicting on dataset {prediction_array_identifier}")
predict(
Expand Down
65 changes: 50 additions & 15 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pathlib import Path
from upath import UPath as Path
import sys
from dacapo.experiments.datasplits.datasets.arrays.zarr_array import ZarrArray
from dacapo.store.array_store import LocalArrayIdentifier
Expand Down Expand Up @@ -27,6 +27,12 @@
default="INFO",
)
def cli(log_level):
"""
CLI for running the threshold worker.
Args:
log_level (str): The log level to use.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))


Expand All @@ -47,7 +53,17 @@ def start_worker(
input_dataset: str,
output_container: Path | str,
output_dataset: str,
return_io_loop: bool = False,
):
"""
Start the threshold worker.
Args:
input_container (Path | str): The input container.
input_dataset (str): The input dataset.
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
Expand All @@ -57,34 +73,51 @@ def start_worker(
)
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)

# wait for blocks to run pipeline
client = daisy.Client()
def io_loop():
# wait for blocks to run pipeline
client = daisy.Client()

while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break
while True:
print("getting block")
with client.acquire_block() as block:
if block is None:
break

# write to output array
output_array[block.write_roi] = np.argmax(
input_array[block.write_roi],
axis=input_array.axes.index("c"),
)
# write to output array
output_array[block.write_roi] = np.argmax(
input_array[block.write_roi],
axis=input_array.axes.index("c"),
)

if return_io_loop:
return io_loop
else:
io_loop()


def spawn_worker(
input_array_identifier: "LocalArrayIdentifier",
output_array_identifier: "LocalArrayIdentifier",
):
"""Spawn a worker to predict on a given dataset.
"""
Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Returns:
Callable: The function to run the worker.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
return start_worker(
input_array_identifier.container,
input_array_identifier.dataset,
output_array_identifier.container,
output_array_identifier.dataset,
return_io_loop=True,
)

# Make the command for the worker to run
command = [
Expand All @@ -103,7 +136,9 @@ def spawn_worker(
]

def run_worker():
# Run the worker in the given compute context
"""
Run the worker in the given compute context.
"""
compute_context.execute(command)

return run_worker
Expand Down
41 changes: 40 additions & 1 deletion dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,32 @@
from datetime import datetime
from importlib.machinery import SourceFileLoader
from pathlib import Path
from upath import UPath as Path
from daisy import Task, Roi


class DaCapoBlockwiseTask(Task):
"""
A task to run a blockwise worker function. This task is used to run a
blockwise worker function on a given ROI.
Attributes:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Methods:
__init__:
Initialize the task.
"""

def __init__(
self,
worker_file: str | Path,
Expand All @@ -18,6 +40,23 @@ def __init__(
*args,
**kwargs,
):
"""
Initialize the task.
Args:
worker_file (str | Path): The path to the worker file.
total_roi (Roi): The ROI to process.
read_roi (Roi): The ROI to read from for a block.
write_roi (Roi): The ROI to write to for a block.
num_workers (int): The number of workers to use.
max_retries (int): The maximum number of times a task will be retried if failed
(either due to failed post check or application crashes or network
failure)
timeout: The timeout for the task.
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
"""
# Load worker functions
worker_name = Path(worker_file).stem
worker = SourceFileLoader(worker_name, str(worker_file)).load_module()
Expand Down
Loading

0 comments on commit d81281f

Please sign in to comment.