From 8a6640d3ed7bd7e1730465afaf5c33c69e5518c2 Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 27 Jun 2024 20:38:15 -0400 Subject: [PATCH] fix explore map --- config.yaml | 2 +- pokemonred_puffer/cleanrl_puffer.py | 6 +++--- pokemonred_puffer/eval.py | 4 ++-- pokemonred_puffer/global_map.py | 10 ++++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/config.yaml b/config.yaml index 218d5ef..649332f 100644 --- a/config.yaml +++ b/config.yaml @@ -30,7 +30,7 @@ debug: total_timesteps: 100_000_000 save_checkpoint: True checkpoint_interval: 4 - save_overlay: False + save_overlay: True overlay_interval: 1 verbose: False env_pool: False diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 73c5e2f..d66773e 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -304,14 +304,14 @@ def evaluate(self): # Moves into models... maybe. Definitely moves. # You could also just return infos and have it in demo if "pokemon_exploration_map" in k and self.config.save_overlay is True: - if self.epoch % self.config.overlay_interval == 0: + if True or self.epoch % self.config.overlay_interval == 0: overlay = make_pokemon_red_overlay(np.stack(self.infos[k], axis=0)) if self.wandb_client is not None: self.stats["Media/aggregate_exploration_map"] = wandb.Image( overlay, file_type="jpg" ) - elif "state" in k: - continue + elif "state" in k: + continue try: # TODO: Better checks on log data types self.stats[k] = np.mean(v) diff --git a/pokemonred_puffer/eval.py b/pokemonred_puffer/eval.py index 94d214f..e67559a 100644 --- a/pokemonred_puffer/eval.py +++ b/pokemonred_puffer/eval.py @@ -4,12 +4,12 @@ import numpy as np from numba import jit -from pokemonred_puffer.global_map import MAP_PAD +from pokemonred_puffer.global_map import PAD KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png") BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH)) -BACKGROUND = np.pad(BACKGROUND, MAP_PAD + ((0, 0),)) +BACKGROUND = np.pad(BACKGROUND, ((PAD * 16, PAD * 16), (PAD * 16, PAD * 16), (0, 0))) @jit(nopython=True, nogil=True) diff --git a/pokemonred_puffer/global_map.py b/pokemonred_puffer/global_map.py index 0357b0f..e5d5f9d 100644 --- a/pokemonred_puffer/global_map.py +++ b/pokemonred_puffer/global_map.py @@ -1,11 +1,13 @@ import os import json +KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png") + MAP_PATH = os.path.join(os.path.dirname(__file__), "map_data.json") -MAP_PAD = ((20, 20), (20, 20)) -GLOBAL_MAP_SHAPE = (444 + MAP_PAD[0][0] + MAP_PAD[0][1], 436 + MAP_PAD[1][0] + MAP_PAD[1][1]) -MAP_ROW_OFFSET = MAP_PAD[0][0] -MAP_COL_OFFSET = MAP_PAD[1][0] +PAD = 20 +GLOBAL_MAP_SHAPE = (444 + PAD * 2, 436 + PAD * 2) +MAP_ROW_OFFSET = PAD +MAP_COL_OFFSET = PAD with open(MAP_PATH) as map_data: MAP_DATA = json.load(map_data)["regions"]