Skip to content

Commit

Permalink
[Doc] Fix README example (#2398)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Aug 14, 2024
1 parent 25e8bd2 commit e82a69f
Showing 1 changed file with 49 additions and 48 deletions.
97 changes: 49 additions & 48 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,68 +99,69 @@ lines of code*!

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
LazyTensorStorage, SamplerWithoutReplacement
LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

env = GymEnv("Pendulum-v1")
env = GymEnv("Pendulum-v1")
model = TensorDictModule(
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 2),
NormalParamExtractor()
),
in_keys=["observation"],
out_keys=["loc", "scale"]
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 2),
NormalParamExtractor()
),
in_keys=["observation"],
out_keys=["loc", "scale"]
)
critic = ValueOperator(
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 1),
),
in_keys=["observation"],
nn.Sequential(
nn.Linear(3, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 128), nn.Tanh(),
nn.Linear(128, 1),
),
in_keys=["observation"],
)
actor = ProbabilisticActor(
model,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={"min": -1.0, "max": 1.0},
return_log_prob=True
)
model,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={"low": -1.0, "high": 1.0},
return_log_prob=True
)
buffer = TensorDictReplayBuffer(
LazyTensorStorage(1000),
SamplerWithoutReplacement()
)
storage=LazyTensorStorage(1000),
sampler=SamplerWithoutReplacement(),
batch_size=50,
)
collector = SyncDataCollector(
env,
actor,
frames_per_batch=1000,
total_frames=1_000_000
)
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
env,
actor,
frames_per_batch=1000,
total_frames=1_000_000,
)
loss_fn = ClipPPOLoss(actor, critic)
adv_fn = GAE(value_network=critic, average_gae=True, gamma=0.99, lmbda=0.95)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)

for data in collector: # collect data
for epoch in range(10):
adv_fn(data) # compute advantage
buffer.extend(data.view(-1))
for i in range(20): # consume data
sample = buffer.sample(50) # mini-batch
loss_vals = loss_fn(sample)
loss_val = sum(
value for key, value in loss_vals.items() if
key.startswith("loss")
)
loss_val.backward()
optim.step()
optim.zero_grad()
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
for epoch in range(10):
adv_fn(data) # compute advantage
buffer.extend(data)
for sample in buffer: # consume data
loss_vals = loss_fn(sample)
loss_val = sum(
value for key, value in loss_vals.items() if
key.startswith("loss")
)
loss_val.backward()
optim.step()
optim.zero_grad()
print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")
```
</details>

Expand Down

0 comments on commit e82a69f

Please sign in to comment.