Skip to content

Commit

Permalink
[Feature] Buffer device (#87)
Browse files Browse the repository at this point in the history
* amend

* amend

* empty
  • Loading branch information
matteobettini authored Jun 8, 2024
1 parent 0a28a16 commit ac59796
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 8 deletions.
8 changes: 5 additions & 3 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, experiment):
self.experiment = experiment

self.device: DEVICE_TYPING = experiment.config.train_device
self.buffer_device: DEVICE_TYPING = experiment.config.buffer_device
self.experiment_config = experiment.config
self.model_config = experiment.model_config
self.critic_model_config = experiment.critic_model_config
Expand Down Expand Up @@ -141,11 +142,12 @@ def get_replay_buffer(
"""
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
storing_device = self.device
sampler = SamplerWithoutReplacement() if self.on_policy else RandomSampler()

return TensorDictReplayBuffer(
storage=LazyTensorStorage(memory_size, device=storing_device),
storage=LazyTensorStorage(
memory_size,
device=self.device if self.on_policy else self.buffer_device,
),
sampler=sampler,
batch_size=sampling_size,
priority_key=(group, "td_error"),
Expand Down
2 changes: 2 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defaults:
sampling_device: "cpu"
# The device for training (e.g. cuda)
train_device: "cpu"
# The device for the replay buffer of off-policy algorithms (e.g. cuda)
buffer_device: "cpu"

# Whether to share the parameters of the policy within agent groups
share_policy_params: True
Expand Down
9 changes: 5 additions & 4 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ExperimentConfig:

sampling_device: str = MISSING
train_device: str = MISSING
buffer_device: str = MISSING

share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
Expand Down Expand Up @@ -462,9 +463,9 @@ def _setup_collector(self):
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=self.config.off_policy_init_random_frames
if not self.on_policy
else 0,
init_random_frames=(
self.config.off_policy_init_random_frames if not self.on_policy else 0
),
)

def _setup_name(self):
Expand Down Expand Up @@ -647,7 +648,7 @@ def _get_excluded_keys(self, group: str):
return excluded_keys

def _optimizer_loop(self, group: str) -> TensorDictBase:
subdata = self.replay_buffers[group].sample()
subdata = self.replay_buffers[group].sample().to(self.config.train_device)
loss_vals = self.losses[group](subdata)
training_td = loss_vals.detach()
loss_vals = self.algorithm.process_loss_vals(group, loss_vals)
Expand Down
5 changes: 4 additions & 1 deletion benchmarl/hydra_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from omegaconf import DictConfig, OmegaConf


def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
def load_experiment_from_hydra(
cfg: DictConfig, task_name: str, callbacks=()
) -> Experiment:
"""Creates an :class:`~benchmarl.experiment.Experiment` from hydra config.
Args:
Expand All @@ -41,6 +43,7 @@ def load_experiment_from_hydra(cfg: DictConfig, task_name: str) -> Experiment:
critic_model_config=critic_model_config,
seed=cfg.seed,
config=experiment_config,
callbacks=callbacks,
)


Expand Down
1 change: 1 addition & 0 deletions fine_tuned/smacv2/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ seed: 0
experiment:
sampling_device: "cpu"
train_device: "cuda"
buffer_device: "cuda"

share_policy_params: True
prefer_continuous_actions: True
Expand Down
1 change: 1 addition & 0 deletions fine_tuned/vmas/conf/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ experiment:

sampling_device: "cuda"
train_device: "cuda"
buffer_device: "cuda"

share_policy_params: True
prefer_continuous_actions: True
Expand Down

0 comments on commit ac59796

Please sign in to comment.