From 61226eff203cd11d6ca218173680de73c4451014 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 21 Mar 2024 20:53:28 -0400 Subject: [PATCH] Remove reward buffer, move params calc --- pokemonred_puffer/cleanrl_puffer.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index eb49bf6..3d37047 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -251,10 +251,6 @@ def __init__( # TODO: Figure out how to compile the optimizer! # self.calculate_loss = torch.compile(self.calculate_loss, mode=config.compile_mode) - if config.verbose: - self.n_params = sum(p.numel() for p in self.agent.parameters() if p.requires_grad) - print(f"Model Size: {self.n_params//1000} K parameters") - if self.opt_state is not None: self.optimizer.load_state_dict(resume_state["optimizer_state_dict"]) @@ -437,27 +433,13 @@ def evaluate(self): with env_profiler: self.pool.send(actions) - self.reward_buffer.append(r.cpu().sum().numpy()) - # Probably should normalize the rewards before trying to take the variance... - reward_var = np.var(self.reward_buffer) - if self.log and self.wandb is not None: - self.wandb.log( - { - "reward/reward_var": reward_var, - "reward/reward_buffer_len": len(self.reward_buffer), - }, - ) - if ( - self.taught_cut - and len(self.reward_buffer) == self.reward_buffer.maxlen - and reward_var < 2.5e-3 - ): - self.reward_buffer.clear() - # reset lr update if the reward starts stalling - self.lr_update = 1.0 - eval_profiler.stop() + # Now that we initialized the model, we can get the number of parameters + if self.global_step == 0 and self.config.verbose: + self.n_params = sum(p.numel() for p in self.agent.parameters() if p.requires_grad) + print(f"Model Size: {self.n_params//1000} K parameters") + self.total_agent_steps += padded_steps_collected new_step = np.mean(self.infos["learner"]["stats/step"]) if new_step > self.global_step: