diff --git a/config.yaml b/config.yaml index 44670d6..47c559f 100644 --- a/config.yaml +++ b/config.yaml @@ -132,7 +132,7 @@ rewards: bag_menu: 0.1 baseline.TeachCutReplicationEnvFork: reward: - event: 4.0 + event: 1.0 bill_saved: 5.0 moves_obtained: 4.0 hm_count: 10.0 @@ -147,6 +147,9 @@ rewards: taught_cut: 10.0 explore_npcs: 0.02 explore_hidden_objs: 0.02 + seen_pokemon: 4.0 + caught_pokemon: 4.0 + level: 1.0 policies: diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index a2e6048..44e22fb 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -780,8 +780,8 @@ def update_tm_hm_moves_obtained(self): for i in range(self.read_m("wPartyCount")): _, addr = self.pyboy.symbol_lookup(f"wPartyMon{i+1}Moves") for move_id in self.pyboy.memory[addr : addr + 4]: - if move_id in TM_HM_MOVES: - self.moves_obtained[move_id] = 1 + # if move_id in TM_HM_MOVES: + self.moves_obtained[move_id] = 1 """ # Scan current box (since the box doesn't auto increment in pokemon red) num_moves = 4 diff --git a/pokemonred_puffer/rewards/baseline.py b/pokemonred_puffer/rewards/baseline.py index e3a3069..8eb3d69 100644 --- a/pokemonred_puffer/rewards/baseline.py +++ b/pokemonred_puffer/rewards/baseline.py @@ -172,6 +172,9 @@ def get_game_state_reward(self): ), "bag_menu": self.reward_config["bag_menu"] * self.seen_bag_menu * int(self.taught_cut), "taught_cut": self.reward_config["taught_cut"] * int(self.taught_cut), + "seen_pokemon": self.reward_config["seen_pokemon"] * sum(self.seen_pokemon), + "caught_pokemon": self.reward_config["caught_pokemon"] * sum(self.caught_pokemon), + "level": self.reward_config["level"] * self.get_levels_reward(), } def update_max_event_rew(self): @@ -192,3 +195,12 @@ def get_all_events_reward(self): - int(self.read_bit(*MUSEUM_TICKET)), 0, ) + + def get_levels_reward(self): + party_size = self.read_m("wPartyCount") + party_levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(party_size)] + self.max_level_sum = max(self.max_level_sum, sum(party_levels)) + if self.max_level_sum < 15: + return self.max_level_sum + else: + return 15 + (self.max_level_sum - 15) / 4