Skip to content

Commit

Permalink
Merge branch 'more-repro-reward'
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 29, 2024
2 parents 9335909 + dc20780 commit 78e1f2c
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 66 deletions.
35 changes: 28 additions & 7 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ wandb:

debug:
env:
headless: True
headless: False
stream_wrapper: False
init_state: cut
max_steps: 4
max_steps: 1_000_000
train:
device: cpu
compile: False
compile_mode: default
num_envs: 10
num_envs: 4
envs_per_worker: 1
envs_per_batch: 1
batch_size: 8
envs_per_batch: 4
batch_size: 16
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
Expand All @@ -28,8 +28,8 @@ debug:
env_pool: False
log_frequency: 5000
load_optimizer_state: False
swarm_frequency: 5
swarm_keep_pct: .8
swarm_frequency: 10
swarm_keep_pct: .1

env:
headless: True
Expand Down Expand Up @@ -157,6 +157,27 @@ rewards:
explore_npcs: 0.02
explore_hidden_objs: 0.02

baseline.RockTunnelReplicationEnv:
reward:
level: 1.0
exploration: 0.02
taught_cut: 10.0
event: 3.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
cut_coords: 1.0
cut_tiles: 1.0
start_menu: 0.005
pokemon_menu: 0.05
stats_menu: 0.05
bag_menu: 0.05
pokecenter: 5.0
# Really an addition to event reward
badges: 2.0
bill_saved: 2.0



policies:
multi_convolutional.MultiConvolutionalPolicy:
Expand Down
16 changes: 10 additions & 6 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ def evaluate(self):
)
self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state])
waiting_for.append(i + 1)
# Now copy the hidden state over
# This may be a little slow, but so is this whole process
self.next_lstm_state[0][:, i, :] = self.next_lstm_state[0][:, new_state, :]
self.next_lstm_state[1][:, i, :] = self.next_lstm_state[1][:, new_state, :]
for i in waiting_for:
self.env_send_queues[i].get()
print("State migration complete")
Expand Down Expand Up @@ -451,7 +455,7 @@ def evaluate(self):

# Index alive mask with policy pool idxs...
# TODO: Find a way to avoid having to do this
learner_mask = torch.Tensor(mask * self.policy_pool.mask)
learner_mask = torch.as_tensor(mask * self.policy_pool.mask)

# Ensure indices do not exceed batch size
indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy()
Expand Down Expand Up @@ -588,11 +592,11 @@ def train(self):
)

# Flatten the batch
self.b_obs = b_obs = torch.tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
b_actions = torch.tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.Tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
self.b_obs = b_obs = torch.as_tensor(self.obs_ary[b_idxs], dtype=torch.uint8)
b_actions = torch.as_tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True)
b_logprobs = torch.as_tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True)
# b_dones = torch.as_tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True)
b_values = torch.as_tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True)
b_advantages = advantages.reshape(
config.batch_rows, num_minibatches, config.bptt_horizon
).transpose(0, 1)
Expand Down
39 changes: 32 additions & 7 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,11 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] =
self.seen_pokemon = np.zeros(152, dtype=np.uint8)
self.caught_pokemon = np.zeros(152, dtype=np.uint8)
self.moves_obtained = np.zeros(0xA5, dtype=np.uint8)
self.pokecenters = np.zeros(252, dtype=np.uint8)
# lazy random seed setting
if not seed:
seed = random.randint(0, 4096)
self.pyboy.tick(seed, render=False)
# if not seed:
# seed = random.randint(0, 4096)
# self.pyboy.tick(seed, render=False)
else:
self.reset_count += 1

Expand Down Expand Up @@ -370,8 +371,10 @@ def reset_mem(self):
def render(self):
# (144, 160, 3)
game_pixels_render = np.expand_dims(self.screen.ndarray[:, :, 1], axis=-1)

if self.reduce_res:
game_pixels_render = game_pixels_render[::2, ::2, :]
# game_pixels_render = skimage.measure.block_reduce(game_pixels_render, (2, 2, 1), np.min)

# place an overlay on top of the screen greying out places we haven't visited
# first get our location
Expand Down Expand Up @@ -551,11 +554,18 @@ def step(self, action):
if self.perfect_ivs:
self.set_perfect_iv_dvs()
self.taught_cut = self.check_if_party_has_cut()

self.pokecenters[self.read_m("wLastBlackoutMap")] = 1
info = {}

if self.get_events_sum() > self.max_event_rew:
state = io.BytesIO()
self.pyboy.save_state(state)
state.seek(0)
info["state"] = state.read()

# TODO: Make log frequency a configuration parameter
if self.step_count % self.log_frequency == 0:
info = self.agent_stats(action)
info = info | self.agent_stats(action)

obs = self._get_obs()

Expand Down Expand Up @@ -635,9 +645,9 @@ def cut_hook(self, context):
]:
self.cut_coords[coords] = 10
else:
self.cut_coords[coords] = 0.01
self.cut_coords[coords] = 0.001
else:
self.cut_coords[coords] = 0.01
self.cut_coords[coords] = 0.001

self.cut_explore_map[local_to_global(y, x, map_id)] = 1
self.cut_tiles[wTileInFrontOfPlayer] = 1
Expand Down Expand Up @@ -687,6 +697,7 @@ def agent_stats(self, action):
"item_count": self.read_m(0xD31D),
"reset_count": self.reset_count,
"blackout_count": self.blackout_count,
"pokecenter": np.sum(self.pokecenters),
},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
Expand Down Expand Up @@ -877,3 +888,17 @@ def get_levels_reward(self):
else:
level_reward = 30 + (self.max_level_sum - 30) / 4
return level_reward

def get_events_sum(self):
# adds up all event flags, exclude museum ticket
return max(
sum(
[
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
]
)
- self.base_event_flags
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)
95 changes: 49 additions & 46 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
RedGymEnv,
)

import numpy as np

MUSEUM_TICKET = (0xD754, 0)


class BaselineRewardEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

# TODO: make the reward weights configurable
def get_game_state_reward(self):
Expand Down Expand Up @@ -80,11 +83,7 @@ def get_levels_reward(self):
return 15 + (self.max_level_sum - 15) / 4


class TeachCutReplicationEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

class TeachCutReplicationEnv(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
Expand Down Expand Up @@ -113,31 +112,8 @@ def get_game_state_reward(self):
"rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4),
}

def update_max_event_rew(self):
cur_rew = self.get_all_events_reward()
self.max_event_rew = max(cur_rew, self.max_event_rew)
return self.max_event_rew

def get_all_events_reward(self):
# adds up all event flags, exclude museum ticket
return max(
sum(
[
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
]
)
- self.base_event_flags
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)


class TeachCutReplicationEnvFork(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

class TeachCutReplicationEnvFork(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
Expand Down Expand Up @@ -179,24 +155,51 @@ def get_game_state_reward(self):
"level": self.reward_config["level"] * self.get_levels_reward(),
}

def update_max_event_rew(self):
cur_rew = self.get_all_events_reward()
self.max_event_rew = max(cur_rew, self.max_event_rew)
return self.max_event_rew
def get_levels_reward(self):
party_size = self.read_m("wPartyCount")
party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)]
self.max_level_sum = max(self.max_level_sum, sum(party_levels))
if self.max_level_sum < 15:
return self.max_level_sum
else:
return 15 + (self.max_level_sum - 15) / 4

def get_all_events_reward(self):
# adds up all event flags, exclude museum ticket
return max(
sum(
[
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
]
)
- self.base_event_flags
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)

class RockTunnelReplicationEnv(BaselineRewardEnv):
def get_game_state_reward(self):
return {
"level": self.reward_config["level"] * self.get_levels_reward(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut),
"event": self.reward_config["event"] * self.update_max_event_rew(),
"seen_pokemon": self.reward_config["seen_pokemon"] * np.sum(self.seen_pokemon),
"caught_pokemon": self.reward_config["caught_pokemon"] * np.sum(self.caught_pokemon),
"moves_obtained": self.reward_config["moves_obtained"] * np.sum(self.moves_obtained),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles),
"start_menu": (
self.reward_config["start_menu"] * self.seen_start_menu * int(self.taught_cut)
),
"pokemon_menu": (
self.reward_config["pokemon_menu"] * self.seen_pokemon_menu * int(self.taught_cut)
),
"stats_menu": (
self.reward_config["stats_menu"] * self.seen_stats_menu * int(self.taught_cut)
),
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut),
# "pokecenter": self.reward_config["pokecenter"] * np.sum(self.pokecenters),
"badges": self.reward_config["badges"] * self.get_badges(),
"met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)),
"used_cell_separator_on_bill": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 3)),
"ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)),
"met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)),
"bill_said_use_cell_separator": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 6)),
"left_bills_house_after_helping": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 7)),
"rival3": self.reward_config["event"] * int(self.read_m(0xD665) == 4),
}

def get_levels_reward(self):
party_size = self.read_m("wPartyCount")
Expand Down

0 comments on commit 78e1f2c

Please sign in to comment.