diff --git a/config.yaml b/config.yaml index b22c002..6940567 100644 --- a/config.yaml +++ b/config.yaml @@ -15,11 +15,13 @@ debug: device: cpu compile: False compile_mode: default - num_envs: 2 + num_envs: 1 num_workers: 1 - env_batch_size: 16 + env_batch_size: 4 env_pool: True zero_copy: False + batch_size: 128 + minibatch_size: 128 batch_rows: 4 bptt_horizon: 2 total_timesteps: 100_000_000 @@ -63,6 +65,7 @@ env: auto_pokeflute: True infinite_money: True use_global_map: False + save_state: False train: @@ -73,7 +76,7 @@ train: compile_mode: "reduce-overhead" float32_matmul_precision: "high" total_timesteps: 100_000_000_000 - batch_size: 65536 + batch_size: 131072 # 65536 minibatch_size: 2048 learning_rate: 2.0e-4 anneal_lr: False diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 91fe6ad..c126c9e 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -133,7 +133,6 @@ class CleanPuffeRL: wandb_client: wandb.wandb_sdk.wandb_run.Run | None = None profile: Profile = field(default_factory=lambda: Profile()) losses: Losses = field(default_factory=lambda: Losses()) - utilization: Utilization = field(default_factory=lambda: Utilization()) global_step: int = 0 epoch: int = 0 stats: dict = field(default_factory=lambda: {}) @@ -155,6 +154,8 @@ def __post_init__(self): clear=True, ) + self.utilization = Utilization() + self.vecenv.async_reset(self.config.seed) obs_shape = self.vecenv.single_observation_space.shape obs_dtype = self.vecenv.single_observation_space.dtype @@ -199,6 +200,14 @@ def __post_init__(self): @pufferlib.utils.profile def evaluate(self): + # Clear all self.infos except for the state + for k in list(self.infos.keys()): + if k != "state": + del self.infos[k] + elif len(self.infos["state"]) > 0: + # just in case + self.infos["state"] = self.infos["state"][-1] + # now for a tricky bit: # if we have swarm_frequency, we will take the top swarm_keep_pct envs and evenly distribute # their states to the bottom 90%. @@ -208,6 +217,7 @@ def evaluate(self): and hasattr(self.config, "swarm_keep_pct") and self.epoch % self.config.swarm_frequency == 0 and "reward/event" in self.infos + and "state" in self.infos ): # collect the top swarm_keep_pct % of envs largest = [ @@ -275,11 +285,17 @@ def evaluate(self): actions = actions.cpu().numpy() mask = torch.as_tensor(mask) # * policy.mask) o = o if self.config.cpu_offload else o_device + if self.config.num_workers == 1: + actions = np.expand_dims(actions, 0) + logprob = logprob.unsqueeze(0) self.experience.store(o, value, actions, logprob, r, d, env_id, mask) for i in info: for k, v in pufferlib.utils.unroll_nested_dict(i): - self.infos[k].append(v) + if k == "state": + self.infos[k] = [v] + else: + self.infos[k].append(v) with self.profile.env: self.vecenv.send(actions) @@ -441,16 +457,17 @@ def train(self): done_training = self.global_step >= self.config.total_timesteps if self.profile.update(self) or done_training: - print_dashboard( - self.config.env, - self.utilization, - self.global_step, - self.epoch, - self.profile, - self.losses, - self.stats, - self.msg, - ) + if self.config.verbose: + print_dashboard( + self.config.env, + self.utilization, + self.global_step, + self.epoch, + self.profile, + self.losses, + self.stats, + self.msg, + ) if ( self.wandb_client is not None @@ -475,7 +492,8 @@ def train(self): def close(self): self.vecenv.close() - self.utilization.stop() + if self.config.verbose: + self.utilization.stop() if self.wandb_client is not None: artifact_name = f"{self.exp_name}_model" diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 5b2c2d8..c79a826 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -102,6 +102,7 @@ def __init__(self, env_config: pufferlib.namespace): self.auto_pokeflute = env_config.auto_pokeflute self.infinite_money = env_config.infinite_money self.use_global_map = env_config.use_global_map + self.save_state = env_config.save_state self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? @@ -552,7 +553,7 @@ def step(self, action): self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} - if self.get_events_sum() > self.max_event_rew: + if self.save_state and self.get_events_sum() > self.max_event_rew: state = io.BytesIO() self.pyboy.save_state(state) state.seek(0)