Skip to content

Commit

Permalink
Dev/main (#218)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar authored Mar 27, 2024
2 parents 4b9e01e + c4ca798 commit 40c6e22
Show file tree
Hide file tree
Showing 12 changed files with 55 additions and 53 deletions.
3 changes: 2 additions & 1 deletion dacapo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__version__ = 0.3.0
__version__ = "0.3.0"
__version_info__ = tuple(int(i) for i in __version__.split("."))

from .options import Options # noqa
from . import experiments, utils # noqa
Expand Down
2 changes: 1 addition & 1 deletion dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def run_blockwise(
read_roi: Roi,
write_roi: Roi,
num_workers: int = 16,
max_retries: int = 2,
max_retries: int = 1,
timeout=None,
upstream_tasks=None,
*args,
Expand Down
15 changes: 10 additions & 5 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,23 @@ def __init__(self, run_config):

# preloaded weights from previous run
self.start = (
run_config.start_config.start_type(run_config.start_config)
(
run_config.start_config.start_type(run_config.start_config)
if hasattr(run_config.start_config, "start_type")
else Start(run_config.start_config)
)
if run_config.start_config is not None
else None
)
if self.start is None:
return
else:

new_head = None
if hasattr(run_config, "task_config"):
if hasattr(run_config.task_config, "channels"):
new_head = run_config.task_config.channels
else:
new_head = None
self.start.initialize_weights(self.model, new_head=new_head)

self.start.initialize_weights(self.model, new_head=new_head)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
9 changes: 5 additions & 4 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

if hasattr(start_config.task_config, "channels"):
self.channels = start_config.task_config.channels
else:
self.channels = None
self.channels = None

if hasattr(start_config, "task_config"):
if hasattr(start_config.task_config, "channels"):
self.channels = start_config.task_config.channels

def initialize_weights(self, model, new_head=None):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@

@attr.s
class InstanceEvaluationScores(EvaluationScores):
criteria = ["voi_split", "voi_merge", "voi", "avg_iou"]
criteria = ["voi_split", "voi_merge", "voi"]

voi_split: float = attr.ib(default=float("nan"))
voi_merge: float = attr.ib(default=float("nan"))
avg_iou: float = attr.ib(default=float("nan"))

@property
def voi(self):
Expand All @@ -22,7 +21,6 @@ def higher_is_better(criterion: str) -> bool:
"voi_split": False,
"voi_merge": False,
"voi": False,
"avg_iou": True,
}
return mapping[criterion]

Expand All @@ -34,7 +32,6 @@ def bounds(
"voi_split": (0, 1),
"voi_merge": (0, 1),
"voi": (0, 1),
"avg_iou": (0, None),
}
return mapping[criterion]

Expand Down
31 changes: 12 additions & 19 deletions dacapo/experiments/tasks/evaluators/instance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@

from .evaluator import Evaluator
from .instance_evaluation_scores import InstanceEvaluationScores

from funlib.evaluate import rand_voi, detection_scores
from dacapo.utils.voi import voi as _voi

import numpy as np
import numpy_indexed as npi

import logging

logger = logging.getLogger(__name__)


def relabel(array, return_backwards_map=False, inplace=False):
"""Relabel array, such that IDs are consecutive. Excludes 0.
Expand Down Expand Up @@ -68,34 +71,24 @@ def relabel(array, return_backwards_map=False, inplace=False):


class InstanceEvaluator(Evaluator):
criteria: List[str] = ["voi_merge", "voi_split", "voi", "avg_iou"]
criteria: List[str] = ["voi_merge", "voi_split", "voi"]

def evaluate(self, output_array_identifier, evaluation_array):
output_array = ZarrArray.open_from_array_identifier(output_array_identifier)
evaluation_data = evaluation_array[evaluation_array.roi].astype(np.uint64)
output_data = output_array[output_array.roi].astype(np.uint64)
results = rand_voi(evaluation_data, output_data)
try:
output_data, _ = relabel(output_data)
results.update(
detection_scores(
evaluation_data,
output_data,
matching_score="iou",
)
)
except Exception:
results["avg_iou"] = 0
logger.warning(
"Could not compute IoU because of an unknown error. Sorry about that."
)
results = voi(evaluation_data, output_data)

return InstanceEvaluationScores(
voi_merge=results["voi_merge"],
voi_split=results["voi_split"],
avg_iou=results["avg_iou"],
)

@property
def score(self) -> InstanceEvaluationScores:
return InstanceEvaluationScores()


def voi(truth, test):
voi_split, voi_merge = _voi(test + 1, truth + 1, ignore_groundtruth=[])
return {"voi_split": voi_split, "voi_merge": voi_merge}
12 changes: 10 additions & 2 deletions dacapo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from dacapo.blockwise import run_blockwise
import dacapo.blockwise
from dacapo.experiments import Run
from dacapo.store.create_store import create_config_store
from dacapo.store.create_store import create_config_store, create_weights_store
from dacapo.store.local_array_store import LocalArrayIdentifier
from dacapo.experiments.datasplits.datasets.arrays import ZarrArray
from dacapo.compute_context import create_compute_context, LocalTorch

from funlib.geometry import Coordinate, Roi
import numpy as np
import zarr

from typing import Optional
import logging
Expand Down Expand Up @@ -75,6 +74,15 @@ def predict(

model = run.model.eval()

if iteration is not None:
# create weights store
weights_store = create_weights_store()

# load weights
run.model.load_state_dict(
weights_store.retrieve_weights(run_name, iteration).model
)

input_voxel_size = Coordinate(raw_array.voxel_size)
output_voxel_size = model.scale(input_voxel_size)
input_shape = Coordinate(model.eval_input_shape)
Expand Down
3 changes: 1 addition & 2 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def validate_run(
or len(run.datasplit.validate) == 0
or run.datasplit.validate[0].gt is None
):
print(f"Cannot validate run {run.name}. Continuing training!")
return None, None
raise ValueError(f"Cannot validate run {run.name} at iteration {iteration}.")

# get array and weight store
array_store = create_array_store()
Expand Down
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,22 @@ dependencies = [
"attrs",
"bokeh",
"numpy-indexed>=0.3.7",
"daisy @ git+https://github.com/funkelab/daisy",
"daisy",
"funlib.math>=0.1",
"funlib.geometry>=0.2",
"mwatershed>=0.1",
"cellmap-models",
"funlib.persistence @ git+https://github.com/janelia-cellmap/funlib.persistence",
"funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate",
"funlib.persistence>=0.3.0",
"gunpowder>=1.3",
# "lsds>=0.1.3",
"lsds @ git+https://github.com/funkelab/lsd",
# "lsds @ git+https://github.com/funkelab/lsd",
"lsds",
"xarray",
"cattrs",
"numpy-indexed",
"click",
"pyyaml",
"scipy",
]

# extras
Expand All @@ -64,13 +65,11 @@ dependencies = [
test = ["pytest==7.4.4", "pytest-cov", "pytest-lazy-fixture"]
dev = [
"black",
"ipython",
"mypy",
"pdbpp",
"rich",
"ruff",
"pre-commit",
"jupyter",
]
docs = [
"sphinx-autodoc-typehints",
Expand Down
4 changes: 2 additions & 2 deletions tests/operations/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
@pytest.mark.parametrize(
"run_config",
[
lazy_fixture("distance_run"),
# lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
lazy_fixture("onehot_run"),
# lazy_fixture("onehot_run"),
],
)
def test_apply(options, run_config, zarr_array, tmp_path):
Expand Down
6 changes: 3 additions & 3 deletions tests/operations/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
@pytest.mark.parametrize(
"run_config",
[
lazy_fixture("distance_run"),
# lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
lazy_fixture("onehot_run"),
# lazy_fixture("onehot_run"),
],
)
def test_predict(options, run_config, zarr_array, tmp_path):
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_predict(options, run_config, zarr_array, tmp_path):
)

# test predicting with iterations for which we know there are no weights
with pytest.raises(ValueError):
with pytest.raises(FileNotFoundError):
predict(
run_config.name,
iteration=2,
Expand Down
7 changes: 3 additions & 4 deletions tests/operations/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
"run_config",
[
lazy_fixture("distance_run"),
lazy_fixture("dummy_run"),
lazy_fixture("onehot_run"),
# lazy_fixture("onehot_run"),
],
)
def test_validate(
Expand Down Expand Up @@ -58,8 +57,8 @@ def test_validate(
# test validating iterations for which we know there are weights
weights_store.store_weights(run, 0)
validate(run_config.name, 0, num_workers=4)
weights_store.store_weights(run, 1)
validate(run_config.name, 1, num_workers=4)
# weights_store.store_weights(run, 1)
# validate(run_config.name, 1, num_workers=4)

# test validating weights that don't exist
with pytest.raises(FileNotFoundError):
Expand Down

0 comments on commit 40c6e22

Please sign in to comment.