From 948598cea91e65aad6e71a7e9592fe9357edc31e Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Fri, 21 Jun 2024 09:37:21 -0400 Subject: [PATCH] reduce minibatches, fix reference to config --- config.yaml | 2 +- pokemonred_puffer/cleanrl_puffer.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/config.yaml b/config.yaml index f02a0b8..bcf1c13 100644 --- a/config.yaml +++ b/config.yaml @@ -74,7 +74,7 @@ train: float32_matmul_precision: "high" total_timesteps: 100_000_000_000 batch_size: 65536 - minibatch_size: 32768 + minibatch_size: 2048 learning_rate: 2.0e-4 anneal_lr: False gamma: 0.998 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index d26c089..e308938 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -12,6 +12,7 @@ import pufferlib import pufferlib.emulation import pufferlib.frameworks.cleanrl +import pufferlib.pytorch import pufferlib.utils import pufferlib.vector @@ -178,12 +179,6 @@ def __post_init__(self): if self.config.compile: self.policy = torch.compile(self.policy, mode=self.config.compile_mode) - self.policy.get_value = torch.compile( - self.policy.get_value, mode=self.config.compile_mode - ) - self.policy.get_action_and_value = torch.compile( - self.policy.get_action_and_value, mode=self.config.compile_mode - ) self.optimizer = torch.optim.Adam( self.policy.parameters(), lr=self.config.learning_rate, eps=1e-5 @@ -418,7 +413,9 @@ def train(self): with self.profile.learn: self.optimizer.zero_grad() loss.backward() - torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.config.max_grad_norm + ) self.optimizer.step() if self.config.device == "cuda": torch.cuda.synchronize()