diff --git a/config.yaml b/config.yaml index 023bf01..25d3084 100644 --- a/config.yaml +++ b/config.yaml @@ -7,13 +7,13 @@ debug: env: headless: False stream_wrapper: False - init_state: Bulbasaur + init_state: victory_road max_steps: 16 log_frequency: 1 disable_wild_encounters: True disable_ai_actions: True use_global_map: False - reduce_res: False + reduce_res: True animate_scripts: True train: device: cpu @@ -69,9 +69,9 @@ env: auto_pokeflute: True infinite_money: True use_global_map: False - save_state: False + save_state: True animate_scripts: False - exploration_inc: 0.25 + exploration_inc: 1.0 exploration_max: 1.0 @@ -103,7 +103,7 @@ train: bptt_horizon: 16 vf_clip_coef: 0.1 - num_envs: 144 + num_envs: 288 num_workers: 24 env_batch_size: 36 env_pool: True @@ -119,11 +119,11 @@ train: pool_kernel: [0] load_optimizer_state: False use_rnn: True - async_wrapper: False - archive_states: False + async_wrapper: True + archive_states: True - # swarm_frequency: 500 - # swarm_keep_pct: .8 + swarm_frequency: 500 + swarm_keep_pct: .8 wrappers: empty: @@ -284,11 +284,11 @@ rewards: useful_item: 1.0 pokecenter_heal: 0.2 exploration: 0.02 - exploration_gym: 0.05 + exploration_gym: 0.025 exploration_facility: 0.05 - exploration_plateau: 0.03 - exploration_lobby: 0.03 # for game corner - a_press: 0.02 + exploration_plateau: 0.025 + exploration_lobby: 0.035 # for game corner + a_press: 0.00001 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 0f9940e..284f1b9 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -330,28 +330,26 @@ def evaluate(self): max_event_count = len(key) new_state_key = key max_state: deque = self.states[key] - to_migrate_keys = [] if max_event_count > self.max_event_count and len(max_state) == max_state.maxlen: - to_migrate_keys = self.event_tracker.keys() self.max_event_count = max_event_count - # Need a way not to reset the env id counter for the driver env - # Until then env ids are 1-indexed - for key in to_migrate_keys: - new_state = random.choice(self.states[new_state_key]) - - print(f"Environment ID: {key}") - print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}") - print(f"\tNew events: {new_state_key}") - self.env_recv_queues[key].put(new_state) - # Now copy the hidden state over - # This may be a little slow, but so is this whole process - # self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] - # self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] - for key in to_migrate_keys: - print(f"\tWaiting for message from env-id {key}") - self.env_send_queues[key].get() - print("State migration complete") + # Need a way not to reset the env id counter for the driver env + # Until then env ids are 1-indexed + for key in self.event_tracker.keys(): + new_state = random.choice(self.states[new_state_key]) + + print(f"Environment ID: {key}") + print(f"\tEvents count: {self.event_tracker[key]} -> {len(new_state_key)}") + print(f"\tNew events: {new_state_key}") + self.env_recv_queues[key].put(new_state) + # Now copy the hidden state over + # This may be a little slow, but so is this whole process + # self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :] + # self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :] + for key in self.event_tracker.keys(): + print(f"\tWaiting for message from env-id {key}") + self.env_send_queues[key].get() + print("State migration complete") self.stats = {} @@ -537,7 +535,7 @@ def train(self): ) if self.epoch % self.config.checkpoint_interval == 0 or done_training: - self.save_checkpoint() + # self.save_checkpoint() self.msg = f"Checkpoint saved at update {self.epoch}" def close(self): @@ -546,11 +544,11 @@ def close(self): self.utilization.stop() if self.wandb_client is not None: - artifact_name = f"{self.exp_name}_model" - artifact = wandb.Artifact(artifact_name, type="model") - model_path = self.save_checkpoint() - artifact.add_file(model_path) - self.wandb_client.log_artifact(artifact) + # artifact_name = f"{self.exp_name}_model" + # artifact = wandb.Artifact(artifact_name, type="model") + # model_path = self.save_checkpoint() + # artifact.add_file(model_path) + # self.wandb_client.log_artifact(artifact) self.wandb_client.finish() def save_checkpoint(self): @@ -594,7 +592,7 @@ def __enter__(self): def __exit__(self, *args): print("Done training.") - self.save_checkpoint() + # self.save_checkpoint() self.close() print("Run complete") diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index 4bcebc9..476e4cf 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -38,11 +38,11 @@ def __init__( self.channels_last = channels_last self.downsample = downsample self.screen_network = nn.Sequential( - nn.LazyConv2d(32, 8, stride=4), + nn.LazyConv2d(32, 8, stride=2), nn.ReLU(), nn.LazyConv2d(64, 4, stride=2), nn.ReLU(), - nn.LazyConv2d(64, 3, stride=1), + nn.LazyConv2d(64, 3, stride=2), nn.ReLU(), nn.Flatten(), ) @@ -205,9 +205,9 @@ def encode_observations(self, observations): ) party_latent = self.party_network(party_obs) - event_obs = ( - observations["events"].float() @ self.event_embeddings.weight - ) / self.event_embeddings.weight.shape[0] + # event_obs = ( + # observations["events"].float() @ self.event_embeddings.weight + # ) / self.event_embeddings.weight.shape[0] cat_obs = torch.cat( ( self.screen_network(image_observation.float() / 255.0).squeeze(1), @@ -224,7 +224,7 @@ def encode_observations(self, observations): blackout_map_id.squeeze(1), items.flatten(start_dim=1), party_latent, - event_obs, + observations["events"].float().squeeze(1), ), dim=-1, ) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index 0e89b71..e16b864 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -333,7 +333,7 @@ def get_game_state_reward(self): } | { f"exploration_{tileset.name.lower()}": self.reward_config.get( - tileset.name.lower(), self.reward_config["exploration"] + f"exploration_{tileset.name.lower()}", self.reward_config["exploration"] ) * sum(self.seen_coords.get(tileset.value, {}).values()) for tileset in Tilesets