Skip to content

Commit

Permalink
Lazy modules, cut map, fix some stats
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Mar 22, 2024
1 parent ad47462 commit 9ba34c1
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 43 deletions.
21 changes: 18 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,31 @@ rewards:
pokemon_menu: 0.1
stats_menu: 0.1
bag_menu: 0.1
baseline.TeachCutReplicationEnvFork:
reward:
event: 1.0
bill_saved: 5.0
seen_pokemon: 4.0
caught_pokemon: 4.0
moves_obtained: 4.0
hm_count: 10.0
level: 1.0
badges: 10.0
exploration: 0.02
cut_coords: 1.0
cut_tiles: 1.0
start_menu: 0.01
pokemon_menu: 0.1
stats_menu: 0.1
bag_menu: 0.1
taught_cut: 10.0


policies:
multi_convolutional.MultiConvolutionalPolicy:
policy:
input_size: 512
hidden_size: 512
output_size: 512
framestack: 3
flat_size: 1928

recurrent:
# Assumed to be in the same module as the policy
Expand Down
14 changes: 11 additions & 3 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,9 @@ def __init__(

self.reward_buffer = deque(maxlen=1_000)
self.exploration_map_agg = np.zeros((config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32)
self.cut_exploration_map_agg = np.zeros(
(config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32
)
self.taught_cut = False

self.infos = {}
Expand Down Expand Up @@ -479,13 +482,18 @@ def evaluate(self):
self.stats = {}
self.max_stats = {}
for k, v in self.infos["learner"].items():
if "pokemon_exploration_map" in k and config.save_overlay is True:
if "exploration_map" in k and config.save_overlay is True:
if self.update % config.overlay_interval == 0:
# self.exploration_map_agg[env_id, :, :] = v
# overlay = make_pokemon_red_overlay(self.exploration_map_agg)
overlay = make_pokemon_red_overlay(np.stack(v, axis=0))
if self.wandb is not None:
self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay)
if "cut_exploration_map" in k and config.save_overlay is True:
if self.update % config.overlay_interval == 0:
overlay = make_pokemon_red_overlay(np.stack(v, axis=0))
if self.wandb is not None:
self.stats["Media/aggregate_cut_exploration_map"] = self.wandb.Image(
overlay
)
try: # TODO: Better checks on log data types
# self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16)
self.stats[k] = np.mean(v)
Expand Down
20 changes: 15 additions & 5 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def reset(self, seed: Optional[int] = None):
self.caught_pokemon = np.zeros(152, dtype=np.uint8)
self.moves_obtained = np.zeros(0xA5, dtype=np.uint8)
self.explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.cut_explore_map = np.zeros(GLOBAL_MAP_SHAPE, dtype=np.float32)
self.init_mem()
self.reset_count = 0
with open(self.init_state_path, "rb") as f:
Expand All @@ -270,6 +271,7 @@ def reset(self, seed: Optional[int] = None):
self.caught_pokemon.fill(0)
self.moves_obtained.fill(0)
self.explore_map *= 0
self.cut_explore_map *= 0
self.reset_mem()

self.taught_cut = self.check_if_party_has_cut()
Expand Down Expand Up @@ -543,7 +545,7 @@ def pokemon_menu_hook(self, *args, **kwargs):
self.seen_pokemon_menu = 1

def chose_stats_hook(self, *args, **kwargs):
self.chose_stats_hook = 1
self.seen_stats_menu = 1

def chose_item_hook(self, *args, **kwargs):
self.seen_action_bag_menu = 1
Expand All @@ -567,8 +569,12 @@ def cut_hook(self, context):
coords = (x - 1, y, map_id)
if player_direction == 0xC:
coords = (x + 1, y, map_id)

wTileInFrontOfPlayer = self.pyboy.memory[
self.pyboy.symbol_lookup("wTileInFrontOfPlayer")[1]
]
if context:
if self.pyboy.memory[self.pyboy.symbol_lookup("wTileInFrontOfPlayer")[1]] in [
if wTileInFrontOfPlayer in [
0x3D,
0x50,
]:
Expand All @@ -578,6 +584,9 @@ def cut_hook(self, context):
else:
self.cut_coords[coords] = 0.01

self.cut_explore_map[local_to_global(y, x, map_id)] = 1
self.cut_tiles[wTileInFrontOfPlayer] = 1

def agent_stats(self, action):
levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))]
return {
Expand Down Expand Up @@ -626,7 +635,8 @@ def agent_stats(self, action):
},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
"pokemon_exploration_map": self.explore_map,
"exploration_map": self.explore_map,
"cut_exploration_map": self.cut_explore_map,
}

def start_video(self):
Expand Down Expand Up @@ -717,12 +727,12 @@ def read_short(self, addr: str | int) -> int:
if isinstance(addr, str):
_, addr = self.pyboy.symbol_lookup(addr)
data = self.pyboy.memory[addr : addr + 2]
return data[0] << 8 + data[1]
return int(data[0] << 8) + int(data[1])

def read_bit(self, addr: str | int, bit: int) -> bool:
# add padding so zero will read '0b100000000' instead of '0b0'
# return bin(256 + self.read_m(addr))[-bit - 1] == "1"
return bool(self.read_m(addr) & 1 << (7 - bit))
return bool(int(self.read_m(addr)) & (1 << (7 - bit)))

def read_event_bits(self):
_, addr = self.pyboy.symbol_lookup("wEventFlags")
Expand Down
38 changes: 6 additions & 32 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@ class MultiConvolutionalPolicy(pufferlib.models.Policy):
def __init__(
self,
env,
screen_framestack: int = 3,
global_map_frame_stack: int = 1,
screen_flat_size: int = 1928, # 14341,
global_map_flat_size: int = 1600,
input_size: int = 512,
framestack: int = 1,
flat_size: int = 1,
hidden_size=512,
output_size=512,
channels_last: bool = True,
Expand All @@ -42,41 +35,22 @@ def __init__(
self.channels_last = channels_last
self.downsample = downsample
self.screen_network = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Conv2d(screen_framestack, 32, 8, stride=4)),
nn.LazyConv2d(32, 8, stride=4),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 4, stride=2)),
nn.LazyConv2d(64, 4, stride=2),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 3, stride=1)),
nn.LazyConv2d(64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
)

"""
self.global_map_network = nn.Sequential(
pufferlib.pytorch.layer_init(nn.Conv2d(global_map_frame_stack, 32, 16, stride=8)),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(32, 64, 8, stride=4)),
nn.ReLU(),
pufferlib.pytorch.layer_init(nn.Conv2d(64, 64, 4, stride=2)),
nn.ReLU(),
nn.Flatten(),
)
"""

self.encode_linear = nn.Sequential(
pufferlib.pytorch.layer_init(
nn.Linear(
screen_flat_size,
hidden_size,
),
),
nn.LazyLinear(hidden_size),
nn.ReLU(),
)

self.actor = pufferlib.pytorch.layer_init(
nn.Linear(output_size, self.num_actions), std=0.01
)
self.value_fn = pufferlib.pytorch.layer_init(nn.Linear(output_size, 1), std=1)
self.actor = nn.LazyLinear(self.num_actions)
self.value_fn = nn.LazyLinear(output_size, 1)

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)
Expand Down
50 changes: 50 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,53 @@ def get_all_events_reward(self):
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)


class TeachCutReplicationEnvFork(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
super().__init__(env_config)
self.reward_config = reward_config

def get_game_state_reward(self):
return {
"event": self.reward_config["event"] * self.update_max_event_rew(),
"met_bill": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F1, 0)),
"used_cell_separator_on_bill": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 3)),
"ss_ticket": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 4)),
"met_bill_2": self.reward_config["bill_saved"] * int(self.read_bit(0xD7F2, 5)),
"bill_said_use_cell_separator": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 6)),
"left_bills_house_after_helping": self.reward_config["bill_saved"]
* int(self.read_bit(0xD7F2, 7)),
"moves_obtained": self.reward_config["moves_obtained"] * sum(self.moves_obtained),
"hm_count": self.reward_config["hm_count"] * self.get_hm_count(),
"badges": self.reward_config["badges"] * self.get_badges(),
"exploration": self.reward_config["exploration"] * sum(self.seen_coords.values()),
"cut_coords": self.reward_config["cut_coords"] * sum(self.cut_coords.values()),
"cut_tiles": self.reward_config["cut_tiles"] * sum(self.cut_tiles),
"start_menu": self.reward_config["start_menu"] * self.seen_start_menu,
"pokemon_menu": self.reward_config["pokemon_menu"] * self.seen_pokemon_menu,
"stats_menu": self.reward_config["stats_menu"] * self.seen_stats_menu,
"bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu,
"taught_cut": self.reward_config["taught_cut"] * int(self.check_if_party_has_cut()),
}

def update_max_event_rew(self):
cur_rew = self.get_all_events_reward()
self.max_event_rew = max(cur_rew, self.max_event_rew)
return self.max_event_rew

def get_all_events_reward(self):
# adds up all event flags, exclude museum ticket
return max(
sum(
[
self.read_m(i).bit_count()
for i in range(EVENT_FLAGS_START, EVENT_FLAGS_START + EVENTS_FLAGS_LENGTH)
]
)
- self.base_event_flags
- int(self.read_bit(*MUSEUM_TICKET)),
0,
)

0 comments on commit 9ba34c1

Please sign in to comment.