From 4c9d5698a640e4da3afed7f6482d61a1de6a81d7 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Mon, 11 Mar 2024 21:36:29 -0400 Subject: [PATCH] Increase entropy if events stalls --- pokemonred_puffer/cleanrl_puffer.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 60e7b86..f3e1f46 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,3 +1,4 @@ +import math import os import random import time @@ -318,6 +319,8 @@ def __init__( self.infos = {} self.log = False + self.ent_coef = self.config.ent_coef + self.events_avg = deque([0] * 500, maxlen=500) @pufferlib.utils.profile def evaluate(self): @@ -434,25 +437,11 @@ def evaluate(self): with env_profiler: self.pool.send(actions) - self.reward_buffer.append(r.cpu().sum().numpy()) - # Probably should normalize the rewards before trying to take the variance... - reward_var = np.var(self.reward_buffer) - if self.log and self.wandb is not None: - self.wandb.log( - { - "reward/reward_var": reward_var, - "reward/reward_buffer_len": len(self.reward_buffer), - }, - ) - if ( - self.taught_cut - and len(self.reward_buffer) == self.reward_buffer.maxlen - and reward_var < 2.5e-3 - ): - self.reward_buffer.clear() - # reset lr update if the reward starts stalling - self.lr_update = 1.0 - + self.events_avg.append(np.mean(self.infos["learner"]["stats/event"])) + if math.abs(self.events_avg[-1] - self.events_avg[0]) < 3: + self.ent_coef = self.config.ent_coef * 1.25 + else: + self.ent_coef = self.config.ent_coef eval_profiler.stop() self.total_agent_steps += padded_steps_collected @@ -717,7 +706,7 @@ def save_checkpoint(self): return model_path def calculate_loss(self, pg_loss, entropy_loss, v_loss): - loss = pg_loss - self.config.ent_coef * entropy_loss + v_loss * self.config.vf_coef + loss = pg_loss - self.ent_coef * entropy_loss + v_loss * self.config.vf_coef self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.agent.parameters(), self.config.max_grad_norm)