diff --git a/config.yaml b/config.yaml index 79a56a9..6ad611c 100644 --- a/config.yaml +++ b/config.yaml @@ -7,15 +7,17 @@ debug: env: headless: False stream_wrapper: False - init_state: cut + init_state: victory_road max_steps: 1_000_000 + disable_wild_encounters: True + disable_ai_actions: True train: device: cpu compile: False compile_mode: default - num_envs: 4 + num_envs: 1 envs_per_worker: 1 - envs_per_batch: 4 + envs_per_batch: 1 batch_size: 16 batch_rows: 4 bptt_horizon: 2 @@ -28,8 +30,8 @@ debug: env_pool: False log_frequency: 5000 load_optimizer_state: False - swarm_frequency: 10 - swarm_keep_pct: .1 + # swarm_frequency: 10 + # swarm_keep_pct: .1 env: headless: True @@ -47,6 +49,20 @@ env: reduce_res: True two_bit: True log_frequency: 2000 + auto_flash: True + disable_wild_encounters: True + disable_ai_actions: False + auto_teach_cut: True + auto_use_cut: True + auto_use_surf: True + auto_teach_surf: True + auto_teach_strength: True + auto_solve_strength_puzzles: True + auto_remove_all_nonuseful_items: True + auto_pokeflute: True + infinite_money: True + use_global_map: False + train: seed: 1 @@ -84,13 +100,13 @@ train: save_checkpoint: False checkpoint_interval: 200 save_overlay: True - overlay_interval: 200 + overlay_interval: 100 cpu_offload: True pool_kernel: [0] load_optimizer_state: False - swarm_frequency: 500 - swarm_keep_pct: .8 + # swarm_frequency: 500 + # swarm_keep_pct: .8 wrappers: baseline: @@ -108,16 +124,38 @@ wrappers: bag_menu: 0.998 action_bag_menu: 0.998 forgetting_frequency: 10 + - exploration.OnResetExplorationWrapper: + full_reset_frequency: 1 finite_coords: - stream_wrapper.StreamWrapper: user: thatguy - exploration.MaxLengthWrapper: capacity: 1750 + - exploration.OnResetExplorationWrapper: + full_reset_frequency: 1 + jitter: 0 stream_only: - stream_wrapper.StreamWrapper: user: thatguy + - exploration.OnResetExplorationWrapper: + full_reset_frequency: 1 + jitter: 1 + + fixed_reset_value: + - stream_wrapper.StreamWrapper: + user: thatguy + - exploration.OnResetLowerToFixedValueWrapper: + fixed_value: + coords: 0.33 + map_ids: 0.33 + npc: 0.33 + cut: 0.33 + explore: 0.33 + - exploration.OnResetExplorationWrapper: + full_reset_frequency: 25 + jitter: 0 rewards: baseline.BaselineRewardEnv: @@ -139,6 +177,7 @@ rewards: pokemon_menu: 0.1 stats_menu: 0.1 bag_menu: 0.1 + baseline.TeachCutReplicationEnvFork: reward: event: 1.0 @@ -157,25 +196,26 @@ rewards: explore_npcs: 0.02 explore_hidden_objs: 0.02 - baseline.RockTunnelReplicationEnv: + baseline.CutWithObjectRewardsEnv: reward: - level: 1.0 - exploration: 0.02 - taught_cut: 10.0 - event: 3.0 + event: 1.0 + bill_saved: 5.0 seen_pokemon: 4.0 caught_pokemon: 4.0 moves_obtained: 4.0 - cut_coords: 1.0 - cut_tiles: 1.0 - start_menu: 0.005 - pokemon_menu: 0.05 - stats_menu: 0.05 - bag_menu: 0.05 - pokecenter: 5.0 - # Really an addition to event reward - badges: 2.0 - bill_saved: 2.0 + hm_count: 10.0 + level: 1.0 + badges: 10.0 + exploration: 0.02 + cut_coords: 0.0 + cut_tiles: 0.0 + start_menu: 0.00 + pokemon_menu: 0.0 + stats_menu: 0.0 + bag_menu: 0.1 + rocket_hideout_found: 5.0 + explore_hidden_objs: 0.02 + seen_action_bag_menu: 0.1 diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index b74b193..3c62d3e 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -524,13 +524,13 @@ def evaluate(self): overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) if self.wandb is not None: self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay) - elif "cut_exploration_map" in k and config.save_overlay is True: - if self.update % config.overlay_interval == 0: - overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) - if self.wandb is not None: - self.stats["Media/aggregate_cut_exploration_map"] = self.wandb.Image( - overlay - ) + # elif "cut_exploration_map" in k and config.save_overlay is True: + # if self.update % config.overlay_interval == 0: + # overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) + # if self.wandb is not None: + # self.stats["Media/aggregate_cut_exploration_map"] = self.wandb.Image( + # overlay + # ) elif "state" in k: pass else: diff --git a/pokemonred_puffer/data/events.py b/pokemonred_puffer/data/events.py new file mode 100644 index 0000000..c7db94a --- /dev/null +++ b/pokemonred_puffer/data/events.py @@ -0,0 +1,3 @@ +EVENT_FLAGS_START = 0xD747 +EVENTS_FLAGS_LENGTH = 320 +MUSEUM_TICKET = (0xD754, 0) diff --git a/pokemonred_puffer/data/field_moves.py b/pokemonred_puffer/data/field_moves.py new file mode 100644 index 0000000..86d2e83 --- /dev/null +++ b/pokemonred_puffer/data/field_moves.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class FieldMoves(Enum): + CUT = 1 + FLY = 2 + SURF = 3 + SURF_2 = 4 + STRENGTH = 5 + FLASH = 6 + DIG = 7 + TELEPORT = 8 + SOFTBOILED = 9 diff --git a/pokemonred_puffer/data/items.py b/pokemonred_puffer/data/items.py new file mode 100644 index 0000000..16640df --- /dev/null +++ b/pokemonred_puffer/data/items.py @@ -0,0 +1,205 @@ +from enum import Enum + + +MAX_ITEM_CAPACITY = 20 +# Starts at 0x1 + + +class Items(Enum): + MASTER_BALL = 0x01 + ULTRA_BALL = 0x02 + GREAT_BALL = 0x03 + POKE_BALL = 0x04 + TOWN_MAP = 0x05 + BICYCLE = 0x06 + SURFBOARD = 0x07 # + SAFARI_BALL = 0x08 + POKEDEX = 0x09 + MOON_STONE = 0x0A + ANTIDOTE = 0x0B + BURN_HEAL = 0x0C + ICE_HEAL = 0x0D + AWAKENING = 0x0E + PARLYZ_HEAL = 0x0F + FULL_RESTORE = 0x10 + MAX_POTION = 0x11 + HYPER_POTION = 0x12 + SUPER_POTION = 0x13 + POTION = 0x14 + BOULDERBADGE = 0x15 + CASCADEBADGE = 0x16 + SAFARI_BAIT = 0x15 # overload + SAFARI_ROCK = 0x16 # overload + THUNDERBADGE = 0x17 + RAINBOWBADGE = 0x18 + SOULBADGE = 0x19 + MARSHBADGE = 0x1A + VOLCANOBADGE = 0x1B + EARTHBADGE = 0x1C + ESCAPE_ROPE = 0x1D + REPEL = 0x1E + OLD_AMBER = 0x1F + FIRE_STONE = 0x20 + THUNDER_STONE = 0x21 + WATER_STONE = 0x22 + HP_UP = 0x23 + PROTEIN = 0x24 + IRON = 0x25 + CARBOS = 0x26 + CALCIUM = 0x27 + RARE_CANDY = 0x28 + DOME_FOSSIL = 0x29 + HELIX_FOSSIL = 0x2A + SECRET_KEY = 0x2B + UNUSED_ITEM = 0x2C # "?????" + BIKE_VOUCHER = 0x2D + X_ACCURACY = 0x2E + LEAF_STONE = 0x2F + CARD_KEY = 0x30 + NUGGET = 0x31 + PP_UP_2 = 0x32 + POKE_DOLL = 0x33 + FULL_HEAL = 0x34 + REVIVE = 0x35 + MAX_REVIVE = 0x36 + GUARD_SPEC = 0x37 + SUPER_REPEL = 0x38 + MAX_REPEL = 0x39 + DIRE_HIT = 0x3A + COIN = 0x3B + FRESH_WATER = 0x3C + SODA_POP = 0x3D + LEMONADE = 0x3E + S_S_TICKET = 0x3F + GOLD_TEETH = 0x40 + X_ATTACK = 0x41 + X_DEFEND = 0x42 + X_SPEED = 0x43 + X_SPECIAL = 0x44 + COIN_CASE = 0x45 + OAKS_PARCEL = 0x46 + ITEMFINDER = 0x47 + SILPH_SCOPE = 0x48 + POKE_FLUTE = 0x49 + LIFT_KEY = 0x4A + EXP_ALL = 0x4B + OLD_ROD = 0x4C + GOOD_ROD = 0x4D + SUPER_ROD = 0x4E + PP_UP = 0x4F + ETHER = 0x50 + MAX_ETHER = 0x51 + ELIXER = 0x52 + MAX_ELIXER = 0x53 + FLOOR_B2F = 0x54 + FLOOR_B1F = 0x55 + FLOOR_1F = 0x56 + FLOOR_2F = 0x57 + FLOOR_3F = 0x58 + FLOOR_4F = 0x59 + FLOOR_5F = 0x5A + FLOOR_6F = 0x5B + FLOOR_7F = 0x5C + FLOOR_8F = 0x5D + FLOOR_9F = 0x5E + FLOOR_10F = 0x5F + FLOOR_11F = 0x60 + FLOOR_B4F = 0x61 + HM_01 = 0xC4 + HM_02 = 0xC5 + HM_03 = 0xC6 + HM_04 = 0xC7 + HM_05 = 0xC8 + TM_01 = 0xC9 + TM_02 = 0xCA + TM_03 = 0xCB + TM_04 = 0xCC + TM_05 = 0xCD + TM_06 = 0xCE + TM_07 = 0xCF + TM_08 = 0xD0 + TM_09 = 0xD1 + TM_10 = 0xD2 + TM_11 = 0xD3 + TM_12 = 0xD4 + TM_13 = 0xD5 + TM_14 = 0xD6 + TM_15 = 0xD7 + TM_16 = 0xD8 + TM_17 = 0xD9 + TM_18 = 0xDA + TM_19 = 0xDB + TM_20 = 0xDC + TM_21 = 0xDD + TM_22 = 0xDE + TM_23 = 0xDF + TM_24 = 0xE0 + TM_25 = 0xE1 + TM_26 = 0xE2 + TM_27 = 0xE3 + TM_28 = 0xE4 + TM_29 = 0xE5 + TM_30 = 0xE6 + TM_31 = 0xE7 + TM_32 = 0xE8 + TM_33 = 0xE9 + TM_34 = 0xEA + TM_35 = 0xEB + TM_36 = 0xEC + TM_37 = 0xED + TM_38 = 0xEE + TM_39 = 0xEF + TM_40 = 0xF0 + TM_41 = 0xF1 + TM_42 = 0xF2 + TM_43 = 0xF3 + TM_44 = 0xF4 + TM_45 = 0xF5 + TM_46 = 0xF6 + TM_47 = 0xF7 + TM_48 = 0xF8 + TM_49 = 0xF9 + TM_50 = 0xFA + + +KEY_ITEM_IDS = { + Items.TOWN_MAP.value, + Items.BICYCLE.value, + Items.SURFBOARD.value, + Items.SAFARI_BALL.value, + Items.POKEDEX.value, + Items.BOULDERBADGE.value, + Items.CASCADEBADGE.value, + Items.THUNDERBADGE.value, + Items.RAINBOWBADGE.value, + Items.SOULBADGE.value, + Items.MARSHBADGE.value, + Items.VOLCANOBADGE.value, + Items.EARTHBADGE.value, + Items.OLD_AMBER.value, + Items.DOME_FOSSIL.value, + Items.HELIX_FOSSIL.value, + Items.SECRET_KEY.value, + # Items.ITEM_2C.value, + Items.BIKE_VOUCHER.value, + Items.CARD_KEY.value, + Items.S_S_TICKET.value, + Items.GOLD_TEETH.value, + Items.COIN_CASE.value, + Items.OAKS_PARCEL.value, + Items.ITEMFINDER.value, + Items.SILPH_SCOPE.value, + Items.POKE_FLUTE.value, + Items.LIFT_KEY.value, + Items.OLD_ROD.value, + Items.GOOD_ROD.value, + Items.SUPER_ROD.value, +} + +HM_ITEM_IDS = { + Items.HM_01.value, + Items.HM_02.value, + Items.HM_03.value, + Items.HM_04.value, + Items.HM_05.value, +} diff --git a/pokemonred_puffer/data/map.py b/pokemonred_puffer/data/map.py new file mode 100644 index 0000000..a58ea30 --- /dev/null +++ b/pokemonred_puffer/data/map.py @@ -0,0 +1,16 @@ +RESET_MAP_IDS = { + 0x0, # Pallet Town + 0x1, # Viridian City + 0x2, # Pewter City + 0x3, # Cerulean City + 0x4, # Lavender Town + 0x5, # Vermilion City + 0x6, # Celadon City + 0x7, # Fuchsia City + 0x8, # Cinnabar Island + 0x9, # Indigo Plateau + 0xA, # Saffron City + 0xF, # Route 4 (Mt Moon) + 0x10, # Route 10 (Rock Tunnel) + 0xE9, # Silph Co 9F (Heal station) +} diff --git a/pokemonred_puffer/data/species.py b/pokemonred_puffer/data/species.py new file mode 100644 index 0000000..8704e7e --- /dev/null +++ b/pokemonred_puffer/data/species.py @@ -0,0 +1,194 @@ +from enum import Enum + + +class Species(Enum): + RHYDON = 0x01 + KANGASKHAN = 0x02 + NIDORAN_M = 0x03 + CLEFAIRY = 0x04 + SPEAROW = 0x05 + VOLTORB = 0x06 + NIDOKING = 0x07 + SLOWBRO = 0x08 + IVYSAUR = 0x09 + EXEGGUTOR = 0x0A + LICKITUNG = 0x0B + EXEGGCUTE = 0x0C + GRIMER = 0x0D + GENGAR = 0x0E + NIDORAN_F = 0x0F + NIDOQUEEN = 0x10 + CUBONE = 0x11 + RHYHORN = 0x12 + LAPRAS = 0x13 + ARCANINE = 0x14 + MEW = 0x15 + GYARADOS = 0x16 + SHELLDER = 0x17 + TENTACOOL = 0x18 + GASTLY = 0x19 + SCYTHER = 0x1A + STARYU = 0x1B + BLASTOISE = 0x1C + PINSIR = 0x1D + TANGELA = 0x1E + MISSINGNO_1F = 0x1F + MISSINGNO_20 = 0x20 + GROWLITHE = 0x21 + ONIX = 0x22 + FEAROW = 0x23 + PIDGEY = 0x24 + SLOWPOKE = 0x25 + KADABRA = 0x26 + GRAVELER = 0x27 + CHANSEY = 0x28 + MACHOKE = 0x29 + MR_MIME = 0x2A + HITMONLEE = 0x2B + HITMONCHAN = 0x2C + ARBOK = 0x2D + PARASECT = 0x2E + PSYDUCK = 0x2F + DROWZEE = 0x30 + GOLEM = 0x31 + MISSINGNO_32 = 0x32 + MAGMAR = 0x33 + MISSINGNO_34 = 0x34 + ELECTABUZZ = 0x35 + MAGNETON = 0x36 + KOFFING = 0x37 + MISSINGNO_38 = 0x38 + MANKEY = 0x39 + SEEL = 0x3A + DIGLETT = 0x3B + TAUROS = 0x3C + MISSINGNO_3D = 0x3D + MISSINGNO_3E = 0x3E + MISSINGNO_3F = 0x3F + FARFETCHD = 0x40 + VENONAT = 0x41 + DRAGONITE = 0x42 + MISSINGNO_43 = 0x43 + MISSINGNO_44 = 0x44 + MISSINGNO_45 = 0x45 + DODUO = 0x46 + POLIWAG = 0x47 + JYNX = 0x48 + MOLTRES = 0x49 + ARTICUNO = 0x4A + ZAPDOS = 0x4B + DITTO = 0x4C + MEOWTH = 0x4D + KRABBY = 0x4E + MISSINGNO_4F = 0x4F + MISSINGNO_50 = 0x50 + MISSINGNO_51 = 0x51 + VULPIX = 0x52 + NINETALES = 0x53 + PIKACHU = 0x54 + RAICHU = 0x55 + MISSINGNO_56 = 0x56 + MISSINGNO_57 = 0x57 + DRATINI = 0x58 + DRAGONAIR = 0x59 + KABUTO = 0x5A + KABUTOPS = 0x5B + HORSEA = 0x5C + SEADRA = 0x5D + MISSINGNO_5E = 0x5E + MISSINGNO_5F = 0x5F + SANDSHREW = 0x60 + SANDSLASH = 0x61 + OMANYTE = 0x62 + OMASTAR = 0x63 + JIGGLYPUFF = 0x64 + WIGGLYTUFF = 0x65 + EEVEE = 0x66 + FLAREON = 0x67 + JOLTEON = 0x68 + VAPOREON = 0x69 + MACHOP = 0x6A + ZUBAT = 0x6B + EKANS = 0x6C + PARAS = 0x6D + POLIWHIRL = 0x6E + POLIWRATH = 0x6F + WEEDLE = 0x70 + KAKUNA = 0x71 + BEEDRILL = 0x72 + MISSINGNO_73 = 0x73 + DODRIO = 0x74 + PRIMEAPE = 0x75 + DUGTRIO = 0x76 + VENOMOTH = 0x77 + DEWGONG = 0x78 + MISSINGNO_79 = 0x79 + MISSINGNO_7A = 0x7A + CATERPIE = 0x7B + METAPOD = 0x7C + BUTTERFREE = 0x7D + MACHAMP = 0x7E + MISSINGNO_7F = 0x7F + GOLDUCK = 0x80 + HYPNO = 0x81 + GOLBAT = 0x82 + MEWTWO = 0x83 + SNORLAX = 0x84 + MAGIKARP = 0x85 + MISSINGNO_86 = 0x86 + MISSINGNO_87 = 0x87 + MUK = 0x88 + MISSINGNO_89 = 0x89 + KINGLER = 0x8A + CLOYSTER = 0x8B + MISSINGNO_8C = 0x8C + ELECTRODE = 0x8D + CLEFABLE = 0x8E + WEEZING = 0x8F + PERSIAN = 0x90 + MAROWAK = 0x91 + MISSINGNO_92 = 0x92 + HAUNTER = 0x93 + ABRA = 0x94 + ALAKAZAM = 0x95 + PIDGEOTTO = 0x96 + PIDGEOT = 0x97 + STARMIE = 0x98 + BULBASAUR = 0x99 + VENUSAUR = 0x9A + TENTACRUEL = 0x9B + MISSINGNO_9C = 0x9C + GOLDEEN = 0x9D + SEAKING = 0x9E + MISSINGNO_9F = 0x9F + MISSINGNO_A0 = 0xA0 + MISSINGNO_A1 = 0xA1 + MISSINGNO_A2 = 0xA2 + PONYTA = 0xA3 + RAPIDASH = 0xA4 + RATTATA = 0xA5 + RATICATE = 0xA6 + NIDORINO = 0xA7 + NIDORINA = 0xA8 + GEODUDE = 0xA9 + PORYGON = 0xAA + AERODACTYL = 0xAB + MISSINGNO_AC = 0xAC + MAGNEMITE = 0xAD + MISSINGNO_AE = 0xAE + MISSINGNO_AF = 0xAF + CHARMANDER = 0xB0 + SQUIRTLE = 0xB1 + CHARMELEON = 0xB2 + WARTORTLE = 0xB3 + CHARIZARD = 0xB4 + MISSINGNO_B5 = 0xB5 + FOSSIL_KABUTOPS = 0xB6 + FOSSIL_AERODACTYL = 0xB7 + MON_GHOST = 0xB8 + ODDISH = 0xB9 + GLOOM = 0xBA + VILEPLUME = 0xBB + BELLSPROUT = 0xBC + WEEPINBELL = 0xBD + VICTREEBEL = 0xBE diff --git a/pokemonred_puffer/data/strength_puzzles.py b/pokemonred_puffer/data/strength_puzzles.py new file mode 100644 index 0000000..5d553c8 --- /dev/null +++ b/pokemonred_puffer/data/strength_puzzles.py @@ -0,0 +1,398 @@ +STRENGTH_SOLUTIONS = {} + +################### +# SEAFOAM ISLANDS # +################### + +# Seafoam 1F Left +STRENGTH_SOLUTIONS[(63, 14, 22, 18, 11, 192)] = [ + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "RIGHT", + "UP", + "LEFT", +] +STRENGTH_SOLUTIONS[(63, 14, 22, 19, 10, 192)] = ["DOWN", "LEFT"] + STRENGTH_SOLUTIONS[ + (63, 14, 22, 18, 11, 192) +] +STRENGTH_SOLUTIONS[(63, 14, 22, 18, 9, 192)] = ["RIGHT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 14, 22, 19, 10, 192) +] +STRENGTH_SOLUTIONS[(63, 14, 22, 17, 10, 192)] = ["UP", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 14, 22, 18, 9, 192) +] + +# Seafoam 1F right +STRENGTH_SOLUTIONS[(63, 11, 30, 26, 8, 192)] = [ + "UP", + "RIGHT", + "UP", + "RIGHT", + "UP", + "UP", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", +] +STRENGTH_SOLUTIONS[(63, 11, 30, 27, 7, 192)] = ["DOWN", "LEFT"] + STRENGTH_SOLUTIONS[ + (63, 11, 30, 26, 8, 192) +] +STRENGTH_SOLUTIONS[(63, 11, 30, 26, 6, 192)] = ["RIGHT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 11, 30, 27, 7, 192) +] +STRENGTH_SOLUTIONS[(63, 11, 30, 25, 7, 192)] = ["UP", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 11, 30, 26, 6, 192) +] + +# Seafoam B1 left + +STRENGTH_SOLUTIONS[(63, 10, 21, 16, 6, 159)] = ["RIGHT"] +STRENGTH_SOLUTIONS[(63, 10, 21, 17, 5, 159)] = ["LEFT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 10, 21, 16, 6, 159) +] +STRENGTH_SOLUTIONS[(63, 10, 21, 17, 7, 159)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 10, 21, 16, 6, 159) +] + +# Seafoam B1 right + +STRENGTH_SOLUTIONS[(63, 10, 26, 21, 6, 159)] = ["RIGHT"] +STRENGTH_SOLUTIONS[(63, 10, 26, 22, 5, 159)] = ["LEFT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 10, 26, 21, 6, 159) +] +STRENGTH_SOLUTIONS[(63, 10, 26, 22, 7, 159)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 10, 26, 21, 6, 159) +] + +# Seafoam B2 left + +STRENGTH_SOLUTIONS[(63, 10, 22, 17, 6, 160)] = ["RIGHT"] +STRENGTH_SOLUTIONS[(63, 10, 22, 18, 5, 160)] = ["LEFT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 10, 22, 17, 6, 160) +] +STRENGTH_SOLUTIONS[(63, 10, 22, 18, 7, 160)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 10, 22, 17, 6, 160) +] + +# Seafoam B2 right + +STRENGTH_SOLUTIONS[(63, 10, 27, 24, 6, 160)] = ["LEFT"] +STRENGTH_SOLUTIONS[(63, 10, 27, 23, 7, 160)] = ["RIGHT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 10, 27, 24, 6, 160) +] + +# We skip seafoam b3 since that is for articuno +# TODO: Articuno + +################ +# VICTORY ROAD # +################ + +# 1F Switch 1 +STRENGTH_SOLUTIONS[(63, 19, 9, 5, 14, 108)] = [ + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "LEFT", + "DOWN", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "DOWN", + "RIGHT", + "RIGHT", + "UP", + "UP", + "UP", + "UP", + "UP", + "UP", + "LEFT", + "UP", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "DOWN", + "RIGHT", + "UP", + "UP", + "UP", + "UP", + "UP", + "LEFT", + "LEFT", + "UP", + "UP", + "UP", + "UP", + "RIGHT", + "RIGHT", + "RIGHT", + "RIGHT", + "UP", + "RIGHT", + "DOWN", +] + +STRENGTH_SOLUTIONS[(63, 19, 9, 4, 15, 108)] = ["UP", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 19, 9, 5, 14, 108) +] +STRENGTH_SOLUTIONS[(63, 19, 9, 5, 16, 108)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 19, 9, 4, 15, 108) +] + +# 2F Switch 1 +STRENGTH_SOLUTIONS[(63, 18, 8, 5, 14, 194)] = [ + "LEFT", + "LEFT", + "LEFT", + "UP", + "LEFT", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "RIGHT", + "DOWN", + "LEFT", + "LEFT", + "LEFT", + "LEFT", +] + +STRENGTH_SOLUTIONS[(63, 18, 8, 4, 13, 194)] = ["RIGHT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 18, 8, 5, 14, 194) +] +STRENGTH_SOLUTIONS[(63, 18, 8, 3, 14, 194)] = ["UP", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 18, 8, 4, 13, 194) +] +STRENGTH_SOLUTIONS[(63, 18, 8, 4, 15, 194)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 18, 8, 3, 14, 194) +] + +# 3F Switch 3 +STRENGTH_SOLUTIONS[(63, 19, 26, 22, 4, 198)] = [ + "UP", + "UP", + "RIGHT", + "UP", + "UP", + "LEFT", + "DOWN", + "DOWN", + "DOWN", + "LEFT", + "LEFT", + "UP", + "UP", + "RIGHT", + "UP", + "UP", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "UP", + "LEFT", + "DOWN", + "DOWN", + "RIGHT", + "DOWN", + "DOWN", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "UP", + "UP", + "LEFT", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "DOWN", + "LEFT", + "DOWN", + "RIGHT", + "RIGHT", +] + +STRENGTH_SOLUTIONS[(63, 19, 26, 23, 3, 198)] = ["DOWN", "LEFT"] + STRENGTH_SOLUTIONS[ + (63, 19, 26, 22, 4, 198) +] +STRENGTH_SOLUTIONS[(63, 19, 26, 22, 2, 198)] = ["RIGHT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 19, 26, 23, 3, 198) +] +STRENGTH_SOLUTIONS[(63, 19, 26, 21, 3, 198)] = ["UP", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 19, 26, 22, 2, 198) +] + +# 3F Boulder in hole +STRENGTH_SOLUTIONS[(63, 16, 17, 21, 15, 198)] = ["RIGHT", "RIGHT", "RIGHT"] +STRENGTH_SOLUTIONS[(63, 16, 17, 22, 16, 198)] = ["LEFT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 16, 17, 21, 15, 198) +] +STRENGTH_SOLUTIONS[(63, 16, 17, 22, 14, 198)] = ["LEFT", "DOWN"] + STRENGTH_SOLUTIONS[ + (63, 16, 17, 21, 15, 198) +] + + +# 2F final switch +STRENGTH_SOLUTIONS[(63, 20, 27, 24, 16, 194)] = [ + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", + "LEFT", +] + +STRENGTH_SOLUTIONS[(63, 20, 27, 23, 17, 194)] = ["RIGHT", "UP"] + STRENGTH_SOLUTIONS[ + (63, 20, 27, 24, 16, 194) +] +STRENGTH_SOLUTIONS[(63, 20, 27, 22, 16, 194)] = ["DOWN", "RIGHT"] + STRENGTH_SOLUTIONS[ + (63, 20, 27, 23, 17, 194) +] diff --git a/pokemonred_puffer/data/tilesets.py b/pokemonred_puffer/data/tilesets.py new file mode 100644 index 0000000..5c471a2 --- /dev/null +++ b/pokemonred_puffer/data/tilesets.py @@ -0,0 +1,28 @@ +from enum import Enum + + +class Tilesets(Enum): + OVERWORLD = 0 + REDS_HOUSE_1 = 1 + MART = 2 + FOREST = 3 + REDS_HOUSE_2 = 4 + DOJO = 5 + POKECENTER = 6 + GYM = 7 + HOUSE = 8 + FOREST_GATE = 9 + MUSEUM = 10 + UNDERGROUND = 11 + GATE = 12 + SHIP = 13 + SHIP_PORT = 14 + CEMETERY = 15 + INTERIOR = 16 + CAVERN = 17 + LOBBY = 18 + MANSION = 19 + LAB = 20 + CLUB = 21 + FACILITY = 22 + PLATEAU = 23 diff --git a/pokemonred_puffer/data/tm_hm.py b/pokemonred_puffer/data/tm_hm.py new file mode 100644 index 0000000..0916d26 --- /dev/null +++ b/pokemonred_puffer/data/tm_hm.py @@ -0,0 +1,194 @@ +from enum import Enum +from pokemonred_puffer.data.species import Species + + +class TmHmMoves(Enum): + MEGA_PUNCH = (0x5,) + RAZOR_WIND = 0xD + SWORDS_DANCE = 0xE + WHIRLWIND = 0x12 + MEGA_KICK = 0x19 + TOXIC = 0x5C + HORN_DRILL = 0x20 + BODY_SLAM = 0x22 + TAKE_DOWN = 0x24 + DOUBLE_EDGE = 0x26 + BUBBLE_BEAM = 0x3D + WATER_GUN = 0x37 + ICE_BEAM = 0x3A + BLIZZARD = 0x3B + HYPER_BEAM = 0x3F + PAY_DAY = 0x06 + SUBMISSION = 0x42 + COUNTER = 0x44 + SEISMIC_TOSS = 0x45 + RAGE = 0x63 + MEGA_DRAIN = 0x48 + SOLAR_BEAM = 0x4C + DRAGON_RAGE = 0x52 + THUNDERBOLT = 0x55 + THUNDER = 0x57 + EARTHQUAKE = 0x59 + FISSURE = 0x5A + DIG = 0x5B + PSYCHIC = 0x5E + TELEPORT = 0x64 + MIMIC = 0x66 + DOUBLE_TEAM = 0x68 + REFLECT = 0x73 + BIDE = 0x75 + METRONOME = 0x76 + SELFDESTRUCT = 0x78 + EGG_BOMB = 0x79 + FIRE_BLAST = 0x7E + SWIFT = 0x81 + SKULL_BASH = 0x82 + SOFTBOILED = 0x87 + DREAM_EATER = 0x8A + SKY_ATTACK = 0x8F + REST = 0x9C + THUNDER_WAVE = 0x56 + PSYWAVE = 0x95 + EXPLOSION = 0x99 + ROCK_SLIDE = 0x9D + TRI_ATTACK = 0xA1 + SUBSTITUTE = 0xA4 + CUT = 0x0F + FLY = 0x13 + SURF = 0x39 + STRENGTH = 0x46 + FLASH = 0x94 + + +CUT_SPECIES_IDS = { + Species.BULBASAUR.value, + Species.IVYSAUR.value, + Species.VENUSAUR.value, + Species.CHARMANDER.value, + Species.CHARMELEON.value, + Species.CHARIZARD.value, + Species.BEEDRILL.value, + Species.SANDSHREW.value, + Species.SANDSLASH.value, + Species.ODDISH.value, + Species.GLOOM.value, + Species.VILEPLUME.value, + Species.PARAS.value, + Species.PARASECT.value, + Species.BELLSPROUT.value, + Species.WEEPINBELL.value, + Species.VICTREEBEL.value, + Species.TENTACOOL.value, + Species.TENTACRUEL.value, + Species.FARFETCHD.value, + Species.KRABBY.value, + Species.KINGLER.value, + Species.LICKITUNG.value, + Species.TANGELA.value, + Species.SCYTHER.value, + Species.PINSIR.value, + Species.MEW.value, +} + +SURF_SPECIES_IDS = { + Species.SQUIRTLE.value, + Species.WARTORTLE.value, + Species.BLASTOISE.value, + Species.NIDOQUEEN.value, + Species.NIDOKING.value, + Species.PSYDUCK.value, + Species.GOLDUCK.value, + Species.POLIWAG.value, + Species.POLIWHIRL.value, + Species.POLIWRATH.value, + Species.TENTACOOL.value, + Species.TENTACRUEL.value, + Species.SLOWPOKE.value, + Species.SLOWBRO.value, + Species.SEEL.value, + Species.DEWGONG.value, + Species.SHELLDER.value, + Species.CLOYSTER.value, + Species.KRABBY.value, + Species.KINGLER.value, + Species.LICKITUNG.value, + Species.RHYDON.value, + Species.KANGASKHAN.value, + Species.HORSEA.value, + Species.SEADRA.value, + Species.GOLDEEN.value, + Species.SEAKING.value, + Species.STARYU.value, + Species.STARMIE.value, + Species.GYARADOS.value, + Species.LAPRAS.value, + Species.VAPOREON.value, + Species.OMANYTE.value, + Species.OMASTAR.value, + Species.KABUTO.value, + Species.KABUTOPS.value, + Species.SNORLAX.value, + Species.DRATINI.value, + Species.DRAGONAIR.value, + Species.DRAGONITE.value, + Species.MEW.value, +} + +STRENGTH_SPECIES_IDS = { + Species.CHARMANDER.value, + Species.CHARMELEON.value, + Species.CHARIZARD.value, + Species.SQUIRTLE.value, + Species.WARTORTLE.value, + Species.BLASTOISE.value, + Species.EKANS.value, + Species.ARBOK.value, + Species.SANDSHREW.value, + Species.SANDSLASH.value, + Species.NIDOQUEEN.value, + Species.NIDOKING.value, + Species.CLEFAIRY.value, + Species.CLEFABLE.value, + Species.JIGGLYPUFF.value, + Species.WIGGLYTUFF.value, + Species.PSYDUCK.value, + Species.GOLDUCK.value, + Species.MANKEY.value, + Species.PRIMEAPE.value, + Species.POLIWHIRL.value, + Species.POLIWRATH.value, + Species.MACHOP.value, + Species.MACHOKE.value, + Species.MACHAMP.value, + Species.GEODUDE.value, + Species.GRAVELER.value, + Species.GOLEM.value, + Species.SLOWPOKE.value, + Species.SLOWBRO.value, + Species.SEEL.value, + Species.DEWGONG.value, + Species.GENGAR.value, + Species.ONIX.value, + Species.KRABBY.value, + Species.KINGLER.value, + Species.EXEGGUTOR.value, + Species.CUBONE.value, + Species.MAROWAK.value, + Species.HITMONLEE.value, + Species.HITMONCHAN.value, + Species.LICKITUNG.value, + Species.RHYHORN.value, + Species.RHYDON.value, + Species.CHANSEY.value, + Species.KANGASKHAN.value, + Species.ELECTABUZZ.value, + Species.MAGMAR.value, + Species.PINSIR.value, + Species.TAUROS.value, + Species.GYARADOS.value, + Species.LAPRAS.value, + Species.SNORLAX.value, + Species.DRAGONITE.value, + Species.MEWTWO.value, + Species.MEW.value, +} diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 4ed86db..3af4b2d 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -13,99 +13,30 @@ from gymnasium import Env, spaces from pyboy import PyBoy from pyboy.utils import WindowEvent -from skimage.transform import resize +# from skimage.transform import resize import pufferlib +from pokemonred_puffer.data.events import EVENT_FLAGS_START, EVENTS_FLAGS_LENGTH, MUSEUM_TICKET +from pokemonred_puffer.data.field_moves import FieldMoves +from pokemonred_puffer.data.items import ( + HM_ITEM_IDS, + KEY_ITEM_IDS, + MAX_ITEM_CAPACITY, + Items, +) +from pokemonred_puffer.data.strength_puzzles import STRENGTH_SOLUTIONS +from pokemonred_puffer.data.tilesets import Tilesets +from pokemonred_puffer.data.tm_hm import ( + CUT_SPECIES_IDS, + STRENGTH_SPECIES_IDS, + SURF_SPECIES_IDS, + TmHmMoves, +) from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global PIXEL_VALUES = np.array([0, 85, 153, 255], dtype=np.uint8) - -EVENT_FLAGS_START = 0xD747 -EVENTS_FLAGS_LENGTH = 320 -MUSEUM_TICKET = (0xD754, 0) - VISITED_MASK_SHAPE = (144 // 16, 160 // 16, 1) -TM_HM_MOVES = set( - [ - 5, # Mega punch - 0xD, # Razor wind - 0xE, # Swords dance - 0x12, # Whirlwind - 0x19, # Mega kick - 0x5C, # Toxic - 0x20, # Horn drill - 0x22, # Body slam - 0x24, # Take down - 0x26, # Double edge - 0x3D, # Bubble beam - 0x37, # Water gun - 0x3A, # Ice beam - 0x3B, # Blizzard - 0x3F, # Hyper beam - 0x06, # Pay day - 0x42, # Submission - 0x44, # Counter - 0x45, # Seismic toss - 0x63, # Rage - 0x48, # Mega drain - 0x4C, # Solar beam - 0x52, # Dragon rage - 0x55, # Thunderbolt - 0x57, # Thunder - 0x59, # Earthquake - 0x5A, # Fissure - 0x5B, # Dig - 0x5E, # Psychic - 0x64, # Teleport - 0x66, # Mimic - 0x68, # Double team - 0x73, # Reflect - 0x75, # Bide - 0x76, # Metronome - 0x78, # Selfdestruct - 0x79, # Egg bomb - 0x7E, # Fire blast - 0x81, # Swift - 0x82, # Skull bash - 0x87, # Softboiled - 0x8A, # Dream eater - 0x8F, # Sky attack - 0x9C, # Rest - 0x56, # Thunder wave - 0x95, # Psywave - 0x99, # Explosion - 0x9D, # Rock slide - 0xA1, # Tri attack - 0xA4, # Substitute - 0x0F, # Cut - 0x13, # Fly - 0x39, # Surf - 0x46, # Strength - 0x94, # Flash - ] -) - -HM_ITEM_IDS = set([0xC4, 0xC5, 0xC6, 0xC7, 0xC8]) - -RESET_MAP_IDS = set( - [ - 0x0, # Pallet Town - 0x1, # Viridian City - 0x2, # Pewter City - 0x3, # Cerulean City - 0x4, # Lavender Town - 0x5, # Vermilion City - 0x6, # Celadon City - 0x7, # Fuchsia City - 0x8, # Cinnabar Island - 0x9, # Indigo Plateau - 0xA, # Saffron City - 0xF, # Route 4 (Mt Moon) - 0x10, # Route 10 (Rock Tunnel) - 0xE9, # Silph Co 9F (Heal station) - ] -) VALID_ACTIONS = [ WindowEvent.PRESS_ARROW_DOWN, @@ -115,6 +46,7 @@ WindowEvent.PRESS_BUTTON_A, WindowEvent.PRESS_BUTTON_B, WindowEvent.PRESS_BUTTON_START, + WindowEvent.PASS, ] VALID_RELEASE_ACTIONS = [ @@ -125,6 +57,7 @@ WindowEvent.RELEASE_BUTTON_A, WindowEvent.RELEASE_BUTTON_B, WindowEvent.RELEASE_BUTTON_START, + WindowEvent.PASS, ] VALID_ACTIONS_STR = ["down", "left", "right", "up", "a", "b", "start"] @@ -158,9 +91,23 @@ def __init__(self, env_config: pufferlib.namespace): self.gb_path = env_config.gb_path self.log_frequency = env_config.log_frequency self.two_bit = env_config.two_bit + self.auto_flash = env_config.auto_flash + self.disable_wild_encounters = env_config.disable_wild_encounters + self.disable_ai_actions = env_config.disable_ai_actions + self.auto_teach_cut = env_config.auto_teach_cut + self.auto_teach_surf = env_config.auto_teach_surf + self.auto_teach_strength = env_config.auto_teach_strength + self.auto_use_cut = env_config.auto_use_cut + self.auto_use_surf = env_config.auto_use_surf + self.auto_solve_strength_puzzles = env_config.auto_solve_strength_puzzles + self.auto_remove_all_nonuseful_items = env_config.auto_remove_all_nonuseful_items + self.auto_pokeflute = env_config.auto_pokeflute + self.infinite_money = env_config.infinite_money + self.use_global_map = env_config.use_global_map self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? + self.global_map_shape = GLOBAL_MAP_SHAPE if self.reduce_res: self.screen_output_shape = (72, 80, 1) else: @@ -171,6 +118,7 @@ def __init__(self, env_config: pufferlib.namespace): self.screen_output_shape[1] // 4, 1, ) + self.global_map_shape = (self.global_map_shape[0], self.global_map_shape[1] // 4, 1) self.coords_pad = 12 self.enc_freqs = 8 @@ -192,28 +140,34 @@ def __init__(self, env_config: pufferlib.namespace): v: i for i, v in enumerate([40, 0, 12, 1, 13, 51, 2, 54, 14, 59, 60, 61, 15, 3, 65]) } - self.observation_space = spaces.Dict( - { - "screen": spaces.Box( - low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - ), - "visited_mask": spaces.Box( - low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - ), - "global_map": spaces.Box( - low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 - ), - # Discrete is more apt, but pufferlib is slower at processing Discrete - "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - # "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), - # "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), - # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), - # "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), - # "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8), - } - ) + obs_dict = { + "screen": spaces.Box(low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8), + "visited_mask": spaces.Box( + low=0, high=255, shape=self.screen_output_shape, dtype=np.uint8 + ), + # Discrete is more apt, but pufferlib is slower at processing Discrete + "direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), + "blackout_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + "battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8), + "cut_event": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + # "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.u`int8), + # "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8), + # "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16), + "badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8), + "wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8), + "bag_items": spaces.Box( + low=0, high=max(Items._value2member_map_.keys()), shape=(20,), dtype=np.uint8 + ), + "bag_quantity": spaces.Box(low=0, high=100, shape=(20,), dtype=np.uint8), + } + + if self.use_global_map: + obs_dict["global_map"] = spaces.Box( + low=0, high=255, shape=self.global_map_shape, dtype=np.uint8 + ) + self.observation_space = spaces.Dict(obs_dict) self.pyboy = PyBoy( env_config.gb_path, @@ -243,6 +197,8 @@ def __init__(self, env_config: pufferlib.namespace): RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF RedGymEnv.env_id.buf[3] = (env_id) & 0xFF + self.init_mem() + def register_hooks(self): self.pyboy.hook_register(None, "DisplayStartMenu", self.start_menu_hook, None) self.pyboy.hook_register(None, "RedisplayStartMenu", self.start_menu_hook, None) @@ -254,10 +210,31 @@ def register_hooks(self): self.pyboy.hook_register( None, "CheckForHiddenObject.foundMatchingObject", self.hidden_object_hook, None ) + """ + _, addr = self.pyboy.symbol_lookup("IsSpriteOrSignInFrontOfPlayer.retry") + self.pyboy.hook_register( + None, addr-1, self.sign_hook, None + ) + """ self.pyboy.hook_register(None, "HandleBlackOut", self.blackout_hook, None) self.pyboy.hook_register(None, "SetLastBlackoutMap.done", self.blackout_update_hook, None) - self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True) - self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False) + # self.pyboy.hook_register(None, "UsedCut.nothingToCut", self.cut_hook, context=True) + # self.pyboy.hook_register(None, "UsedCut.canCut", self.cut_hook, context=False) + if self.disable_wild_encounters: + self.setup_disable_wild_encounters() + + def setup_disable_wild_encounters(self): + bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") + self.pyboy.hook_register( + bank, + addr + 8, + self.disable_wild_encounter_hook, + None, + ) + + def setup_enable_wild_ecounters(self): + bank, addr = self.pyboy.symbol_lookup("TryDoWildEncounter.gotWildEncounterType") + self.pyboy.hook_deregister(bank, addr) def update_state(self, state: bytes): self.reset(seed=random.randint(0, 10), options={"state": state}) @@ -270,9 +247,9 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = if self.first or options.get("state", None) is not None: self.recent_screens = deque() self.recent_actions = deque() - self.init_mem() # We only init seen hidden objs once cause they can only be found once! self.seen_hidden_objs = {} + self.seen_signs = {} if options.get("state", None) is not None: self.pyboy.load_state(io.BytesIO(options["state"])) self.reset_count += 1 @@ -280,8 +257,6 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = with open(self.init_state_path, "rb") as f: self.pyboy.load_state(f) self.reset_count = 0 - self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) - self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.base_event_flags = sum( self.read_m(i).bit_count() for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH) @@ -308,7 +283,7 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.update_pokedex() self.update_tm_hm_moves_obtained() - self.taught_cut = self.check_if_party_has_cut() + self.taught_cut = self.check_if_party_has_hm(0xF) self.levels_satisfied = False self.base_explore = 0 self.max_opponent_level = 0 @@ -341,7 +316,8 @@ def init_mem(self): # Maybe I should preallocate a giant matrix for all map ids # All map ids have the same size, right? self.seen_coords = {} - # self.seen_global_coords = np.zeros(GLOBAL_MAP_SHAPE) + self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) self.seen_map_ids = np.zeros(256) self.seen_npcs = {} @@ -355,13 +331,6 @@ def init_mem(self): self.seen_action_bag_menu = 0 def reset_mem(self): - self.seen_coords.update((k, 0) for k, _ in self.seen_coords.items()) - self.seen_map_ids *= 0 - self.seen_npcs.update((k, 0) for k, _ in self.seen_npcs.items()) - - self.cut_coords.update((k, 0) for k, _ in self.cut_coords.items()) - self.cut_tiles.update((k, 0) for k, _ in self.cut_tiles.items()) - self.seen_start_menu = 0 self.seen_pokemon_menu = 0 self.seen_stats_menu = 0 @@ -456,12 +425,17 @@ def render(self): ) ).astype(np.uint8) visited_mask = np.expand_dims(visited_mask, -1) - """ global_map = np.expand_dims( 255 * resize(self.explore_map, game_pixels_render.shape, anti_aliasing=False), axis=-1, ).astype(np.uint8) + """ + if self.use_global_map: + global_map = np.expand_dims( + 255 * self.explore_map, + axis=-1, + ).astype(np.uint8) if self.two_bit: game_pixels_render = ( @@ -487,39 +461,48 @@ def render(self): .reshape(game_pixels_render.shape) .astype(np.uint8) ) - global_map = ( - ( - np.digitize( - global_map.reshape((-1, 4)), - np.array([0, 64, 128, 255], dtype=np.uint8), - right=True, - ).astype(np.uint8) - << np.array([6, 4, 2, 0], dtype=np.uint8) + if self.use_global_map: + global_map = ( + ( + np.digitize( + global_map.reshape((-1, 4)), + np.array([0, 64, 128, 255], dtype=np.uint8), + right=True, + ).astype(np.uint8) + << np.array([6, 4, 2, 0], dtype=np.uint8) + ) + .sum(axis=1, dtype=np.uint8) + .reshape(self.global_map_shape) ) - .sum(axis=1, dtype=np.uint8) - .reshape(game_pixels_render.shape) - ) return { "screen": game_pixels_render, "visited_mask": visited_mask, - "global_map": global_map, - } + } | ({"global_map": global_map} if self.use_global_map else {}) def _get_obs(self): # player_x, player_y, map_n = self.get_game_coords() - return { - **self.render(), + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + bag = np.array(self.pyboy.memory[wBagItems : wBagItems + 40], dtype=np.uint8) + numBagItems = self.read_m("wNumBagItems") + # item ids start at 1 so using 0 as the nothing value is okay + bag[2 * numBagItems :] = 0 + + return self.render() | { "direction": np.array( self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8 ), - # "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), + "blackout_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8), "battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8), - # "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8), + "cut_event": np.array(self.read_bit(0xD803, 0), dtype=np.uint8), + "cut_in_party": np.array(self.check_if_party_has_hm(0xF), dtype=np.uint8), # "x": np.array(player_x, dtype=np.uint8), # "y": np.array(player_y, dtype=np.uint8), - # "map_id": np.array(map_n, dtype=np.uint8), - # "badges": np.array(self.get_badges(), dtype=np.uint8), + "map_id": np.array(self.read_m(0xD35E), dtype=np.uint8), + "badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8), + "wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8), + "bag_items": bag[::2].copy(), + "bag_quantity": bag[1::2].copy(), } def set_perfect_iv_dvs(self): @@ -528,12 +511,12 @@ def set_perfect_iv_dvs(self): _, addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Species") self.pyboy.memory[addr + 17 : addr + 17 + 12] = 0xFF - def check_if_party_has_cut(self) -> bool: + def check_if_party_has_hm(self, hm: int) -> bool: party_size = self.read_m("wPartyCount") for i in range(party_size): # PRET 1-indexes _, addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Moves") - if 15 in self.pyboy.memory[addr : addr + 4]: + if hm in self.pyboy.memory[addr : addr + 4]: return True return False @@ -541,6 +524,20 @@ def step(self, action): if self.save_video and self.step_count == 0: self.start_video() + _, wMapPalOffset = self.pyboy.symbol_lookup("wMapPalOffset") + if self.auto_flash and self.pyboy.memory[wMapPalOffset] == 6: + self.pyboy.memory[wMapPalOffset] = 0 + + if self.auto_remove_all_nonuseful_items: + self.remove_all_nonuseful_items() + + _, wPlayerMoney = self.pyboy.symbol_lookup("wPlayerMoney") + if ( + self.infinite_money + and int.from_bytes(self.pyboy.memory[wPlayerMoney : wPlayerMoney + 3], "little") < 10000 + ): + self.pyboy.memory[wPlayerMoney : wPlayerMoney + 3] = int(10000).to_bytes(3, "little") + self.run_action_on_emulator(action) self.update_seen_coords() self.update_health() @@ -553,7 +550,7 @@ def step(self, action): self.update_map_progress() if self.perfect_ivs: self.set_perfect_iv_dvs() - self.taught_cut = self.check_if_party_has_cut() + self.taught_cut = self.check_if_party_has_hm(0xF) self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} @@ -575,16 +572,424 @@ def step(self, action): # self.caught_pokemon[6] == 1 # squirtle ) + # cut mon check + if not self.party_has_cut_capable_mon(): + reset = True + self.first = True + new_reward = -self.total_reward * 0.5 + return obs, new_reward, reset, False, info def run_action_on_emulator(self, action): self.action_hist[action] += 1 # press button then release after some steps # TODO: Add video saving logic - self.pyboy.send_input(VALID_ACTIONS[action]) - self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) + + if not self.disable_ai_actions: + self.pyboy.send_input(VALID_ACTIONS[action]) + self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8) self.pyboy.tick(self.action_freq, render=True) + if self.read_bit(0xD803, 0): + if self.auto_teach_cut and not self.check_if_party_has_hm(0x0F): + self.teach_hm(TmHmMoves.CUT.value, 30, CUT_SPECIES_IDS) + if self.auto_use_cut: + self.cut_if_next() + + if self.read_bit(0xD78E, 0): + if self.auto_teach_surf and not self.check_if_party_has_hm(0x39): + self.teach_hm(TmHmMoves.SURF.value, 15, SURF_SPECIES_IDS) + if self.auto_use_surf: + self.surf_if_attempt(VALID_ACTIONS[action]) + + if self.read_bit(0xD857, 0): + if self.auto_teach_strength and not self.check_if_party_has_hm(0x46): + self.teach_hm(TmHmMoves.STRENGTH.value, 15, STRENGTH_SPECIES_IDS) + if self.auto_solve_strength_puzzles: + self.solve_missable_strength_puzzle() + self.solve_switch_strength_puzzle() + + if self.read_bit(0xD76C, 0) and self.auto_pokeflute: + self.use_pokeflute() + + def party_has_cut_capable_mon(self): + # find bulba and replace tackle (first skill) with cut + party_size = self.read_m("wPartyCount") + for i in range(party_size): + # PRET 1-indexes + _, species_addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Species") + poke = self.pyboy.memory[species_addr] + # https://github.com/pret/pokered/blob/d38cf5281a902b4bd167a46a7c9fd9db436484a7/constants/pokemon_constants.asm + if poke in CUT_SPECIES_IDS: + return True + return False + + def teach_hm(self, tmhm: int, pp: int, pokemon_species_ids): + # find bulba and replace tackle (first skill) with cut + party_size = self.read_m("wPartyCount") + for i in range(party_size): + # PRET 1-indexes + _, species_addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Species") + poke = self.pyboy.memory[species_addr] + # https://github.com/pret/pokered/blob/d38cf5281a902b4bd167a46a7c9fd9db436484a7/constants/pokemon_constants.asm + if poke in pokemon_species_ids: + for slot in range(4): + if self.read_m(f"wPartyMon{i+1}Moves") not in {0xF, 0x13, 0x39, 0x46, 0x94}: + _, move_addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Moves") + _, pp_addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}PP") + self.pyboy.memory[move_addr + slot] = tmhm + self.pyboy.memory[pp_addr + slot] = pp + # fill up pp: 30/30 + break + + def use_pokeflute(self): + in_overworld = self.read_m("wCurMapTileset") == Tilesets.OVERWORLD.value + if in_overworld: + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + bag_items = self.pyboy.memory[wBagItems : wBagItems + 40] + if Items.POKE_FLUTE.value not in bag_items[::2]: + return + pokeflute_index = bag_items[::2].index(Items.POKE_FLUTE.value) + + # Check if we're on the snorlax coordinates + + coords = self.get_game_coords() + if coords == (9, 62, 23): + self.pyboy.button("RIGHT", 8) + self.pyboy.tick(self.action_freq, render=True) + elif coords == (10, 63, 23): + self.pyboy.button("UP", 8) + self.pyboy.tick(self.action_freq, render=True) + elif coords == (10, 61, 23): + self.pyboy.button("DOWN", 8) + self.pyboy.tick(self.action_freq, render=True) + elif coords == (27, 10, 27): + self.pyboy.button("LEFT", 8) + self.pyboy.tick(self.action_freq, render=True) + elif coords == (27, 10, 25): + self.pyboy.button("RIGHT", 8) + self.pyboy.tick(self.action_freq, render=True) + else: + return + # Then check if snorlax is a missable object + # Then trigger snorlax + + _, wMissableObjectFlags = self.pyboy.symbol_lookup("wMissableObjectFlags") + _, wMissableObjectList = self.pyboy.symbol_lookup("wMissableObjectList") + missable_objects_list = self.pyboy.memory[ + wMissableObjectList : wMissableObjectList + 34 + ] + missable_objects_list = missable_objects_list[: missable_objects_list.index(0xFF)] + missable_objects_sprite_ids = missable_objects_list[::2] + missable_objects_flags = missable_objects_list[1::2] + for sprite_id in missable_objects_sprite_ids: + picture_id = self.read_m(f"wSprite{sprite_id:02}StateData1PictureID") + flags_bit = missable_objects_flags[missable_objects_sprite_ids.index(sprite_id)] + flags_byte = flags_bit // 8 + flag_bit = flags_bit % 8 + flag_byte_value = self.read_bit(wMissableObjectFlags + flags_byte, flag_bit) + if picture_id == 0x43 and not flag_byte_value: + # open start menu + self.pyboy.button("START", 8) + self.pyboy.tick(self.action_freq, render=True) + # scroll to bag + # 2 is the item index for bag + for _ in range(24): + if self.read_m("wCurrentMenuItem") == 2: + break + self.pyboy.button("DOWN", 8) + self.pyboy.tick(self.action_freq, render=True) + self.pyboy.button("A", 8) + self.pyboy.tick(self.action_freq, render=True) + + # Scroll until you get to pokeflute + # We'll do this by scrolling all the way up then all the way down + # There is a faster way to do it, but this is easier to think about + # Could also set the menu index manually, but there are like 4 variables + # for that + for _ in range(20): + self.pyboy.button("UP", 8) + self.pyboy.tick(self.action_freq, render=True) + + for _ in range(21): + if ( + self.read_m("wCurrentMenuItem") + self.read_m("wListScrollOffset") + == pokeflute_index + ): + break + self.pyboy.button("DOWN", 8) + self.pyboy.tick(self.action_freq, render=True) + + # press a bunch of times + for _ in range(5): + self.pyboy.button("A", 8) + self.pyboy.tick(4 * self.action_freq, render=True) + + break + + def cut_if_next(self): + # https://github.com/pret/pokered/blob/d38cf5281a902b4bd167a46a7c9fd9db436484a7/constants/tileset_constants.asm#L11C8-L11C11 + in_erika_gym = self.read_m("wCurMapTileset") == Tilesets.GYM.value + in_overworld = self.read_m("wCurMapTileset") == Tilesets.OVERWORLD.value + if in_erika_gym or in_overworld: + _, wTileMap = self.pyboy.symbol_lookup("wTileMap") + tileMap = self.pyboy.memory[wTileMap : wTileMap + 20 * 18] + tileMap = np.array(tileMap, dtype=np.uint8) + tileMap = np.reshape(tileMap, (18, 20)) + y, x = 8, 8 + up, down, left, right = ( + tileMap[y - 2 : y, x : x + 2], # up + tileMap[y + 2 : y + 4, x : x + 2], # down + tileMap[y : y + 2, x - 2 : x], # left + tileMap[y : y + 2, x + 2 : x + 4], # right + ) + + # Gym trees apparently get the same tile map as outside bushes + # GYM = 7 + if (in_overworld and 0x3D in up) or (in_erika_gym and 0x50 in up): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_UP) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_UP, delay=8) + self.pyboy.tick(self.action_freq, render=True) + elif (in_overworld and 0x3D in down) or (in_erika_gym and 0x50 in down): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + elif (in_overworld and 0x3D in left) or (in_erika_gym and 0x50 in left): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_LEFT) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_LEFT, delay=8) + self.pyboy.tick(self.action_freq, render=True) + elif (in_overworld and 0x3D in right) or (in_erika_gym and 0x50 in right): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_RIGHT) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_RIGHT, delay=8) + self.pyboy.tick(self.action_freq, render=True) + else: + return + + # open start menu + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_START) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START, delay=8) + self.pyboy.tick(self.action_freq, render=True) + # scroll to pokemon + # 1 is the item index for pokemon + for _ in range(24): + if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1: + break + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(self.action_freq, render=True) + + # find pokemon with cut + # We run this over all pokemon so we dont end up in an infinite for loop + for _ in range(7): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + party_mon = self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] + _, addr = self.pyboy.symbol_lookup(f"wPartyMon{party_mon%6+1}Moves") + if 0xF in self.pyboy.memory[addr : addr + 4]: + break + + # Enter submenu + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(4 * self.action_freq, render=True) + + # Scroll until the field move is found + _, wFieldMoves = self.pyboy.symbol_lookup("wFieldMoves") + field_moves = self.pyboy.memory[wFieldMoves : wFieldMoves + 4] + + for _ in range(10): + current_item = self.read_m("wCurrentMenuItem") + if current_item < 4 and FieldMoves.CUT.value == field_moves[current_item]: + break + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + + # press a bunch of times + for _ in range(5): + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(4 * self.action_freq, render=True) + + def surf_if_attempt(self, action: WindowEvent): + if not ( + self.read_m("wWalkBikeSurfState") != 2 + and self.check_if_party_has_hm(0x39) + and action + in [ + WindowEvent.PRESS_ARROW_DOWN, + WindowEvent.PRESS_ARROW_LEFT, + WindowEvent.PRESS_ARROW_RIGHT, + WindowEvent.PRESS_ARROW_UP, + ] + ): + return + + in_overworld = self.read_m("wCurMapTileset") == Tilesets.OVERWORLD.value + in_plateau = self.read_m("wCurMapTileset") == Tilesets.PLATEAU.value + if in_overworld or in_plateau: + _, wTileMap = self.pyboy.symbol_lookup("wTileMap") + tileMap = self.pyboy.memory[wTileMap : wTileMap + 20 * 18] + tileMap = np.array(tileMap, dtype=np.uint8) + tileMap = np.reshape(tileMap, (18, 20)) + y, x = 8, 8 + # This could be made a little faster by only checking the + # direction that matters, but I decided to copy pasta the cut routine + up, down, left, right = ( + tileMap[y - 2 : y, x : x + 2], # up + tileMap[y + 2 : y + 4, x : x + 2], # down + tileMap[y : y + 2, x - 2 : x], # left + tileMap[y : y + 2, x + 2 : x + 4], # right + ) + + # down, up, left, right + direction = self.read_m("wSpritePlayerStateData1FacingDirection") + + if not ( + (direction == 0x4 and action == WindowEvent.PRESS_ARROW_UP and 0x14 in up) + or (direction == 0x0 and action == WindowEvent.PRESS_ARROW_DOWN and 0x14 in down) + or (direction == 0x8 and action == WindowEvent.PRESS_ARROW_LEFT and 0x14 in left) + or (direction == 0xC and action == WindowEvent.PRESS_ARROW_RIGHT and 0x14 in right) + ): + return + + # open start menu + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_START) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_START, delay=8) + self.pyboy.tick(self.action_freq, render=True) + # scroll to pokemon + # 1 is the item index for pokemon + for _ in range(24): + if self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] == 1: + break + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(self.action_freq, render=True) + + # find pokemon with surf + # We run this over all pokemon so we dont end up in an infinite for loop + for _ in range(7): + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + party_mon = self.pyboy.memory[self.pyboy.symbol_lookup("wCurrentMenuItem")[1]] + _, addr = self.pyboy.symbol_lookup(f"wPartyMon{party_mon%6+1}Moves") + if 0x39 in self.pyboy.memory[addr : addr + 4]: + break + + # Enter submenu + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(4 * self.action_freq, render=True) + + # Scroll until the field move is found + _, wFieldMoves = self.pyboy.symbol_lookup("wFieldMoves") + field_moves = self.pyboy.memory[wFieldMoves : wFieldMoves + 4] + + for _ in range(10): + current_item = self.read_m("wCurrentMenuItem") + if current_item < 4 and field_moves[current_item] in ( + FieldMoves.SURF.value, + FieldMoves.SURF_2.value, + ): + break + self.pyboy.send_input(WindowEvent.PRESS_ARROW_DOWN) + self.pyboy.send_input(WindowEvent.RELEASE_ARROW_DOWN, delay=8) + self.pyboy.tick(self.action_freq, render=True) + + # press a bunch of times + for _ in range(5): + self.pyboy.send_input(WindowEvent.PRESS_BUTTON_A) + self.pyboy.send_input(WindowEvent.RELEASE_BUTTON_A, delay=8) + self.pyboy.tick(4 * self.action_freq, render=True) + + def solve_missable_strength_puzzle(self): + in_cavern = self.read_m("wCurMapTileset") == Tilesets.CAVERN.value + if in_cavern: + _, wMissableObjectFlags = self.pyboy.symbol_lookup("wMissableObjectFlags") + _, wMissableObjectList = self.pyboy.symbol_lookup("wMissableObjectList") + missable_objects_list = self.pyboy.memory[ + wMissableObjectList : wMissableObjectList + 34 + ] + missable_objects_list = missable_objects_list[: missable_objects_list.index(0xFF)] + missable_objects_sprite_ids = missable_objects_list[::2] + missable_objects_flags = missable_objects_list[1::2] + + for sprite_id in missable_objects_sprite_ids: + flags_bit = missable_objects_flags[missable_objects_sprite_ids.index(sprite_id)] + flags_byte = flags_bit // 8 + flag_bit = flags_bit % 8 + flag_byte_value = self.read_bit(wMissableObjectFlags + flags_byte, flag_bit) + if not flag_byte_value: # True if missable + picture_id = self.read_m(f"wSprite{sprite_id:02}StateData1PictureID") + mapY = self.read_m(f"wSprite{sprite_id:02}StateData2MapY") + mapX = self.read_m(f"wSprite{sprite_id:02}StateData2MapX") + if solution := STRENGTH_SOLUTIONS.get( + (picture_id, mapY, mapX) + self.get_game_coords(), [] + ): + if not self.disable_wild_encounters: + self.setup_disable_wild_encounters() + # Activate strength + _, wd728 = self.pyboy.symbol_lookup("wd728") + self.pyboy.memory[wd728] |= 0b0000_0001 + # Perform solution + current_repel_steps = self.read_m("wRepelRemainingSteps") + for button in solution: + self.pyboy.memory[ + self.pyboy.symbol_lookup("wRepelRemainingSteps")[1] + ] = 0xFF + self.pyboy.button(button, 8) + self.pyboy.tick(self.action_freq * 1.5, render=True) + self.pyboy.memory[self.pyboy.symbol_lookup("wRepelRemainingSteps")[1]] = ( + current_repel_steps + ) + if not self.disable_wild_encounters: + self.setup_enable_wild_ecounters() + break + + def solve_switch_strength_puzzle(self): + in_cavern = self.read_m("wCurMapTileset") == Tilesets.CAVERN.value + if in_cavern: + for sprite_id in range(1, self.read_m("wNumSprites") + 1): + picture_id = self.read_m(f"wSprite{sprite_id:02}StateData1PictureID") + mapY = self.read_m(f"wSprite{sprite_id:02}StateData2MapY") + mapX = self.read_m(f"wSprite{sprite_id:02}StateData2MapX") + if solution := STRENGTH_SOLUTIONS.get( + (picture_id, mapY, mapX) + self.get_game_coords(), [] + ): + if not self.disable_wild_encounters: + self.setup_disable_wild_encounters() + # Activate strength + _, wd728 = self.pyboy.symbol_lookup("wd728") + self.pyboy.memory[wd728] |= 0b0000_0001 + # Perform solution + current_repel_steps = self.read_m("wRepelRemainingSteps") + for button in solution: + self.pyboy.memory[self.pyboy.symbol_lookup("wRepelRemainingSteps")[1]] = ( + 0xFF + ) + self.pyboy.button(button, 8) + self.pyboy.tick(self.action_freq * 2, render=True) + self.pyboy.memory[self.pyboy.symbol_lookup("wRepelRemainingSteps")[1]] = ( + current_repel_steps + ) + if not self.disable_wild_encounters: + self.setup_enable_wild_ecounters() + break + + def sign_hook(self, *args, **kwargs): + sign_id = self.pyboy.memory[self.pyboy.symbol_lookup("hSpriteIndexOrTextID")[1]] + map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]] + # We will store this by map id, y, x, + self.seen_hidden_objs[(map_id, sign_id)] = 1 + def hidden_object_hook(self, *args, **kwargs): hidden_object_id = self.pyboy.memory[self.pyboy.symbol_lookup("wHiddenObjectIndex")[1]] map_id = self.pyboy.memory[self.pyboy.symbol_lookup("wCurMap")[1]] @@ -600,8 +1005,8 @@ def start_menu_hook(self, *args, **kwargs): self.seen_start_menu = 1 def item_menu_hook(self, *args, **kwargs): - if self.read_m("wIsInBattle") == 0: - self.seen_bag_menu = 1 + # if self.read_m("wIsInBattle") == 0: + self.seen_bag_menu = 1 def pokemon_menu_hook(self, *args, **kwargs): if self.read_m("wIsInBattle") == 0: @@ -612,8 +1017,8 @@ def chose_stats_hook(self, *args, **kwargs): self.seen_stats_menu = 1 def chose_item_hook(self, *args, **kwargs): - if self.read_m("wIsInBattle") == 0: - self.seen_action_bag_menu = 1 + # if self.read_m("wIsInBattle") == 0: + self.seen_action_bag_menu = 1 def blackout_hook(self, *args, **kwargs): self.blackout_count += 1 @@ -639,10 +1044,7 @@ def cut_hook(self, context): self.pyboy.symbol_lookup("wTileInFrontOfPlayer")[1] ] if context: - if wTileInFrontOfPlayer in [ - 0x3D, - 0x50, - ]: + if wTileInFrontOfPlayer in [0x3D, 0x50]: self.cut_coords[coords] = 10 else: self.cut_coords[coords] = 0.001 @@ -652,8 +1054,13 @@ def cut_hook(self, context): self.cut_explore_map[local_to_global(y, x, map_id)] = 1 self.cut_tiles[wTileInFrontOfPlayer] = 1 + def disable_wild_encounter_hook(self, *args, **kwargs): + self.pyboy.memory[self.pyboy.symbol_lookup("wRepelRemainingSteps")[1]] = 0xFF + self.pyboy.memory[self.pyboy.symbol_lookup("wCurEnemyLVL")[1]] = 0x01 + def agent_stats(self, action): levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))] + badges = self.read_m("wObtainedBadges") return { "stats": { "step": self.step_count + self.reset_count * self.max_steps, @@ -685,7 +1092,7 @@ def agent_stats(self, action): "left_bills_house_after_helping": int(self.read_bit(0xD7F2, 7)), "got_hm01": int(self.read_bit(0xD803, 0)), "rubbed_captains_back": int(self.read_bit(0xD803, 1)), - "taught_cut": int(self.check_if_party_has_cut()), + "taught_cut": int(self.check_if_party_has_hm(0xF)), "cut_coords": sum(self.cut_coords.values()), "cut_tiles": len(self.cut_tiles), "start_menu": self.seen_start_menu, @@ -698,7 +1105,10 @@ def agent_stats(self, action): "reset_count": self.reset_count, "blackout_count": self.blackout_count, "pokecenter": np.sum(self.pokecenters), - }, + "rival3": int(self.read_m(0xD665) == 4), + "rocket_hideout_found": int(self.read_bit(0xD77E, 1)), + } + | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), "pokemon_exploration_map": self.explore_map, @@ -744,11 +1154,12 @@ def get_game_coords(self): return (self.read_m(0xD362), self.read_m(0xD361), self.read_m(0xD35E)) def update_seen_coords(self): - x_pos, y_pos, map_n = self.get_game_coords() - self.seen_coords[(x_pos, y_pos, map_n)] = 1 - self.explore_map[local_to_global(y_pos, x_pos, map_n)] = 1 - # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 - self.seen_map_ids[map_n] = 1 + if not (self.read_m("wd736") & 0b1000_0000): + x_pos, y_pos, map_n = self.get_game_coords() + self.seen_coords[(x_pos, y_pos, map_n)] = 1 + self.explore_map[local_to_global(y_pos, x_pos, map_n)] = 1 + # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 + self.seen_map_ids[map_n] = 1 def get_explore_map(self): explore_map = np.zeros(GLOBAL_MAP_SHAPE) @@ -854,6 +1265,52 @@ def update_tm_hm_moves_obtained(self): self.moves_obtained[move_id] = 1 """ + def remove_all_nonuseful_items(self): + _, wNumBagItems = self.pyboy.symbol_lookup("wNumBagItems") + if self.pyboy.memory[wNumBagItems] == MAX_ITEM_CAPACITY: + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + bag_items = self.pyboy.memory[wBagItems : wBagItems + MAX_ITEM_CAPACITY * 2] + # Fun fact: The way they test if an item is an hm in code is by testing the item id + # is greater than or equal to 0xC4 (the item id for HM_01) + + # TODO either remove or check if guard has been given drink + # guard given drink are 4 script pointers to check, NOT an event + new_bag_items = [ + (item, quantity) + for item, quantity in zip(bag_items[::2], bag_items[1::2]) + if (0x0 < item < Items.HM_01.value and (item - 1) in KEY_ITEM_IDS) + or item + in { + Items[name] + for name in [ + "LEMONADE", + "SODA_POP", + "FRESH_WATER", + "HM_01", + "HM_02", + "HM_03", + "HM_04", + "HM_05", + ] + } + ] + # Write the new count back to memory + self.pyboy.memory[wNumBagItems] = len(new_bag_items) + # 0 pad + new_bag_items += [(255, 255)] * (20 - len(new_bag_items)) + # now flatten list + new_bag_items = list(sum(new_bag_items, ())) + # now write back to list + self.pyboy.memory[wBagItems : wBagItems + len(new_bag_items)] = new_bag_items + + _, wBagSavedMenuItem = self.pyboy.symbol_lookup("wBagSavedMenuItem") + _, wListScrollOffset = self.pyboy.symbol_lookup("wListScrollOffset") + # TODO: Make this point to the location of the last removed item + # Should be something like the current location - the number of items + # that have been removed - 1 + self.pyboy.memory[wBagSavedMenuItem] = 0 + self.pyboy.memory[wListScrollOffset] = 0 + def read_hp_fraction(self): party_size = self.read_m("wPartyCount") hp_sum = sum(self.read_short(f"wPartyMon{i+1}HP") for i in range(party_size)) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index ced205e..f65ec29 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -4,6 +4,7 @@ import pufferlib.models from pufferlib.emulation import unpack_batched_obs +from pokemonred_puffer.data.items import Items from pokemonred_puffer.environment import PIXEL_VALUES unpack_batched_obs = torch.compiler.disable(unpack_batched_obs) @@ -44,6 +45,15 @@ def __init__( nn.ReLU(), nn.Flatten(), ) + self.global_map_network = nn.Sequential( + nn.LazyConv2d(32, 8, stride=4), + nn.ReLU(), + nn.LazyConv2d(64, 4, stride=2), + nn.ReLU(), + nn.LazyConv2d(64, 3, stride=1), + nn.ReLU(), + nn.Flatten(), + ) self.encode_linear = nn.Sequential( nn.LazyLinear(hidden_size), @@ -54,6 +64,7 @@ def __init__( self.value_fn = nn.LazyLinear(1) self.two_bit = env.unwrapped.env.two_bit + self.use_global_map = env.unwrapped.env.use_global_map self.register_buffer( "screen_buckets", torch.tensor(PIXEL_VALUES, dtype=torch.uint8), persistent=False @@ -69,14 +80,31 @@ def __init__( self.register_buffer( "unpack_shift", torch.tensor([6, 4, 2, 0], dtype=torch.uint8), persistent=False ) + self.register_buffer("badge_buffer", torch.arange(8) + 1, persistent=False) + + # pokemon has 0xF7 map ids + # Lets start with 4 dims for now. Could try 8 + self.map_embeddings = torch.nn.Embedding(0xF7, 4, dtype=torch.float32) + # N.B. This is an overestimate + item_count = max(Items._value2member_map_.keys()) + self.item_embeddings = torch.nn.Embedding( + item_count, int(item_count**0.25 + 1), dtype=torch.float32 + ) def encode_observations(self, observations): observations = unpack_batched_obs(observations, self.unflatten_context) screen = observations["screen"] visited_mask = observations["visited_mask"] - global_map = observations["global_map"] restored_shape = (screen.shape[0], screen.shape[1], screen.shape[2] * 4, screen.shape[3]) + if self.use_global_map: + global_map = observations["global_map"] + restored_global_map_shape = ( + global_map.shape[0], + global_map.shape[1], + global_map.shape[2] * 4, + global_map.shape[3], + ) if self.two_bit: screen = torch.index_select( @@ -91,36 +119,60 @@ def encode_observations(self, observations): .flatten() .int(), ).reshape(restored_shape) - global_map = torch.index_select( - self.linear_buckets, - 0, - ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) - .flatten() - .int(), - ).reshape(restored_shape) + if self.use_global_map: + global_map = torch.index_select( + self.linear_buckets, + 0, + ((global_map.reshape((-1, 1)) & self.unpack_mask) >> self.unpack_shift) + .flatten() + .int(), + ).reshape(restored_global_map_shape) + badges = self.badge_buffer <= observations["badges"] + map_id = self.map_embeddings(observations["map_id"].long()) + blackout_map_id = self.map_embeddings(observations["blackout_map_id"].long()) + # The bag quantity can be a value between 1 and 99 + # TODO: Should items be positionally encoded? I dont think it matters + items = self.item_embeddings(observations["bag_items"].squeeze(1).long()).float() * ( + observations["bag_quantity"].squeeze(1).float().unsqueeze(-1) / 100.0 + ) - image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) + # image_observation = torch.cat((screen, visited_mask, global_map), dim=-1) + image_observation = torch.cat((screen, visited_mask), dim=-1) if self.channels_last: image_observation = image_observation.permute(0, 3, 1, 2) + if self.use_global_map: + global_map = global_map.permute(0, 3, 1, 2) if self.downsample > 1: image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] - return self.encode_linear( - torch.cat( + cat_obs = torch.cat( + ( + self.screen_network(image_observation.float() / 255.0).squeeze(1), + one_hot(observations["direction"].long(), 4).float().squeeze(1), + # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), + one_hot(observations["battle_type"].long(), 4).float().squeeze(1), + observations["cut_event"].float(), + observations["cut_in_party"].float(), + # observations["x"].float(), + # observations["y"].float(), + # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), + badges.float().squeeze(1), + map_id.squeeze(1), + blackout_map_id.squeeze(1), + observations["wJoyIgnore"].float(), + items.flatten(start_dim=1), + ), + dim=-1, + ) + if self.use_global_map: + cat_obs = torch.cat( ( - (self.screen_network(image_observation.float() / 255.0).squeeze(1)), - one_hot(observations["direction"].long(), 4).float().squeeze(1), - # one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1), - one_hot(observations["battle_type"].long(), 4).float().squeeze(1), - # observations["cut_in_party"].float(), - # observations["x"].float(), - # observations["y"].float(), - # one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1), - # one_hot(observations["badges"].long(), 8).float().squeeze(1), + cat_obs, + self.global_map_network(global_map.float() / 255.0).squeeze(1), ), dim=-1, ) - ), None + return self.encode_linear(cat_obs), None def decode_actions(self, flat_hidden, lookup, concat=None): action = self.actor(flat_hidden) diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index c6da5d2..4f8497a 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -5,8 +5,6 @@ RedGymEnv, ) -import numpy as np - MUSEUM_TICKET = (0xD754, 0) @@ -34,7 +32,7 @@ def get_game_state_reward(self): # "heal": self.total_healing_rew, "explore": sum(self.seen_coords.values()) * 0.012, # "explore_maps": np.sum(self.seen_map_ids) * 0.0001, - "taught_cut": 4 * int(self.check_if_party_has_cut()), + "taught_cut": 4 * int(self.check_if_party_has_hm(0xF)), "cut_coords": sum(self.cut_coords.values()) * 1.0, "cut_tiles": sum(self.cut_tiles.values()) * 1.0, "met_bill": 5 * int(self.read_bit(0xD7F1, 0)), @@ -165,30 +163,10 @@ def get_levels_reward(self): return 15 + (self.max_level_sum - 15) / 4 -class RockTunnelReplicationEnv(BaselineRewardEnv): +class CutWithObjectRewardsEnv(BaselineRewardEnv): def get_game_state_reward(self): return { - "level": self.reward_config["level"] * self.get_levels_reward(), - "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), - "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), "event": self.reward_config["event"] * self.update_max_event_rew(), - "seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon), - "caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon), - "moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained), - "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), - "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles), - "start_menu": ( - self.reward_config["start_menu"] * self.seen_start_menu * int(self.taught_cut) - ), - "pokemon_menu": ( - self.reward_config["pokemon_menu"] * self.seen_pokemon_menu * int(self.taught_cut) - ), - "stats_menu": ( - self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut) - ), - "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut), - # "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters), - "badges": self.reward_config["badges"] * self.get_badges(), "met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)), "used_cell_separator_on_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 3)), @@ -198,7 +176,26 @@ def get_game_state_reward(self): * int(self.read_bit(0xD7F2, 6)), "left_bills_house_after_helping": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 7)), + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained), + "hm_count": self.reward_config["hm_count"] * self.get_hm_count(), + "level": self.reward_config["level"] * self.get_levels_reward(), + "badges": self.reward_config["badges"] * self.get_badges(), + "exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()), + "cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()), + "cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles.values()), + "start_menu": self.reward_config["start_menu"] * self.seen_start_menu, + "pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu, + "stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu, + "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu, "rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4), + "rocket_hideout_found": self.reward_config["rocket_hideout_found"] + * int(self.read_bit(0xD77E, 1)), + "explore_hidden_objs": sum(self.seen_hidden_objs.values()) + * self.reward_config["explore_hidden_objs"], + "seen_action_bag_menu": self.seen_action_bag_menu + * self.reward_config["seen_action_bag_menu"], } def get_levels_reward(self): diff --git a/pokemonred_puffer/wrappers/exploration.py b/pokemonred_puffer/wrappers/exploration.py index 8c6c482..747a8ad 100644 --- a/pokemonred_puffer/wrappers/exploration.py +++ b/pokemonred_puffer/wrappers/exploration.py @@ -1,10 +1,11 @@ from collections import OrderedDict +import random import gymnasium as gym import numpy as np import pufferlib from pokemonred_puffer.environment import RedGymEnv -from pokemonred_puffer.global_map import local_to_global +from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE, local_to_global class LRUCache: @@ -93,3 +94,56 @@ def step(self, action): def reset(self, *args, **kwargs): self.cache.clear() return self.env.reset(*args, **kwargs) + + +class OnResetExplorationWrapper(gym.Wrapper): + def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace): + super().__init__(env) + self.full_reset_frequency = reward_config.full_reset_frequency + self.jitter = reward_config.jitter + self.counter = 0 + + def reset(self, *args, **kwargs): + if (self.counter + random.randint(0, self.jitter)) >= self.full_reset_frequency: + self.counter = 0 + self.env.unwrapped.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.env.unwrapped.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32) + self.env.unwrapped.seen_coords.clear() + self.env.unwrapped.seen_map_ids *= 0 + self.env.unwrapped.seen_npcs.clear() + self.env.unwrapped.cut_coords.clear() + self.env.unwrapped.cut_tiles.clear() + self.counter += 1 + return self.env.reset(*args, **kwargs) + + +class OnResetLowerToFixedValueWrapper(gym.Wrapper): + def __init__(self, env: RedGymEnv, reward_config: pufferlib.namespace): + super().__init__(env) + self.fixed_value = reward_config.fixed_value + + def reset(self, *args, **kwargs): + self.env.unwrapped.seen_coords.update( + (k, self.fixed_value["coords"]) + for k, v in self.env.unwrapped.seen_coords.items() + if v > 0 + ) + self.env.unwrapped.seen_map_ids[self.env.unwrapped.seen_map_ids > 0] = self.fixed_value[ + "map_ids" + ] + self.env.unwrapped.seen_npcs.update( + (k, self.fixed_value["npc"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 + ) + self.env.unwrapped.cut_tiles.update( + (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 + ) + self.env.unwrapped.cut_coords.update( + (k, self.fixed_value["cut"]) for k, v in self.env.unwrapped.seen_npcs.items() if v > 0 + ) + self.env.unwrapped.explore_map[self.env.unwrapped.explore_map > 0] = self.fixed_value[ + "explore" + ] + self.env.unwrapped.cut_explore_map[self.env.unwrapped.cut_explore_map > 0] = ( + self.fixed_value["cut"] + ) + return self.env.reset(*args, **kwargs) diff --git a/pyboy_states/pokeflute.state b/pyboy_states/pokeflute.state new file mode 100644 index 0000000..2329258 Binary files /dev/null and b/pyboy_states/pokeflute.state differ diff --git a/pyboy_states/seafoam.state b/pyboy_states/seafoam.state new file mode 100644 index 0000000..cbe4aef Binary files /dev/null and b/pyboy_states/seafoam.state differ diff --git a/pyboy_states/seafoam_1f_right.state b/pyboy_states/seafoam_1f_right.state new file mode 100644 index 0000000..d26f2e1 Binary files /dev/null and b/pyboy_states/seafoam_1f_right.state differ diff --git a/pyboy_states/victory_road.state b/pyboy_states/victory_road.state new file mode 100644 index 0000000..7450277 Binary files /dev/null and b/pyboy_states/victory_road.state differ diff --git a/pyboy_states/victory_road_2.state b/pyboy_states/victory_road_2.state new file mode 100644 index 0000000..5369652 Binary files /dev/null and b/pyboy_states/victory_road_2.state differ diff --git a/pyboy_states/victory_road_3.state b/pyboy_states/victory_road_3.state new file mode 100644 index 0000000..aec8d32 Binary files /dev/null and b/pyboy_states/victory_road_3.state differ diff --git a/pyboy_states/victory_road_4.state b/pyboy_states/victory_road_4.state new file mode 100644 index 0000000..79c280f Binary files /dev/null and b/pyboy_states/victory_road_4.state differ diff --git a/pyboy_states/victory_road_5.state b/pyboy_states/victory_road_5.state new file mode 100644 index 0000000..cc8984c Binary files /dev/null and b/pyboy_states/victory_road_5.state differ diff --git a/pyboy_states/with_cut_vermillion.state b/pyboy_states/with_cut_vermillion.state new file mode 100644 index 0000000..799a224 Binary files /dev/null and b/pyboy_states/with_cut_vermillion.state differ diff --git a/pyproject.toml b/pyproject.toml index 9c0f70b..35de646 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,13 +13,16 @@ classifiers = [ ] dependencies = [ "einops", - "opencv-python", + "mediapy", "numpy", + "opencv-python", "pyboy>=2", - "pufferlib[cleanrl]>=0.7.3", + "pufferlib[cleanrl]>=0.7.3,<1.0.0", + "scikit-image", "torch>=2.1", "torchvision", - "wandb" + "wandb", + "websockets" ] [tool.setuptools.packages.find]