From e9e0c15aaf71f35e70dcdcac8dacf42308fb42ce Mon Sep 17 00:00:00 2001 From: thatguy11325 <148832074+thatguy11325@users.noreply.github.com> Date: Thu, 29 Aug 2024 23:48:13 -0400 Subject: [PATCH] Add table with agents overview --- pokemonred_puffer/cleanrl_puffer.py | 34 +++++- pokemonred_puffer/data/moves.py | 169 ++++++++++++++++++++++++++++ pokemonred_puffer/data/tm_hm.py | 2 +- pokemonred_puffer/environment.py | 4 + 4 files changed, 203 insertions(+), 6 deletions(-) create mode 100644 pokemonred_puffer/data/moves.py diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index c921264..cc986ff 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -27,6 +27,8 @@ from rich.table import Table import wandb +from pokemonred_puffer.data.moves import Moves +from pokemonred_puffer.data.species import Species from pokemonred_puffer.eval import make_pokemon_red_overlay from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE from pokemonred_puffer.profile import Profile, Utilization @@ -366,13 +368,35 @@ def evaluate(self): overlay = make_pokemon_red_overlay(np.stack(self.infos[k], axis=0)) if self.wandb_client is not None: self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay) - elif "state" in k: + elif any(s in k for s in ["state", "env_id", "species", "levels", "moves"]): continue + else: + try: # TODO: Better checks on log data types + self.stats[k] = np.mean(v) + except: # noqa: E722 + continue - try: # TODO: Better checks on log data types - self.stats[k] = np.mean(v) - except: # noqa: E722 - continue + if ( + all(k in self.infos.keys() for k in ["env_ids", "species", "levels", "moves"]) + and self.wandb_client is not None + ): + table = {} + # The infos are in order of when they were received so this _should_ work + for env_id, species, levels, moves in zip( + self.infos["env_ids"], + self.infos["species"], + self.infos["levels"], + self.infos["moves"], + ): + table[env_id] = [ + f"{Species(_species).name} @ {level} w/ {[Moves(move).name for move in _moves if move]}" + for _species, level, _moves in zip(species, levels, moves) + ] + + self.stats["party/agents"] = wandb.Table( + columns=["env_id"] + [str(v) for v in range(6)], + data=[[str(k)] + v for k, v in table.items()], + ) if self.config.verbose: self.msg = f"Model Size: {abbreviate(count_params(self.policy))} parameters" diff --git a/pokemonred_puffer/data/moves.py b/pokemonred_puffer/data/moves.py new file mode 100644 index 0000000..03f73a5 --- /dev/null +++ b/pokemonred_puffer/data/moves.py @@ -0,0 +1,169 @@ +from enum import auto, IntEnum + + +class Moves(IntEnum): + POUND = auto() + KARATE_CHOP = auto() + DOUBLESLAP = auto() + COMET_PUNCH = auto() + MEGA_PUNCH = auto() + PAY_DAY = auto() + FIRE_PUNCH = auto() + ICE_PUNCH = auto() + THUNDERPUNCH = auto() + SCRATCH = auto() + VICEGRIP = auto() + GUILLOTINE = auto() + RAZOR_WIND = auto() + SWORDS_DANCE = auto() + CUT = auto() + GUST = auto() + WING_ATTACK = auto() + WHIRLWIND = auto() + FLY = auto() + BIND = auto() + SLAM = auto() + VINE_WHIP = auto() + STOMP = auto() + DOUBLE_KICK = auto() + MEGA_KICK = auto() + JUMP_KICK = auto() + ROLLING_KICK = auto() + SAND_ATTACK = auto() + HEADBUTT = auto() + HORN_ATTACK = auto() + FURY_ATTACK = auto() + HORN_DRILL = auto() + TACKLE = auto() + BODY_SLAM = auto() + WRAP = auto() + TAKE_DOWN = auto() + THRASH = auto() + DOUBLE_EDGE = auto() + TAIL_WHIP = auto() + POISON_STING = auto() + TWINEEDLE = auto() + PIN_MISSILE = auto() + LEER = auto() + BITE = auto() + GROWL = auto() + ROAR = auto() + SING = auto() + SUPERSONIC = auto() + SONICBOOM = auto() + DISABLE = auto() + ACID = auto() + EMBER = auto() + FLAMETHROWER = auto() + MIST = auto() + WATER_GUN = auto() + HYDRO_PUMP = auto() + SURF = auto() + ICE_BEAM = auto() + BLIZZARD = auto() + PSYBEAM = auto() + BUBBLEBEAM = auto() + AURORA_BEAM = auto() + HYPER_BEAM = auto() + PECK = auto() + DRILL_PECK = auto() + SUBMISSION = auto() + LOW_KICK = auto() + COUNTER = auto() + SEISMIC_TOSS = auto() + STRENGTH = auto() + ABSORB = auto() + MEGA_DRAIN = auto() + LEECH_SEED = auto() + GROWTH = auto() + RAZOR_LEAF = auto() + SOLARBEAM = auto() + POISONPOWDER = auto() + STUN_SPORE = auto() + SLEEP_POWDER = auto() + PETAL_DANCE = auto() + STRING_SHOT = auto() + DRAGON_RAGE = auto() + FIRE_SPIN = auto() + THUNDERSHOCK = auto() + THUNDERBOLT = auto() + THUNDER_WAVE = auto() + THUNDER = auto() + ROCK_THROW = auto() + EARTHQUAKE = auto() + FISSURE = auto() + DIG = auto() + TOXIC = auto() + CONFUSION = auto() + PSYCHIC_M = auto() + HYPNOSIS = auto() + MEDITATE = auto() + AGILITY = auto() + QUICK_ATTACK = auto() + RAGE = auto() + TELEPORT = auto() + NIGHT_SHADE = auto() + MIMIC = auto() + SCREECH = auto() + DOUBLE_TEAM = auto() + RECOVER = auto() + HARDEN = auto() + MINIMIZE = auto() + SMOKESCREEN = auto() + CONFUSE_RAY = auto() + WITHDRAW = auto() + DEFENSE_CURL = auto() + BARRIER = auto() + LIGHT_SCREEN = auto() + HAZE = auto() + REFLECT = auto() + FOCUS_ENERGY = auto() + BIDE = auto() + METRONOME = auto() + MIRROR_MOVE = auto() + SELFDESTRUCT = auto() + EGG_BOMB = auto() + LICK = auto() + SMOG = auto() + SLUDGE = auto() + BONE_CLUB = auto() + FIRE_BLAST = auto() + WATERFALL = auto() + CLAMP = auto() + SWIFT = auto() + SKULL_BASH = auto() + SPIKE_CANNON = auto() + CONSTRICT = auto() + AMNESIA = auto() + KINESIS = auto() + SOFTBOILED = auto() + HI_JUMP_KICK = auto() + GLARE = auto() + DREAM_EATER = auto() + POISON_GAS = auto() + BARRAGE = auto() + LEECH_LIFE = auto() + LOVELY_KISS = auto() + SKY_ATTACK = auto() + TRANSFORM = auto() + BUBBLE = auto() + DIZZY_PUNCH = auto() + SPORE = auto() + FLASH = auto() + PSYWAVE = auto() + SPLASH = auto() + ACID_ARMOR = auto() + CRABHAMMER = auto() + EXPLOSION = auto() + FURY_SWIPES = auto() + BONEMERANG = auto() + REST = auto() + ROCK_SLIDE = auto() + HYPER_FANG = auto() + SHARPEN = auto() + CONVERSION = auto() + TRI_ATTACK = auto() + SUPER_FANG = auto() + SLASH = auto() + SUBSTITUTE = auto() + STRUGGLE = auto() diff --git a/pokemonred_puffer/data/tm_hm.py b/pokemonred_puffer/data/tm_hm.py index 0916d26..95bed39 100644 --- a/pokemonred_puffer/data/tm_hm.py +++ b/pokemonred_puffer/data/tm_hm.py @@ -3,7 +3,7 @@ class TmHmMoves(Enum): - MEGA_PUNCH = (0x5,) + MEGA_PUNCH = 0x5 RAZOR_WIND = 0xD SWORDS_DANCE = 0xE WHIRLWIND = 0x12 diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index cbd9b52..082a1d9 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -1395,6 +1395,7 @@ def agent_stats(self, action): ) return { + "env_ids": int(self.env_id), "stats": { "step": self.step_count + self.reset_count * self.max_steps, "max_map_progress": self.max_map_progress, @@ -1462,6 +1463,9 @@ def agent_stats(self, action): # Remove padding "pokemon_exploration_map": self.explore_map, # "cut_exploration_map": self.cut_explore_map, + "species": [pokemon.Species for pokemon in self.party], + "levels": [pokemon.Level for pokemon in self.party], + "moves": [list(int(m) for m in pokemon.Moves) for pokemon in self.party], } def start_video(self):