diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 80dfd5602..a31421a75 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -605,11 +605,10 @@ def load_checkpoint( target_key = key if target_key in self.normalizers: - mkeys = ( - self.normalizers[target_key] - .load_state_dict(checkpoint["normalizers"][key]) - .to(map_location) + mkeys = self.normalizers[target_key].load_state_dict( + checkpoint["normalizers"][key] ) + self.normalizers[target_key].to(map_location) assert len(mkeys.missing_keys) == 0 assert len(mkeys.unexpected_keys) == 0