Skip to content

Commit

Permalink
cosem starter (#170)
Browse files Browse the repository at this point in the history
Cosem Starter using Cellmap_models
```python
from dacapo.experiments.starts import CosemStartConfig
starter_config = CosemStartConfig(name='setup04/1820500')
```
  • Loading branch information
mzouink authored Mar 18, 2024
2 parents eaba59c + ac79956 commit cff9ce7
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 22 deletions.
2 changes: 1 addition & 1 deletion dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, run_config):

# preloaded weights from previous run
self.start = (
Start(run_config.start_config)
run_config.start_config.start_type(run_config.start_config)
if run_config.start_config is not None
else None
)
Expand Down
2 changes: 2 additions & 0 deletions dacapo/experiments/starts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .start import Start # noqa
from .start_config import StartConfig # noqa
from .cosem_start import CosemStart # noqa
from .cosem_start_config import CosemStartConfig # noqa
37 changes: 37 additions & 0 deletions dacapo/experiments/starts/cosem_start.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import ABC
import logging
from cellmap_models import cosem
from pathlib import Path
from .start import Start

logger = logging.getLogger(__file__)


def format_name(name):
if "/" in name:
run, criterion = name.split("/")
return run, criterion
else:
raise ValueError(
f"Invalid starter name format {name}. Must be in the format run/criterion"
)


class CosemStart(Start):
def __init__(self, start_config):
run, criterion = format_name(start_config.name)
self.name = start_config.name
super().__init__(run, criterion)

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights_dir = Path(weights_store.basedir, self.run, "checkpoints", "iterations")
if not (weights_dir / self.criterion).exists():
if not weights_dir.exists():
weights_dir.mkdir(parents=True, exist_ok=True)
path = weights_dir / self.criterion
cosem.download_checkpoint(self.name, path)
weights = weights_store._retrieve_weights(self.run, self.criterion)
super._set_weights(model, weights)
14 changes: 14 additions & 0 deletions dacapo/experiments/starts/cosem_start_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import attr
from .cosem_start import CosemStart


@attr.s
class CosemStartConfig:
"""Starter for COSEM pretained models. This is a subclass of `StartConfig` and
should be used to initialize the model with pretrained weights from a previous
run.
"""

start_type = CosemStart

name: str = attr.ib(metadata={"help_text": "The COSEM checkpoint name to use."})
43 changes: 22 additions & 21 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,7 @@ def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion

def initialize_weights(self, model):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Parameters
----------
model : obj
The model to which the weights are to be loaded.
Raises
------
RuntimeError
If weights of a non-existing or mismatched layer are being
loaded, a RuntimeError exception is thrown which is logged
and handled by loading only the common layers from weights.
"""
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
def _set_weights(self, model, weights):
print(f"loading weights from run {self.run}, criterion: {self.criterion}")
# load the model weights (taken from torch load_state_dict source)
try:
Expand All @@ -72,3 +52,24 @@ def initialize_weights(self, model):
) # update only the existing and matching layers
model.load_state_dict(model_dict)
logger.warning(f"loaded only common layers from weights")

def initialize_weights(self, model):
"""
Retrieves the weights from the dacapo store and load them into
the model.
Parameters
----------
model : obj
The model to which the weights are to be loaded.
Raises
------
RuntimeError
If weights of a non-existing or mismatched layer are being
loaded, a RuntimeError exception is thrown which is logged
and handled by loading only the common layers from weights.
"""
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
self._set_weights(model, weights)
3 changes: 3 additions & 0 deletions dacapo/experiments/starts/start_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import attr
from .start import Start


@attr.s
Expand All @@ -16,6 +17,8 @@ class StartConfig:
"""

start_type = Start

run: str = attr.ib(metadata={"help_text": "The Run to use as a starting point."})
criterion: str = attr.ib(
metadata={"help_text": "The criterion for choosing weights from run."}
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ dependencies = [
"funlib.geometry>=0.2",
"mwatershed>=0.1",
"funlib.persistence",
"cellmap-models",
# "funlib.persistence @ git+https://github.com/janelia-cellmap/funlib.persistence",
"funlib.evaluate @ git+https://github.com/pattonw/funlib.evaluate",
"gunpowder>=1.3",
Expand Down Expand Up @@ -163,6 +164,7 @@ exclude = [
# # module specific overrides
[[tool.mypy.overrides]]
module = [
"cellmap_models.*",
"funlib.*",
"toml.*",
"gunpowder.*",
Expand Down

0 comments on commit cff9ce7

Please sign in to comment.