Skip to content

Commit

Permalink
Obs for joypad ignore, noop action, rewards for using an item in the bag
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 8, 2024
1 parent 6be64a8 commit d1613ad
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
3 changes: 2 additions & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,10 @@ rewards:
start_menu: 0.00
pokemon_menu: 0.0
stats_menu: 0.0
bag_menu: 0.0
bag_menu: 0.1
rocket_hideout_found: 5.0
explore_hidden_objs: 0.02
seen_action_bag_menu: 0.1



Expand Down
18 changes: 12 additions & 6 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
WindowEvent.PRESS_BUTTON_A,
WindowEvent.PRESS_BUTTON_B,
WindowEvent.PRESS_BUTTON_START,
WindowEvent.PASS,
]

VALID_RELEASE_ACTIONS = [
Expand All @@ -156,6 +157,7 @@
WindowEvent.RELEASE_BUTTON_A,
WindowEvent.RELEASE_BUTTON_B,
WindowEvent.RELEASE_BUTTON_START,
WindowEvent.PASS,
]

VALID_ACTIONS_STR = ["down", "left", "right", "up", "a", "b", "start"]
Expand Down Expand Up @@ -246,6 +248,7 @@ def __init__(self, env_config: pufferlib.namespace):
"map_id": spaces.Box(low=0, high=0xF7, shape=(1,), dtype=np.uint8),
# "badges": spaces.Box(low=0, high=np.iinfo(np.uint16).max, shape=(1,), dtype=np.uint16),
"badges": spaces.Box(low=0, high=255, shape=(1,), dtype=np.uint8),
"wJoyIgnore": spaces.Box(low=0, high=1, shape=(1,), dtype=np.uint8),
}
)

Expand Down Expand Up @@ -555,6 +558,7 @@ def _get_obs(self):
# "y": np.array(player_y, dtype=np.uint8),
"map_id": np.array(self.read_m(0xD35E), dtype=np.uint8),
"badges": np.array(self.read_short("wObtainedBadges").bit_count(), dtype=np.uint8),
"wJoyIgnore": np.array(self.read_m("wJoyIgnore"), dtype=np.uint8),
}

def set_perfect_iv_dvs(self):
Expand Down Expand Up @@ -626,8 +630,8 @@ def run_action_on_emulator(self, action):
self.action_hist[action] += 1
# press button then release after some steps
# TODO: Add video saving logic
self.pyboy.send_input(VALID_ACTIONS[action])
self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
# self.pyboy.send_input(VALID_ACTIONS[action])
# self.pyboy.send_input(VALID_RELEASE_ACTIONS[action], delay=8)
self.pyboy.tick(self.action_freq, render=True)

if self.read_bit(0xD803, 0):
Expand Down Expand Up @@ -755,8 +759,8 @@ def start_menu_hook(self, *args, **kwargs):
self.seen_start_menu = 1

def item_menu_hook(self, *args, **kwargs):
if self.read_m("wIsInBattle") == 0:
self.seen_bag_menu = 1
# if self.read_m("wIsInBattle") == 0:
self.seen_bag_menu = 1

def pokemon_menu_hook(self, *args, **kwargs):
if self.read_m("wIsInBattle") == 0:
Expand All @@ -767,8 +771,8 @@ def chose_stats_hook(self, *args, **kwargs):
self.seen_stats_menu = 1

def chose_item_hook(self, *args, **kwargs):
if self.read_m("wIsInBattle") == 0:
self.seen_action_bag_menu = 1
# if self.read_m("wIsInBattle") == 0:
self.seen_action_bag_menu = 1

def blackout_hook(self, *args, **kwargs):
self.blackout_count += 1
Expand Down Expand Up @@ -853,6 +857,8 @@ def agent_stats(self, action):
"reset_count": self.reset_count,
"blackout_count": self.blackout_count,
"pokecenter": np.sum(self.pokecenters),
"rival3": int(self.read_m(0xD665) == 4),
"rocket_hideout_found": int(self.read_bit(0xD77E, 1)),
},
"reward": self.get_game_state_reward(),
"reward/reward_sum": sum(self.get_game_state_reward().values()),
Expand Down
1 change: 1 addition & 0 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def encode_observations(self, observations):
badges.float().squeeze(1),
map_id.squeeze(1),
blackout_map_id.squeeze(1),
observations["wJoyIgnore"].float(),
),
dim=-1,
)
Expand Down
2 changes: 2 additions & 0 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def get_game_state_reward(self):
* int(self.read_bit(0xD77E, 1)),
"explore_hidden_objs": sum(self.seen_hidden_objs.values())
* self.reward_config["explore_hidden_objs"],
"seen_action_bag_menu": self.seen_action_bag_menu
* self.reward_config["seen_action_bag_menu"],
}

def get_levels_reward(self):
Expand Down

0 comments on commit d1613ad

Please sign in to comment.