diff --git a/MARBLE/main.py b/MARBLE/main.py index a734b3eb..c8ee3502 100644 --- a/MARBLE/main.py +++ b/MARBLE/main.py @@ -413,7 +413,8 @@ def load_model(self, loadpath): self._epoch = checkpoint["epoch"] self.load_state_dict(checkpoint["model_state_dict"]) self.optimizer_state_dict = checkpoint["optimizer_state_dict"] - self.losses = checkpoint["losses"] + if hatattr(self, 'losses'): + self.losses = checkpoint["losses"] def save_model(self, optimizer, losses, outdir=None, best=False, timestamp=""): """Save model."""