diff --git a/dacapo/experiments/starts/start.py b/dacapo/experiments/starts/start.py index 70f77e316..68dcc0a28 100644 --- a/dacapo/experiments/starts/start.py +++ b/dacapo/experiments/starts/start.py @@ -25,6 +25,9 @@ def initialize_weights(self, model): model_dict = model.state_dict() common_layers = set(model_dict.keys()) & set(weights.model.keys()) for layer in common_layers: - model_dict[layer] = weights.model[layer] + if model_dict[layer].shape == weights.model[layer].shape: + model_dict[layer] = weights.model[layer] + else: + logger.warning(f"layer {layer} has different shape, not loading") model.load_state_dict(model_dict) logger.warning(f"loaded only common layers from weights")