Skip to content

Commit

Permalink
[BugFix] Do not reload empty buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 5, 2024
1 parent a930915 commit c37938e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,7 @@ def state_dict(self) -> OrderedDict:
state=state,
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict()
f"buffer_{k}": item.state_dict() if len(item) else None
for k, item in self.replay_buffers.items()
},
)
Expand All @@ -832,7 +832,10 @@ def load_state_dict(self, state_dict: Dict) -> None:
"""
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"])
if state_dict[f"buffer_{group}"] is not None:
self.replay_buffers[group].load_state_dict(
state_dict[f"buffer_{group}"]
)
if not self.config.collect_with_grad:
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
Expand Down

0 comments on commit c37938e

Please sign in to comment.