Skip to content

Commit

Permalink
reset run and starter, will be a different pull request
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink committed Feb 7, 2024
1 parent 68dbb60 commit b33eec2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 111 deletions.
41 changes: 8 additions & 33 deletions dacapo/experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from .validation_scores import ValidationScores
from .starts import Start
from .model import Model
import logging
import torch

logger = logging.getLogger(__file__)
import torch


class Run:
Expand Down Expand Up @@ -55,37 +53,14 @@ def __init__(self, run_config):
self.task.parameters, self.datasplit.validate, self.task.evaluation_scores
)

if run_config.start_config is None:
return
try:
from ..store import create_config_store

start_config_store = create_config_store()
starter_config = start_config_store.retrieve_run_config(
run_config.start_config.run
)
except Exception as e:
logger.error(
f"could not load start config: {e} Should be added to the database config store RUN"
)
raise e

# preloaded weights from previous run
if run_config.task_config.name == starter_config.task_config.name:
self.start = Start(run_config.start_config)
else:
# Match labels between old and new head
if hasattr(run_config.task_config, "channels"):
# Map old head and new head
old_head = starter_config.task_config.channels
new_head = run_config.task_config.channels
self.start = Start(
run_config.start_config, old_head=old_head, new_head=new_head
)
else:
logger.warning("Not implemented channel match for this task")
self.start = Start(run_config.start_config, remove_head=True)
self.start.initialize_weights(self.model)
self.start = (
Start(run_config.start_config)
if run_config.start_config is not None
else None
)
if self.start is not None:
self.start.initialize_weights(self.model)

@staticmethod
def get_validation_scores(run_config) -> ValidationScores:
Expand Down
83 changes: 5 additions & 78 deletions dacapo/experiments/starts/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,94 +3,21 @@

logger = logging.getLogger(__file__)

# self.old_head =["ecs","plasma_membrane","mito","mito_membrane","vesicle","vesicle_membrane","mvb","mvb_membrane","er","er_membrane","eres","nucleus","microtubules","microtubules_out"]
# self.new_head = ["mito","nucleus","ld","ecs","peroxisome"]
head_keys = [
"prediction_head.weight",
"prediction_head.bias",
"chain.1.weight",
"chain.1.bias",
]

# Hack
# if label is mito_peroxisome or peroxisome then change it to mito
mitos = ["mito_proxisome", "peroxisome"]


def match_heads(model, head_weights, old_head, new_head):
# match the heads
for label in new_head:
old_label = label
if label in mitos:
old_label = "mito"
if old_label in old_head:
logger.warning(f"matching head for {label}")
# find the index of the label in the old_head
old_index = old_head.index(old_label)
# find the index of the label in the new_head
new_index = new_head.index(label)
# get the weight and bias of the old head
for key in head_keys:
if key in model.state_dict().keys():
n_val = head_weights[key][old_index]
model.state_dict()[key][new_index] = n_val
logger.warning(f"matched head for {label} with {old_label}")


class Start(ABC):
def __init__(self, start_config, remove_head=False, old_head=None, new_head=None):
def __init__(self, start_config):
self.run = start_config.run
self.criterion = start_config.criterion
self.remove_head = remove_head
self.old_head = old_head
self.new_head = new_head

def initialize_weights(self, model):
from dacapo.store.create_store import create_weights_store

weights_store = create_weights_store()
weights = weights_store._retrieve_weights(self.run, self.criterion)
logger.info(f"loading weights from run {self.run}, criterion: {self.criterion}")

logger.warning(
f"loading weights from run {self.run}, criterion: {self.criterion}"
)

# load the model weights (taken from torch load_state_dict source)
try:
if self.old_head and self.new_head:
try:
self.load_model_using_head_matching(model, weights)
except RuntimeError as e:
logger.error(f"ERROR starter matching head: {e}")
self.load_model_using_head_removal(model, weights)
elif self.remove_head:
self.load_model_using_head_removal(model, weights)
else:
model.load_state_dict(weights.model)
model.load_state_dict(weights.model)
except RuntimeError as e:
logger.warning(f"ERROR starter: {e}")

def load_model_using_head_removal(self, model, weights):
logger.warning(
f"removing head from run {self.run}, criterion: {self.criterion}"
)
for key in head_keys:
weights.model.pop(key, None)
logger.warning(f"removed head from run {self.run}, criterion: {self.criterion}")
model.load_state_dict(weights.model, strict=False)
logger.warning(
f"loaded weights in non strict mode from run {self.run}, criterion: {self.criterion}"
)

def load_model_using_head_matching(self, model, weights):
logger.warning(
f"matching heads from run {self.run}, criterion: {self.criterion}"
)
logger.warning(f"old head: {self.old_head}")
logger.warning(f"new head: {self.new_head}")
head_weights = {}
for key in head_keys:
head_weights[key] = weights.model[key]
for key in head_keys:
weights.model.pop(key, None)
model.load_state_dict(weights.model, strict=False)
model = match_heads(model, head_weights, self.old_head, self.new_head)
logger.warning(e)

0 comments on commit b33eec2

Please sign in to comment.