Skip to content

Commit

Permalink
Merge branch 'main' into attention-unet
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Feb 9, 2024
2 parents c45e93c + 5f50f9b commit e1806fe
Show file tree
Hide file tree
Showing 15 changed files with 381 additions and 86 deletions.
34 changes: 21 additions & 13 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
name: Python Black

name: black-action
on: [push, pull_request]

jobs:
lint:
name: Python Lint
linter_name:
name: runner / black
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
- name: Setup checkout
uses: actions/checkout@master
- name: Lint with Black
run: |
pip install black
black -v --check dacapo tests
- uses: actions/checkout@v2
- name: Check files using the black formatter
uses: rickstaa/action-black@v1
id: action_black
with:
black_args: "."
- name: Create Pull Request
if: steps.action_black.outputs.is_formatted == 'true'
uses: peter-evans/create-pull-request@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: "Format Python code with psf/black push"
commit-message: ":art: Format Python code with psf/black"
body: |
There appear to be some python formatting errors in ${{ github.sha }}. This pull request
uses the [psf/black](https://github.com/psf/black) formatter to fix these issues.
base: ${{ github.head_ref }} # Creates pull request onto pull request or commit branch
branch: actions/black
196 changes: 192 additions & 4 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,200 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(run_name: str, iteration: int, dataset_name: str):
def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
# weights_store._load_best(run, criterion)
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion)
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights?

# find the best parameters
if isinstance(validation_dataset, str):
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
logger.info("Finding best parameters for validation dataset %s", validation_dataset)
if parameters is None:
parameters = run.task.evaluator.get_overall_best_parameters(
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value)
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value)
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

logger.info(
"Applying results from run %s at iteration %d to dataset %s",
run_name,
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
iteration,
dataset_name,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(
parameters, output_array_identifier, overwrite=overwrite, blockwise=True
)

logger.info("Done")
return
55 changes: 44 additions & 11 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -40,21 +42,52 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to apply."
)
@click.option(
"-i",
"--iteration",
"-ic",
"--input_container",
required=True,
type=int,
help="The iteration weights and parameters to use.",
type=click.Path(exists=True, file_okay=False),
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-r",
"--dataset",
required=True,
"-roi",
"--roi",
type=str,
help="The name of the dataset to apply the run to.",
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
5 changes: 3 additions & 2 deletions dacapo/experiments/datasplits/datasets/arrays/concat_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
logger.info(f"Concatenated array has only one channel: {self.name} {concatenated.shape}")
# raise Exception(f"{concatenated.shape}, shapes")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def axes(self):
logger.debug(
"DaCapo expects Zarr datasets to have an 'axes' attribute!\n"
f"Zarr {self.file_name} and dataset {self.dataset} has attributes: {list(self._attributes.items())}\n"
f"Using default {['t', 'z', 'y', 'x'][-self.dims::]}",
f"Using default {['c', 'z', 'y', 'x'][-self.dims::]}",
)
return ["c", "z", "y", "x"][-self.dims : :]

Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
architecture: Architecture,
prediction_head: torch.nn.Module,
eval_activation: torch.nn.Module = None,
eval_activation: torch.nn.Module | None = None,
):
super().__init__()

Expand Down
6 changes: 5 additions & 1 deletion dacapo/experiments/tasks/affinities_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ def __init__(self, task_config):
"""Create a `DummyTask` from a `DummyTaskConfig`."""

self.predictor = AffinitiesPredictor(
neighborhood=task_config.neighborhood, lsds=task_config.lsds
neighborhood=task_config.neighborhood,
lsds=task_config.lsds,
num_voxels=task_config.num_voxels,
downsample_lsds=task_config.downsample_lsds,
grow_boundary_iterations=task_config.grow_boundary_iterations,
)
self.loss = AffinitiesLoss(len(task_config.neighborhood))
self.post_processor = WatershedPostProcessor(offsets=task_config.neighborhood)
Expand Down
20 changes: 20 additions & 0 deletions dacapo/experiments/tasks/affinities_task_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,23 @@ class AffinitiesTaskConfig(TaskConfig):
"It has been shown that lsds as an auxiliary task can help affinity predictions."
},
)
num_voxels: int = attr.ib(
default=20,
metadata={
"help_text": "The number of voxels to use for the gaussian sigma when computing lsds."
},
)
downsample_lsds: int = attr.ib(
default=1,
metadata={
"help_text": "The amount to downsample the lsds. "
"This is useful for speeding up training and inference."
},
)
grow_boundary_iterations: int = attr.ib(
default=0,
metadata={
"help_text": "The number of iterations to run the grow boundaries algorithm. "
"This is useful for refining the boundaries of the affinities, and reducing merging of adjacent objects."
},
)
Loading

0 comments on commit e1806fe

Please sign in to comment.