Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] Collecting with gradient #77

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ buffer_device: "cpu"
share_policy_params: True
# If an algorithm and an env support both continuous and discrete actions, what should be preferred
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False

# Discount factor
gamma: 0.9
Expand Down
74 changes: 57 additions & 17 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchrl.collectors import SyncDataCollector
from torchrl.envs import SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
from tqdm import tqdm

Expand Down Expand Up @@ -54,6 +54,7 @@ class ExperimentConfig:

share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING

gamma: float = MISSING
lr: float = MISSING
Expand Down Expand Up @@ -456,17 +457,26 @@ def _setup_collector(self):
assert len(group_policy) == 1
self.group_policies.update({group: group_policy[0]})

self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames if not self.on_policy else 0
),
)
if not self.config.collect_with_grad:
self.collector = SyncDataCollector(
self.env_func,
self.policy,
device=self.config.sampling_device,
storing_device=self.config.train_device,
frames_per_batch=self.config.collected_frames_per_batch(self.on_policy),
total_frames=self.config.get_max_n_frames(self.on_policy),
init_random_frames=(
self.config.off_policy_init_random_frames
if not self.on_policy
else 0
),
)
else:
if self.config.off_policy_init_random_frames and not self.on_policy:
raise TypeError(
"Collection via rollouts does not support initial random frames as of now."
)
self.rollout_env = self.env_func().to(self.config.sampling_device)

def _setup_name(self):
self.algorithm_name = self.algorithm_config.associated_class().__name__.lower()
Expand Down Expand Up @@ -544,8 +554,31 @@ def _collection_loop(self):
)
sampling_start = time.time()

if not self.config.collect_with_grad:
iterator = iter(self.collector)
else:
reset_batch = self.rollout_env.reset()

# Training/collection iterations
for batch in self.collector:
for _ in range(
self.n_iters_performed, self.config.get_max_n_iters(self.on_policy)
):
if not self.config.collect_with_grad:
batch = next(iterator)
else:
with set_exploration_type(ExplorationType.RANDOM):
batch = self.rollout_env.rollout(
max_steps=-(
-self.config.collected_frames_per_batch(self.on_policy)
// self.rollout_env.batch_size.numel()
),
policy=self.policy,
break_when_any_done=False,
auto_reset=False,
tensordict=reset_batch,
)
reset_batch = step_mdp(batch[..., -1])

# Logging collection
collection_time = time.time() - sampling_start
current_frames = batch.numel()
Expand All @@ -560,6 +593,7 @@ def _collection_loop(self):

# Callback
self._on_batch_collected(batch)
batch = batch.detach()

# Loop over groups
training_start = time.time()
Expand Down Expand Up @@ -593,7 +627,8 @@ def _collection_loop(self):
explore_layer.step(current_frames)

# Update policy in collector
self.collector.update_policy_weights_()
if not self.config.collect_with_grad:
self.collector.update_policy_weights_()

# Timers
training_time = time.time() - training_start
Expand Down Expand Up @@ -635,7 +670,10 @@ def _collection_loop(self):

def close(self):
"""Close the experiment."""
self.collector.shutdown()
if not self.config.collect_with_grad:
self.collector.shutdown()
else:
self.rollout_env.close()
self.test_env.close()
self.logger.finish()

Expand Down Expand Up @@ -766,13 +804,14 @@ def state_dict(self) -> OrderedDict:
)
state_dict = OrderedDict(
state=state,
collector=self.collector.state_dict(),
**{f"loss_{k}": item.state_dict() for k, item in self.losses.items()},
**{
f"buffer_{k}": item.state_dict()
for k, item in self.replay_buffers.items()
},
)
if not self.config.collect_with_grad:
state_dict.update({"collector": self.collector.state_dict()})
return state_dict

def load_state_dict(self, state_dict: Dict) -> None:
Expand All @@ -785,7 +824,8 @@ def load_state_dict(self, state_dict: Dict) -> None:
for group in self.group_map.keys():
self.losses[group].load_state_dict(state_dict[f"loss_{group}"])
self.replay_buffers[group].load_state_dict(state_dict[f"buffer_{group}"])
self.collector.load_state_dict(state_dict["collector"])
if not self.config.collect_with_grad:
self.collector.load_state_dict(state_dict["collector"])
self.total_time = state_dict["state"]["total_time"]
self.total_frames = state_dict["state"]["total_frames"]
self.n_iters_performed = state_dict["state"]["n_iters_performed"]
Expand Down
Loading