Skip to content

Commit

Permalink
reduce minibatches, fix reference to config
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 21, 2024
1 parent e28afbd commit 948598c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ train:
float32_matmul_precision: "high"
total_timesteps: 100_000_000_000
batch_size: 65536
minibatch_size: 32768
minibatch_size: 2048
learning_rate: 2.0e-4
anneal_lr: False
gamma: 0.998
Expand Down
11 changes: 4 additions & 7 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pufferlib
import pufferlib.emulation
import pufferlib.frameworks.cleanrl
import pufferlib.pytorch
import pufferlib.utils
import pufferlib.vector

Expand Down Expand Up @@ -178,12 +179,6 @@ def __post_init__(self):

if self.config.compile:
self.policy = torch.compile(self.policy, mode=self.config.compile_mode)
self.policy.get_value = torch.compile(
self.policy.get_value, mode=self.config.compile_mode
)
self.policy.get_action_and_value = torch.compile(
self.policy.get_action_and_value, mode=self.config.compile_mode
)

self.optimizer = torch.optim.Adam(
self.policy.parameters(), lr=self.config.learning_rate, eps=1e-5
Expand Down Expand Up @@ -418,7 +413,9 @@ def train(self):
with self.profile.learn:
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
torch.nn.utils.clip_grad_norm_(
self.policy.parameters(), self.config.max_grad_norm
)
self.optimizer.step()
if self.config.device == "cuda":
torch.cuda.synchronize()
Expand Down

0 comments on commit 948598c

Please sign in to comment.