Skip to content

Commit

Permalink
fix visited mask by padding global map
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 27, 2024
1 parent 011657a commit 11c8129
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
5 changes: 3 additions & 2 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ debug:
log_frequency: 1
disable_wild_encounters: True
disable_ai_actions: True
use_global_map: True
use_global_map: False
reduce_res: False
train:
device: cpu
compile: False
Expand All @@ -29,7 +30,7 @@ debug:
total_timesteps: 100_000_000
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
save_overlay: False
overlay_interval: 1
verbose: False
env_pool: False
Expand Down
1 change: 1 addition & 0 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,7 @@ def agent_stats(self, action):
"useful_items": {item.name: item.value in bag_item_ids for item in USEFUL_ITEMS},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
# Remove padding
"pokemon_exploration_map": self.explore_map,
# "cut_exploration_map": self.cut_explore_map,
}
Expand Down
3 changes: 3 additions & 0 deletions pokemonred_puffer/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import numpy as np
from numba import jit

from pokemonred_puffer.global_map import MAP_PAD


KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png")
BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH))
BACKGROUND = np.pad(BACKGROUND, MAP_PAD + ((0, 0),))


@jit(nopython=True, nogil=True)
Expand Down
9 changes: 6 additions & 3 deletions pokemonred_puffer/global_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import json

MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json")
GLOBAL_MAP_SHAPE = (444, 436)
MAP_PAD = ((20, 20), (20, 20))
GLOBAL_MAP_SHAPE = (444 + MAP_PAD[0][0] + MAP_PAD[0][1], 436 + MAP_PAD[1][0] + MAP_PAD[1][1])
MAP_ROW_OFFSET = MAP_PAD[0][0]
MAP_COL_OFFSET = MAP_PAD[1][0]

with open(MAP_PATH) as map_data:
MAP_DATA = json.load(map_data)["regions"]
Expand All @@ -16,8 +19,8 @@ def local_to_global(r: int, c: int, map_n: int):
map_x,
map_y,
) = MAP_DATA[map_n]["coordinates"]
gy = r + map_y
gx = c + map_x
gy = r + map_y + MAP_ROW_OFFSET
gx = c + map_x + MAP_COL_OFFSET
if 0 <= gy < GLOBAL_MAP_SHAPE[0] and 0 <= gx < GLOBAL_MAP_SHAPE[1]:
return gy, gx
print(f"coord out of bounds! global: ({gx}, {gy}) game: ({r}, {c}, {map_n})")
Expand Down

0 comments on commit 11c8129

Please sign in to comment.