diff --git a/config.yaml b/config.yaml index 5423483..59e2139 100644 --- a/config.yaml +++ b/config.yaml @@ -73,6 +73,7 @@ env: auto_pokeflute: True auto_next_elevator_floor: False skip_safari_zone: True + insert_saffron_guard_drinks: True infinite_money: True use_global_map: False save_state: True diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index b8489e7..d94889c 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -123,6 +123,7 @@ def __init__(self, env_config: pufferlib.namespace): self.auto_pokeflute = env_config.auto_pokeflute self.auto_next_elevator_floor = env_config.auto_next_elevator_floor self.skip_safari_zone = env_config.skip_safari_zone + self.insert_saffron_guard_drinks = env_config.insert_saffron_guard_drinks self.infinite_money = env_config.infinite_money self.use_global_map = env_config.use_global_map self.save_state = env_config.save_state @@ -1256,6 +1257,34 @@ def next_elevator_floor(self): self.pyboy.button("up", 8) self.pyboy.tick(self.action_freq, render=self.animate_scripts) + def insert_guard_drinks(self): + if not self.wd728.get_bit("GAVE_SAFFRON_GUARD_DRINK") and self.wd728.MapIds( + self.read_m("wCurMap") + ) in [ + MapIds.CELADON_MART_1F, + MapIds.CELADON_MART_2F, + MapIds.CELADON_MART_3F, + MapIds.CELADON_MART_4F, + MapIds.CELADON_MART_5F, + MapIds.CELADON_MART_ELEVATOR, + MapIds.CELADON_MART_ROOF, + ]: + _, wBagItems = self.pyboy.symbol_lookup("wBagItems") + _, wNumBagItems = self.pyboy.symbol_lookup("wNumBagItems") + numBagItems = self.read_m(wNumBagItems) + bag = np.array(self.pyboy.memory[wBagItems : wBagItems + 40], dtype=np.uint8) + if not { + Items.LEMONADE.value, + Items.FRESH_WATER.value, + Items.SODA_POP.value, + }.intersection(bag[::2]): + bag[numBagItems * 2] = Items.LEMONADE.value + bag[numBagItems * 2 + 1] = 1 + numBagItems += 1 + bag[numBagItems * 2 :] = 0xFF + self.pyboy.memory[wBagItems : wBagItems + 40] = bag + self.pyboy.memory[wNumBagItems] = numBagItems + def sign_hook(self, *args, **kwargs): sign_id = self.read_m("hSpriteIndexOrTextID") map_id = self.read_m("wCurMap")