Skip to content

Commit

Permalink
feat: ⚡️ Incorporate start related changes from rhoadesj/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
rhoadesScholar committed Feb 8, 2024
1 parent 5f50f9b commit 812acc1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 11 deletions.
41 changes: 33 additions & 8 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model

import logging
import torch

logger = logging.getLogger(__file__)


class Run:
name: str
Expand Down Expand Up @@ -53,14 +55,37 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
62 changes: 59 additions & 3 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,77 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]


def match_heads(model, weights, old_head, new_head):
# match the heads
for label in new_head:
if label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]:
if key in model.state_dict().keys():
n_val = weights.model[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label}")
return model


class Start(ABC):
def __init__(self, start_config):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)

logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

# load the model weights (taken from torch load_state_dict source)
try:
model.load_state_dict(weights.model)
if self.old_head and self.new_head:
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.info(f"old head: {self.old_head}")
logger.info(f"new head: {self.new_head}")
model = match_heads(model, weights, self.old_head, self.new_head)
logger.warning(
f"matched heads from run {self.run}, criterion: {self.criterion}"
)
self.remove_head = True
if self.remove_head:
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
weights.model.pop("prediction_head.weight", None)
weights.model.pop("prediction_head.bias", None)
weights.model.pop("chain.1.weight", None)
weights.model.pop("chain.1.bias", None)
logger.warning(
f"removed head from run {self.run}, criterion: {self.criterion}"
)
model.load_state_dict(weights.model, strict=False)
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)
else:
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(e)

0 comments on commit 812acc1

Please sign in to comment.