diff --git a/config.yaml b/config.yaml index 353fd91..6e34d9b 100644 --- a/config.yaml +++ b/config.yaml @@ -12,7 +12,8 @@ debug: log_frequency: 1 disable_wild_encounters: True disable_ai_actions: True - use_global_map: True + use_global_map: False + reduce_res: False train: device: cpu compile: False @@ -29,7 +30,7 @@ debug: total_timesteps: 100_000_000 save_checkpoint: True checkpoint_interval: 4 - save_overlay: True + save_overlay: False overlay_interval: 1 verbose: False env_pool: False diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 944bc76..86bb3fa 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1125,6 +1125,7 @@ def agent_stats(self, action): "useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), + # Remove padding "pokemon_exploration_map": self.explore_map, # "cut_exploration_map": self.cut_explore_map, } diff --git a/pokemonred_puffer/eval.py b/pokemonred_puffer/eval.py index 901c31c..94d214f 100644 --- a/pokemonred_puffer/eval.py +++ b/pokemonred_puffer/eval.py @@ -4,9 +4,12 @@ import numpy as np from numba import jit +from pokemonred_puffer.global_map import MAP_PAD + KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png") BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH)) +BACKGROUND = np.pad(BACKGROUND, MAP_PAD + ((0, 0),)) @jit(nopython=True, nogil=True) diff --git a/pokemonred_puffer/global_map.py b/pokemonred_puffer/global_map.py index 9388955..0357b0f 100644 --- a/pokemonred_puffer/global_map.py +++ b/pokemonred_puffer/global_map.py @@ -2,7 +2,10 @@ import json MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json") -GLOBAL_MAP_SHAPE = (444, 436) +MAP_PAD = ((20, 20), (20, 20)) +GLOBAL_MAP_SHAPE = (444 + MAP_PAD[0][0] + MAP_PAD[0][1], 436 + MAP_PAD[1][0] + MAP_PAD[1][1]) +MAP_ROW_OFFSET = MAP_PAD[0][0] +MAP_COL_OFFSET = MAP_PAD[1][0] with open(MAP_PATH) as map_data: MAP_DATA = json.load(map_data)["regions"] @@ -16,8 +19,8 @@ def local_to_global(r: int, c: int, map_n: int): map_x, map_y, ) = MAP_DATA[map_n]["coordinates"] - gy = r + map_y - gx = c + map_x + gy = r + map_y + MAP_ROW_OFFSET + gx = c + map_x + MAP_COL_OFFSET if 0 <= gy < GLOBAL_MAP_SHAPE[0] and 0 <= gx < GLOBAL_MAP_SHAPE[1]: return gy, gx print(f"coord out of bounds! global: ({gx}, {gy}) game: ({r}, {c}, {map_n})")