Skip to content

Commit

Permalink
Add table with agents overview
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Aug 30, 2024
1 parent a633e7c commit e9e0c15
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 6 deletions.
34 changes: 29 additions & 5 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from rich.table import Table

import wandb
from pokemonred_puffer.data.moves import Moves
from pokemonred_puffer.data.species import Species
from pokemonred_puffer.eval import make_pokemon_red_overlay
from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE
from pokemonred_puffer.profile import Profile, Utilization
Expand Down Expand Up @@ -366,13 +368,35 @@ def evaluate(self):
overlay = make_pokemon_red_overlay(np.stack(self.infos[k], axis=0))
if self.wandb_client is not None:
self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay)
elif "state" in k:
elif any(s in k for s in ["state", "env_id", "species", "levels", "moves"]):
continue
else:
try: # TODO: Better checks on log data types
self.stats[k] = np.mean(v)
except: # noqa: E722
continue

try: # TODO: Better checks on log data types
self.stats[k] = np.mean(v)
except: # noqa: E722
continue
if (
all(k in self.infos.keys() for k in ["env_ids", "species", "levels", "moves"])
and self.wandb_client is not None
):
table = {}
# The infos are in order of when they were received so this _should_ work
for env_id, species, levels, moves in zip(
self.infos["env_ids"],
self.infos["species"],
self.infos["levels"],
self.infos["moves"],
):
table[env_id] = [
f"{Species(_species).name} @ {level} w/ {[Moves(move).name for move in _moves if move]}"
for _species, level, _moves in zip(species, levels, moves)
]

self.stats["party/agents"] = wandb.Table(
columns=["env_id"] + [str(v) for v in range(6)],
data=[[str(k)] + v for k, v in table.items()],
)

if self.config.verbose:
self.msg = f"Model Size: {abbreviate(count_params(self.policy))} parameters"
Expand Down
169 changes: 169 additions & 0 deletions pokemonred_puffer/data/moves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
from enum import auto, IntEnum


class Moves(IntEnum):
POUND = auto()
KARATE_CHOP = auto()
DOUBLESLAP = auto()
COMET_PUNCH = auto()
MEGA_PUNCH = auto()
PAY_DAY = auto()
FIRE_PUNCH = auto()
ICE_PUNCH = auto()
THUNDERPUNCH = auto()
SCRATCH = auto()
VICEGRIP = auto()
GUILLOTINE = auto()
RAZOR_WIND = auto()
SWORDS_DANCE = auto()
CUT = auto()
GUST = auto()
WING_ATTACK = auto()
WHIRLWIND = auto()
FLY = auto()
BIND = auto()
SLAM = auto()
VINE_WHIP = auto()
STOMP = auto()
DOUBLE_KICK = auto()
MEGA_KICK = auto()
JUMP_KICK = auto()
ROLLING_KICK = auto()
SAND_ATTACK = auto()
HEADBUTT = auto()
HORN_ATTACK = auto()
FURY_ATTACK = auto()
HORN_DRILL = auto()
TACKLE = auto()
BODY_SLAM = auto()
WRAP = auto()
TAKE_DOWN = auto()
THRASH = auto()
DOUBLE_EDGE = auto()
TAIL_WHIP = auto()
POISON_STING = auto()
TWINEEDLE = auto()
PIN_MISSILE = auto()
LEER = auto()
BITE = auto()
GROWL = auto()
ROAR = auto()
SING = auto()
SUPERSONIC = auto()
SONICBOOM = auto()
DISABLE = auto()
ACID = auto()
EMBER = auto()
FLAMETHROWER = auto()
MIST = auto()
WATER_GUN = auto()
HYDRO_PUMP = auto()
SURF = auto()
ICE_BEAM = auto()
BLIZZARD = auto()
PSYBEAM = auto()
BUBBLEBEAM = auto()
AURORA_BEAM = auto()
HYPER_BEAM = auto()
PECK = auto()
DRILL_PECK = auto()
SUBMISSION = auto()
LOW_KICK = auto()
COUNTER = auto()
SEISMIC_TOSS = auto()
STRENGTH = auto()
ABSORB = auto()
MEGA_DRAIN = auto()
LEECH_SEED = auto()
GROWTH = auto()
RAZOR_LEAF = auto()
SOLARBEAM = auto()
POISONPOWDER = auto()
STUN_SPORE = auto()
SLEEP_POWDER = auto()
PETAL_DANCE = auto()
STRING_SHOT = auto()
DRAGON_RAGE = auto()
FIRE_SPIN = auto()
THUNDERSHOCK = auto()
THUNDERBOLT = auto()
THUNDER_WAVE = auto()
THUNDER = auto()
ROCK_THROW = auto()
EARTHQUAKE = auto()
FISSURE = auto()
DIG = auto()
TOXIC = auto()
CONFUSION = auto()
PSYCHIC_M = auto()
HYPNOSIS = auto()
MEDITATE = auto()
AGILITY = auto()
QUICK_ATTACK = auto()
RAGE = auto()
TELEPORT = auto()
NIGHT_SHADE = auto()
MIMIC = auto()
SCREECH = auto()
DOUBLE_TEAM = auto()
RECOVER = auto()
HARDEN = auto()
MINIMIZE = auto()
SMOKESCREEN = auto()
CONFUSE_RAY = auto()
WITHDRAW = auto()
DEFENSE_CURL = auto()
BARRIER = auto()
LIGHT_SCREEN = auto()
HAZE = auto()
REFLECT = auto()
FOCUS_ENERGY = auto()
BIDE = auto()
METRONOME = auto()
MIRROR_MOVE = auto()
SELFDESTRUCT = auto()
EGG_BOMB = auto()
LICK = auto()
SMOG = auto()
SLUDGE = auto()
BONE_CLUB = auto()
FIRE_BLAST = auto()
WATERFALL = auto()
CLAMP = auto()
SWIFT = auto()
SKULL_BASH = auto()
SPIKE_CANNON = auto()
CONSTRICT = auto()
AMNESIA = auto()
KINESIS = auto()
SOFTBOILED = auto()
HI_JUMP_KICK = auto()
GLARE = auto()
DREAM_EATER = auto()
POISON_GAS = auto()
BARRAGE = auto()
LEECH_LIFE = auto()
LOVELY_KISS = auto()
SKY_ATTACK = auto()
TRANSFORM = auto()
BUBBLE = auto()
DIZZY_PUNCH = auto()
SPORE = auto()
FLASH = auto()
PSYWAVE = auto()
SPLASH = auto()
ACID_ARMOR = auto()
CRABHAMMER = auto()
EXPLOSION = auto()
FURY_SWIPES = auto()
BONEMERANG = auto()
REST = auto()
ROCK_SLIDE = auto()
HYPER_FANG = auto()
SHARPEN = auto()
CONVERSION = auto()
TRI_ATTACK = auto()
SUPER_FANG = auto()
SLASH = auto()
SUBSTITUTE = auto()
STRUGGLE = auto()
2 changes: 1 addition & 1 deletion pokemonred_puffer/data/tm_hm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


class TmHmMoves(Enum):
MEGA_PUNCH = (0x5,)
MEGA_PUNCH = 0x5
RAZOR_WIND = 0xD
SWORDS_DANCE = 0xE
WHIRLWIND = 0x12
Expand Down
4 changes: 4 additions & 0 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,7 @@ def agent_stats(self, action):
)

return {
"env_ids": int(self.env_id),
"stats": {
"step": self.step_count + self.reset_count * self.max_steps,
"max_map_progress": self.max_map_progress,
Expand Down Expand Up @@ -1462,6 +1463,9 @@ def agent_stats(self, action):
# Remove padding
"pokemon_exploration_map": self.explore_map,
# "cut_exploration_map": self.cut_explore_map,
"species": [pokemon.Species for pokemon in self.party],
"levels": [pokemon.Level for pokemon in self.party],
"moves": [list(int(m) for m in pokemon.Moves) for pokemon in self.party],
}

def start_video(self):
Expand Down

0 comments on commit e9e0c15

Please sign in to comment.