Skip to content

Commit

Permalink
Merge branch 'leanke-rep'
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Mar 21, 2024
2 parents 5ec82a0 + dae87d9 commit ad47462
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 54 deletions.
43 changes: 33 additions & 10 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ env:
state_dir: pyboy_states
init_state: Bulbasaur
action_freq: 24
max_steps: 100_000_000
max_steps: 20480
save_video: False
fast_video: False
frame_stacks: 1
Expand All @@ -48,9 +48,9 @@ train:
compile_mode: "reduce-overhead"
float32_matmul_precision: "high"
total_timesteps: 100_000_000_000
batch_size: 32768
learning_rate: 2.5e-4
anneal_lr: True
batch_size: 65536
learning_rate: 2.0e-4
anneal_lr: False
gamma: 0.998
gae_lambda: 0.95
num_minibatches: 4
Expand All @@ -63,12 +63,12 @@ train:
max_grad_norm: 0.5
target_kl: ~
batch_rows: 128
bptt_horizon: 16 #8
vf_clip_coef: 0.12 # 0.1
bptt_horizon: 16
vf_clip_coef: 0.1

num_envs: 150
envs_per_worker: 2
envs_per_batch: 60
num_envs: 96
envs_per_worker: 1
envs_per_batch: 32
env_pool: True

verbose: True
Expand All @@ -79,6 +79,7 @@ train:
overlay_interval: 200
cpu_offload: True
pool_kernel: [0]
log_frequency: 2000

wrappers:
baseline:
Expand All @@ -103,9 +104,31 @@ wrappers:
- exploration.MaxLengthWrapper:
capacity: 1750

stream_only:
- stream_wrapper.StreamWrapper:
user: thatguy

rewards:
baseline.BaselineRewardEnv:
reward:
baseline.TeachCutReplicationEnv:
reward:
event: 1.0
bill_saved: 5.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
hm_count: 10.0
level: 1.0
badges: 10.0
exploration: 0.02
cut_coords: 1.0
cut_tiles: 1.0
start_menu: 0.01
pokemon_menu: 0.1
stats_menu: 0.1
bag_menu: 0.1


policies:
multi_convolutional.MultiConvolutionalPolicy:
Expand All @@ -114,7 +137,7 @@ policies:
hidden_size: 512
output_size: 512
framestack: 3
flat_size: 2184
flat_size: 1928

recurrent:
# Assumed to be in the same module as the policy
Expand Down
2 changes: 1 addition & 1 deletion pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def evaluate(self):
if self.wandb is not None:
self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay)
try: # TODO: Better checks on log data types
self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16)
# self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16)
self.stats[k] = np.mean(v)
self.max_stats[k] = np.max(v)
except: # noqa
Expand Down
94 changes: 61 additions & 33 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import deque
from multiprocessing import Lock, shared_memory
from pathlib import Path
from typing import Optional
from typing import Iterable, Optional
import uuid

import mediapy as media
Expand Down Expand Up @@ -83,6 +83,8 @@
]
)

HM_ITEM_IDS = set([0xC4, 0xC5, 0xC6, 0xC7, 0xC8])

RESET_MAP_IDS = set(
[
0x0, # Pallet Town
Expand Down Expand Up @@ -112,6 +114,16 @@
WindowEvent.PRESS_BUTTON_START,
]

VALID_RELEASE_ACTIONS = [
WindowEvent.RELEASE_ARROW_DOWN,
WindowEvent.RELEASE_ARROW_LEFT,
WindowEvent.RELEASE_ARROW_RIGHT,
WindowEvent.RELEASE_ARROW_UP,
WindowEvent.RELEASE_BUTTON_A,
WindowEvent.RELEASE_BUTTON_B,
WindowEvent.RELEASE_BUTTON_START,
]

VALID_ACTIONS_STR = ["down", "left", "right", "up", "a", "b", "start"]

ACTION_SPACE = spaces.Discrete(len(VALID_ACTIONS))
Expand Down Expand Up @@ -142,6 +154,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.perfect_ivs = env_config.perfect_ivs
self.reduce_res = env_config.reduce_res
self.gb_path = env_config.gb_path
self.log_frequency = env_config.log_frequency
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
Expand Down Expand Up @@ -177,13 +190,13 @@ def __init__(self, env_config: pufferlib.namespace):
),
# Discrete is more apt, but pufferlib is slower at processing Discrete
"direction": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "reset_map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"battle_type": spaces.Box(low=0, high=4, shape=(1,), dtype=np.uint8),
"cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
"x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
"badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8),
# "cut_in_party": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
# "x": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "y": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
# "map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=8, shape=(1,), dtype=np.uint8),
}
)

Expand Down Expand Up @@ -242,23 +255,22 @@ def reset(self, seed: Optional[int] = None):
self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.init_mem()
self.reset_count = 0
with open(self.init_state_path, "rb") as f:
self.pyboy.load_state(f)
# lazy random seed setting
if not seed:
seed = random.randint(0, 4096)
self.pyboy.tick(seed, render=False)
else:
self.recent_screens.clear()
self.recent_actions.clear()
self.seen_pokemon.fill(0)
self.caught_pokemon.fill(0)
self.moves_obtained.fill(0)
self.explore_map *= 0
self.reset_mem()
self.reset_count += 1

with open(self.init_state_path, "rb") as f:
self.pyboy.load_state(f)

# lazy random seed setting
if not seed:
seed = random.randint(0, 4096)
self.pyboy.tick(seed, render=False)
self.recent_screens.clear()
self.recent_actions.clear()
self.seen_pokemon.fill(0)
self.caught_pokemon.fill(0)
self.moves_obtained.fill(0)
self.explore_map *= 0
self.reset_mem()

self.taught_cut = self.check_if_party_has_cut()
self.base_event_flags = sum(
Expand All @@ -284,9 +296,6 @@ def reset(self, seed: Optional[int] = None):

self.action_hist = np.zeros(len(VALID_ACTIONS))

# experiment!
# self.max_steps += 128

self.max_map_progress = 0
self.progress_reward = self.get_game_state_reward()
self.total_reward = sum([val for _, val in self.progress_reward.items()])
Expand Down Expand Up @@ -446,13 +455,13 @@ def _get_obs(self):
"direction": np.array(
self.read_m("wSpritePlayerStateData1FacingDirection") // 4, dtype=np.uint8
),
"reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
# "reset_map_id": np.array(self.read_m("wLastBlackoutMap"), dtype=np.uint8),
"battle_type": np.array(self.read_m("wIsInBattle") + 1, dtype=np.uint8),
"cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8),
"x": np.array(player_x, dtype=np.uint8),
"y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(map_n, dtype=np.uint8),
"badges": np.array(self.get_badges(), dtype=np.uint8),
# "cut_in_party": np.array(self.check_if_party_has_cut(), dtype=np.uint8),
# "x": np.array(player_x, dtype=np.uint8),
# "y": np.array(player_y, dtype=np.uint8),
# "map_id": np.array(map_n, dtype=np.uint8),
# "badges": np.array(self.get_badges(), dtype=np.uint8),
}

def set_perfect_iv_dvs(self):
Expand Down Expand Up @@ -490,14 +499,14 @@ def step(self, action):

info = {}
# TODO: Make log frequency a configuration parameter
if self.step_count % 2000 == 0:
if self.step_count % self.log_frequency == 0:
info = self.agent_stats(action)

obs = self._get_obs()

self.step_count += 1
reset = (
self.step_count > self.max_steps # or
self.step_count >= self.max_steps # or
# self.caught_pokemon[6] == 1 # squirtle
)

Expand All @@ -507,7 +516,8 @@ def run_action_on_emulator(self, action):
self.action_hist[action] += 1
# press button then release after some steps
# TODO: Add video saving logic
self.pyboy.button(VALID_ACTIONS_STR[action], delay=8)
self.pyboy.send_input(VALID_ACTIONS[action])
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
self.pyboy.tick(self.action_freq, render=True)

if self.save_video and self.fast_video:
Expand Down Expand Up @@ -799,3 +809,21 @@ def get_map_progress(self, map_idx):
return self.essential_map_locations[map_idx]
else:
return -1

def get_items_in_bag(self) -> Iterable[int]:
num_bag_items = self.read_m("wNumBagItems")
_, addr = self.pyboy.symbol_lookup("wBagItems")
return self.pyboy.memory[addr : addr + 2 * num_bag_items][::2]

def get_hm_count(self) -> int:
return len(HM_ITEM_IDS.intersection(self.get_items_in_bag()))

def get_levels_reward(self):
# Level reward
party_levels = self.read_party()
self.max_level_sum = max(self.max_level_sum, sum(party_levels))
if self.max_level_sum < 30:
level_reward = 1 * self.max_level_sum
else:
level_reward = 30 + (self.max_level_sum - 30) / 4
return level_reward
14 changes: 7 additions & 7 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
env,
screen_framestack: int = 3,
global_map_frame_stack: int = 1,
screen_flat_size: int = 2433, # 14341,
screen_flat_size: int = 1928, # 14341,
global_map_flat_size: int = 1600,
input_size: int = 512,
framestack: int = 1,
Expand Down Expand Up @@ -97,13 +97,13 @@ def encode_observations(self, observations):
(
*output,
one_hot(observations["direction"].long(), 4).float().squeeze(1),
one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
# one_hot(observations["reset_map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["battle_type"].long(), 4).float().squeeze(1),
observations["cut_in_party"].float(),
observations["x"].float(),
observations["y"].float(),
one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1),
one_hot(observations["badges"].long(), 8).float().squeeze(1),
# observations["cut_in_party"].float(),
# observations["x"].float(),
# observations["y"].float(),
# one_hot(observations["map_id"].long(), 0xF7).float().squeeze(1),
# one_hot(observations["badges"].long(), 8).float().squeeze(1),
),
dim=-1,
)
Expand Down
55 changes: 52 additions & 3 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ class BaselineRewardEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)

def step(self, action):
return super().step(action)

# TODO: make the reward weights configurable
def get_game_state_reward(self):
# addresses from https://datacrystal.romhacking.net/wiki/Pok%C3%A9mon_Red/Blue:RAM_map
Expand Down Expand Up @@ -80,3 +77,55 @@ def get_levels_reward(self):
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4


class TeachCutReplicationEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
"met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)),
"used_cell_separator_on_bill": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 3)),
"ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)),
"met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)),
"bill_said_use_cell_separator": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 6)),
"left_bills_house_after_helping": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 7)),
"seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"level": self.reward_config["level"] * self.get_levels_reward(),
"badges": self.reward_config["badges"] * self.get_badges(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles),
"start_menu": self.reward_config["start_menu"] * self.seen_start_menu,
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
}

def update_max_event_rew(self):
cur_rew = self.get_all_events_reward()
self.max_event_rew = max(cur_rew, self.max_event_rew)
return self.max_event_rew

def get_all_events_reward(self):
# adds up all event flags, exclude museum ticket
return max(
sum(
[
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
]
)
- self.base_event_flags
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)

0 comments on commit ad47462

Please sign in to comment.