diff --git a/README.md b/README.md index 9b812a21aa0..64559f7af37 100644 --- a/README.md +++ b/README.md @@ -99,68 +99,69 @@ lines of code*! from torchrl.collectors import SyncDataCollector from torchrl.data.replay_buffers import TensorDictReplayBuffer, \ - LazyTensorStorage, SamplerWithoutReplacement + LazyTensorStorage, SamplerWithoutReplacement from torchrl.envs.libs.gym import GymEnv from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal from torchrl.objectives import ClipPPOLoss from torchrl.objectives.value import GAE - env = GymEnv("Pendulum-v1") + env = GymEnv("Pendulum-v1") model = TensorDictModule( - nn.Sequential( - nn.Linear(3, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 2), - NormalParamExtractor() - ), - in_keys=["observation"], - out_keys=["loc", "scale"] + nn.Sequential( + nn.Linear(3, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 2), + NormalParamExtractor() + ), + in_keys=["observation"], + out_keys=["loc", "scale"] ) critic = ValueOperator( - nn.Sequential( - nn.Linear(3, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 128), nn.Tanh(), - nn.Linear(128, 1), - ), - in_keys=["observation"], + nn.Sequential( + nn.Linear(3, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 128), nn.Tanh(), + nn.Linear(128, 1), + ), + in_keys=["observation"], ) actor = ProbabilisticActor( - model, - in_keys=["loc", "scale"], - distribution_class=TanhNormal, - distribution_kwargs={"min": -1.0, "max": 1.0}, - return_log_prob=True - ) + model, + in_keys=["loc", "scale"], + distribution_class=TanhNormal, + distribution_kwargs={"low": -1.0, "high": 1.0}, + return_log_prob=True + ) buffer = TensorDictReplayBuffer( - LazyTensorStorage(1000), - SamplerWithoutReplacement() - ) + storage=LazyTensorStorage(1000), + sampler=SamplerWithoutReplacement(), + batch_size=50, + ) collector = SyncDataCollector( - env, - actor, - frames_per_batch=1000, - total_frames=1_000_000 - ) - loss_fn = ClipPPOLoss(actor, critic, gamma=0.99) + env, + actor, + frames_per_batch=1000, + total_frames=1_000_000, + ) + loss_fn = ClipPPOLoss(actor, critic) + adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95) optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4) - adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True) + for data in collector: # collect data - for epoch in range(10): - adv_fn(data) # compute advantage - buffer.extend(data.view(-1)) - for i in range(20): # consume data - sample = buffer.sample(50) # mini-batch - loss_vals = loss_fn(sample) - loss_val = sum( - value for key, value in loss_vals.items() if - key.startswith("loss") - ) - loss_val.backward() - optim.step() - optim.zero_grad() - print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}") + for epoch in range(10): + adv_fn(data) # compute advantage + buffer.extend(data) + for sample in buffer: # consume data + loss_vals = loss_fn(sample) + loss_val = sum( + value for key, value in loss_vals.items() if + key.startswith("loss") + ) + loss_val.backward() + optim.step() + optim.zero_grad() + print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}") ```