From c37938e093fcd9d6a0851473463115cc9f495343 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Fri, 5 Jul 2024 10:26:29 +0100 Subject: [PATCH] [BugFix] Do not reload empty buffers --- benchmarl/experiment/experiment.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmarl/experiment/experiment.py b/benchmarl/experiment/experiment.py index 6ace35b4..1d8a03a2 100644 --- a/benchmarl/experiment/experiment.py +++ b/benchmarl/experiment/experiment.py @@ -815,7 +815,7 @@ def state_dict(self) -> OrderedDict: state=state, **{f"loss_{k}": item.state_dict() for k, item in self.losses.items()}, **{ - f"buffer_{k}": item.state_dict() + f"buffer_{k}": item.state_dict() if len(item) else None for k, item in self.replay_buffers.items() }, ) @@ -832,7 +832,10 @@ def load_state_dict(self, state_dict: Dict) -> None: """ for group in self.group_map.keys(): self.losses[group].load_state_dict(state_dict[f"loss_{group}"]) - self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"]) + if state_dict[f"buffer_{group}"] is not None: + self.replay_buffers[group].load_state_dict( + state_dict[f"buffer_{group}"] + ) if not self.config.collect_with_grad: self.collector.load_state_dict(state_dict["collector"]) self.total_time = state_dict["state"]["total_time"]