Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rhoadesj/hot distance #39

Merged
merged 45 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5fe699a
Update black.yaml - force formatting
rhoadesScholar Jan 16, 2024
0098fa2
Update black.yaml - automatic black
rhoadesScholar Jan 16, 2024
13ae1f0
Update black.yaml - remove auto black
rhoadesScholar Jan 16, 2024
d741626
Update black.yaml
mzouink Feb 7, 2024
9351c16
remove black
mzouink Feb 7, 2024
bb8cab5
add black format check
mzouink Feb 7, 2024
c25cb79
black format on pull request
mzouink Feb 7, 2024
08f134d
:art: Format Python code with psf/black
mzouink Feb 7, 2024
e73afa2
Merge pull request #19 from janelia-cellmap/actions/black
mzouink Feb 7, 2024
ff61f7c
bug fixes and better logs
mzouink Feb 7, 2024
1490440
Update train.py
mzouink Feb 7, 2024
3c5f2da
Update train.py
mzouink Feb 7, 2024
34e8253
Merge pull request #20 from janelia-cellmap/zouinkhim_fixes
rhoadesScholar Feb 7, 2024
33bbc8a
feat: 🚧 Incorporate simple change from rhoadesj/dev
rhoadesScholar Feb 8, 2024
fe23b5d
feat: 🚧 Incorporate simple change from rhoadesj/dev
rhoadesScholar Feb 8, 2024
5f50f9b
Merge pull request #25 from janelia-cellmap/rhoadesj_simple_changes
mzouink Feb 8, 2024
812acc1
feat: ⚡️ Incorporate start related changes from rhoadesj/dev
rhoadesScholar Feb 8, 2024
ce5d272
docs: 📝 Add authors and versioning.
rhoadesScholar Feb 9, 2024
4f1dfed
starter partial weight load
mzouink Feb 9, 2024
910b2e0
Merge pull request #35 from janelia-cellmap/fix_start_loader
mzouink Feb 9, 2024
906dfd6
:art: Format Python code with psf/black
mzouink Feb 9, 2024
8cce986
Merge pull request #37 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
f5e584a
publish to pypi
mzouink Feb 9, 2024
281a768
logo
rhoadesScholar Feb 9, 2024
65482f4
Update README.md
rhoadesScholar Feb 9, 2024
52409f9
Update README.md
rhoadesScholar Feb 9, 2024
b8e18b4
Update publish.yaml
mzouink Feb 9, 2024
448f766
feat: ⚡️ Incorporate hot_distance related changes from rhoadesj/dev
rhoadesScholar Feb 9, 2024
b4b2780
fix: 🐛 Fix broken dependencies for MacOS.
rhoadesScholar Feb 9, 2024
55a3892
include and use more biases during watershed post processing of affin…
davidackerman Feb 9, 2024
58c7abe
include weighting argument for affinities+lsd loss
davidackerman Feb 9, 2024
ce71fb5
make predictor node optional
davidackerman Feb 9, 2024
353b8cb
:art: Format Python code with psf/black
davidackerman Feb 9, 2024
4a54e31
Merge pull request #42 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
f243c7c
styles fixes for mypy
mzouink Feb 9, 2024
cebc737
update git action, fix doc and no more publish
mzouink Feb 9, 2024
7feab6a
remove unfinished cli and apply from main
mzouink Feb 9, 2024
53b57b6
Merge branch 'hot_distance' into rhoadesj/hot_distance
rhoadesScholar Feb 9, 2024
5d77af0
fix test action, pytest 8.0.0 working
mzouink Feb 9, 2024
e46acf0
:art: Format Python code with psf/black
mzouink Feb 9, 2024
cea9c03
Merge pull request #45 from janelia-cellmap/actions/black
mzouink Feb 9, 2024
232047c
:art: Format Python code with psf/black
mzouink Feb 9, 2024
70169e2
Merge branch 'rhoadesj/hot_distance' into actions/black
rhoadesScholar Feb 11, 2024
daa41b3
Merge pull request #47 from janelia-cellmap/actions/black
rhoadesScholar Feb 11, 2024
c810a0e
Revert GunpowderTrainer class and configuration to main
rhoadesScholar Feb 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
name: Python Black
name: black-action

on: [push, pull_request]

jobs:
lint:
name: Python Lint
linter_name:
name: runner / black
runs-on: ubuntu-latest
steps:
- name: Setup Python
uses: actions/setup-python@v1
- name: Setup checkout
uses: actions/checkout@master
- name: Lint with Black
run: |
pip install black
black -v --check dacapo tests
- uses: actions/checkout@v2
- name: Check files using the black formatter
uses: rickstaa/action-black@v1
id: action_black
with:
black_args: "."
- name: Create Pull Request
if: steps.action_black.outputs.is_formatted == 'true'
uses: peter-evans/create-pull-request@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: "Format Python code with psf/black push"
commit-message: ":art: Format Python code with psf/black"
body: |
There appear to be some python formatting errors in ${{ github.sha }}. This pull request
uses the [psf/black](https://github.com/psf/black) formatter to fix these issues.
base: ${{ github.head_ref }} # Creates pull request onto pull request or commit branch
branch: actions/black
9 changes: 4 additions & 5 deletions .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
name: Pages
on:
push:
branches:
- master
name: Generate Pages

on: [push, pull_request]

jobs:
docs:
runs-on: ubuntu-latest
Expand Down
34 changes: 0 additions & 34 deletions .github/workflows/publish.yaml

This file was deleted.

3 changes: 1 addition & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: Test

on:
push:
on: [push, pull_request]

jobs:
test:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
![DaCapo](docs/source/_static/dacapo.svg)
# DaCapo ![DaCapo](docs/source/_static/icon_dacapo.png)

[![tests](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/tests.yaml)
[![black](https://github.com/funkelab/dacapo/actions/workflows/black.yaml/badge.svg)](https://github.com/funkelab/dacapo/actions/workflows/black.yaml)
Expand Down
197 changes: 5 additions & 192 deletions dacapo/apply.py
Original file line number Diff line number Diff line change
@@ -1,200 +1,13 @@
import logging
from typing import Optional
from funlib.geometry import Roi, Coordinate
import numpy as np
from dacapo.experiments.datasplits.datasets.arrays.array import Array
from dacapo.experiments.datasplits.datasets.dataset import Dataset
from dacapo.experiments.run import Run

from dacapo.experiments.tasks.post_processors.post_processor_parameters import (
PostProcessorParameters,
)
import dacapo.experiments.tasks.post_processors as post_processors
from dacapo.store.array_store import LocalArrayIdentifier
from dacapo.predict import predict
from dacapo.compute_context import LocalTorch, ComputeContext
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.store import (
create_config_store,
create_weights_store,
)

from pathlib import Path

logger = logging.getLogger(__name__)


def apply(
run_name: str,
input_container: Path or str,
input_dataset: str,
output_path: Path or str,
validation_dataset: Optional[Dataset or str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[PostProcessorParameters or str] = None,
roi: Optional[Roi or str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype or str] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
file_format: str = "zarr",
):
"""Load weights and apply a model to a dataset. If iteration is None, the best iteration based on the criterion is used. If roi is None, the whole input dataset is used."""
if isinstance(output_dtype, str):
output_dtype = np.dtype(output_dtype)

if isinstance(roi, str):
start, end = zip(
*[
tuple(int(coord) for coord in axis.split(":"))
for axis in roi.strip("[]").split(",")
]
)
roi = Roi(
Coordinate(start),
Coordinate(end) - Coordinate(start),
)

assert (validation_dataset is not None and isinstance(criterion, str)) or (
isinstance(iteration, int)
), "Either validation_dataset and criterion, or iteration must be provided."

# retrieving run
logger.info("Loading run %s", run_name)
config_store = create_config_store()
run_config = config_store.retrieve_run_config(run_name)
run = Run(run_config)

# create weights store
weights_store = create_weights_store()

# load weights
if iteration is None:
# weights_store._load_best(run, criterion)
iteration = weights_store.retrieve_best(run_name, validation_dataset, criterion)
logger.info("Loading weights for iteration %i", iteration)
weights_store.retrieve_weights(run, iteration) # shouldn't this be load_weights?

# find the best parameters
if isinstance(validation_dataset, str):
val_ds_name = validation_dataset
validation_dataset = [
dataset for dataset in run.datasplit.validate if dataset.name == val_ds_name
][0]
logger.info("Finding best parameters for validation dataset %s", validation_dataset)
if parameters is None:
parameters = run.task.evaluator.get_overall_best_parameters(
validation_dataset, criterion
)
assert (
parameters is not None
), "Unable to retieve parameters. Parameters must be provided explicitly."

elif isinstance(parameters, str):
try:
post_processor_name = parameters.split("(")[0]
post_processor_kwargs = parameters.split("(")[1].strip(")").split(",")
post_processor_kwargs = {
key.strip(): value.strip()
for key, value in [arg.split("=") for arg in post_processor_kwargs]
}
for key, value in post_processor_kwargs.items():
if value.isdigit():
post_processor_kwargs[key] = int(value)
elif value.replace(".", "", 1).isdigit():
post_processor_kwargs[key] = float(value)
except:
raise ValueError(
f"Could not parse parameters string {parameters}. Must be of the form 'post_processor_name(arg1=val1, arg2=val2, ...)'"
)
try:
parameters = getattr(post_processors, post_processor_name)(
**post_processor_kwargs
)
except Exception as e:
logger.error(
f"Could not instantiate post-processor {post_processor_name} with arguments {post_processor_kwargs}.",
exc_info=True,
)
raise e

assert isinstance(
parameters, PostProcessorParameters
), "Parameters must be parsable to a PostProcessorParameters object."

# make array identifiers for input, predictions and outputs
input_array_identifier = LocalArrayIdentifier(input_container, input_dataset)
input_array = ZarrArray.open_from_array_identifier(input_array_identifier)
roi = roi.snap_to_grid(input_array.voxel_size, mode="grow").intersect(
input_array.roi
)
output_container = Path(
output_path,
"".join(Path(input_container).name.split(".")[:-1]) + f".{file_format}",
)
prediction_array_identifier = LocalArrayIdentifier(
output_container, f"prediction_{run_name}_{iteration}"
)
output_array_identifier = LocalArrayIdentifier(
output_container, f"output_{run_name}_{iteration}_{parameters}"
)

def apply(run_name: str, iteration: int, dataset_name: str):
logger.info(
"Applying best results from run %s at iteration %i to dataset %s",
run.name,
"Applying results from run %s at iteration %d to dataset %s",
run_name,
iteration,
Path(input_container, input_dataset),
)
return apply_run(
run,
parameters,
input_array,
prediction_array_identifier,
output_array_identifier,
roi,
num_cpu_workers,
output_dtype,
compute_context,
overwrite,
)


def apply_run(
run: Run,
parameters: PostProcessorParameters,
input_array: Array,
prediction_array_identifier: LocalArrayIdentifier,
output_array_identifier: LocalArrayIdentifier,
roi: Optional[Roi] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[np.dtype] = np.uint8,
compute_context: ComputeContext = LocalTorch(),
overwrite: bool = True,
):
"""Apply the model to a dataset. If roi is None, the whole input dataset is used. Assumes model is already loaded."""
run.model.eval()

# render prediction dataset
logger.info("Predicting on dataset %s", prediction_array_identifier)
predict(
run.model,
input_array,
prediction_array_identifier,
output_roi=roi,
num_cpu_workers=num_cpu_workers,
output_dtype=output_dtype,
compute_context=compute_context,
overwrite=overwrite,
dataset_name,
)

# post-process the output
logger.info("Post-processing output to dataset %s", output_array_identifier)
post_processor = run.task.post_processor
post_processor.set_prediction(prediction_array_identifier)
post_processor.process(
parameters, output_array_identifier, overwrite=overwrite, blockwise=True
)

logger.info("Done")
return
raise NotImplementedError("This function is not yet implemented.")
55 changes: 11 additions & 44 deletions dacapo/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

import dacapo
import click
import logging
Expand Down Expand Up @@ -42,52 +40,21 @@ def validate(run_name, iteration):

@cli.command()
@click.option(
"-r", "--run_name", required=True, type=str, help="The name of the run to use."
"-r", "--run-name", required=True, type=str, help="The name of the run to use."
)
@click.option(
"-ic",
"--input_container",
"-i",
"--iteration",
required=True,
type=click.Path(exists=True, file_okay=False),
type=int,
help="The iteration weights and parameters to use.",
)
@click.option("-id", "--input_dataset", required=True, type=str)
@click.option("-op", "--output_path", required=True, type=click.Path(file_okay=False))
@click.option("-vd", "--validation_dataset", type=str, default=None)
@click.option("-c", "--criterion", default="voi")
@click.option("-i", "--iteration", type=int, default=None)
@click.option("-p", "--parameters", type=str, default=None)
@click.option(
"-roi",
"--roi",
"-r",
"--dataset",
required=True,
type=str,
required=False,
help="The roi to predict on. Passed in as [lower:upper, lower:upper, ... ]",
help="The name of the dataset to apply the run to.",
)
@click.option("-w", "--num_cpu_workers", type=int, default=30)
@click.option("-dt", "--output_dtype", type=str, default="uint8")
def apply(
run_name: str,
input_container: str,
input_dataset: str,
output_path: str,
validation_dataset: Optional[str] = None,
criterion: Optional[str] = "voi",
iteration: Optional[int] = None,
parameters: Optional[str] = None,
roi: Optional[str] = None,
num_cpu_workers: int = 30,
output_dtype: Optional[str] = "uint8",
):
dacapo.apply(
run_name,
input_container,
input_dataset,
output_path,
validation_dataset,
criterion,
iteration,
parameters,
roi,
num_cpu_workers,
output_dtype,
)
def apply(run_name, iteration, dataset_name):
dacapo.apply(run_name, iteration, dataset_name)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np

from typing import Dict, Any
import logging

logger = logging.getLogger(__file__)


class ConcatArray(Array):
Expand Down Expand Up @@ -116,5 +119,7 @@ def __getitem__(self, roi: Roi) -> np.ndarray:
axis=0,
)
if concatenated.shape[0] == 1:
raise Exception(f"{concatenated.shape}, shapes")
logger.info(
f"Concatenated array has only one channel: {self.name} {concatenated.shape}"
)
return concatenated
Loading
Loading