diff --git a/pokemonred_puffer/data/items.py b/pokemonred_puffer/data/items.py index 18ff459..fca6fcb 100644 --- a/pokemonred_puffer/data/items.py +++ b/pokemonred_puffer/data/items.py @@ -185,44 +185,44 @@ class Items(Enum): Items.HM_04, } -KEY_ITEM_IDS = { - Items.TOWN_MAP.value, - Items.BICYCLE.value, - Items.SURFBOARD.value, - Items.SAFARI_BALL.value, - Items.POKEDEX.value, - Items.BOULDERBADGE.value, - Items.CASCADEBADGE.value, - Items.THUNDERBADGE.value, - Items.RAINBOWBADGE.value, - Items.SOULBADGE.value, - Items.MARSHBADGE.value, - Items.VOLCANOBADGE.value, - Items.EARTHBADGE.value, - Items.OLD_AMBER.value, - Items.DOME_FOSSIL.value, - Items.HELIX_FOSSIL.value, - Items.SECRET_KEY.value, - # Items.ITEM_2C.value, - Items.BIKE_VOUCHER.value, - Items.CARD_KEY.value, - Items.S_S_TICKET.value, - Items.GOLD_TEETH.value, - Items.COIN_CASE.value, - Items.OAKS_PARCEL.value, - Items.ITEMFINDER.value, - Items.SILPH_SCOPE.value, - Items.POKE_FLUTE.value, - Items.LIFT_KEY.value, - Items.OLD_ROD.value, - Items.GOOD_ROD.value, - Items.SUPER_ROD.value, +KEY_ITEMS = { + Items.TOWN_MAP, + Items.BICYCLE, + Items.SURFBOARD, + Items.SAFARI_BALL, + Items.POKEDEX, + Items.BOULDERBADGE, + Items.CASCADEBADGE, + Items.THUNDERBADGE, + Items.RAINBOWBADGE, + Items.SOULBADGE, + Items.MARSHBADGE, + Items.VOLCANOBADGE, + Items.EARTHBADGE, + Items.OLD_AMBER, + Items.DOME_FOSSIL, + Items.HELIX_FOSSIL, + Items.SECRET_KEY, + # Items.ITEM_2C, + Items.BIKE_VOUCHER, + Items.CARD_KEY, + Items.S_S_TICKET, + Items.GOLD_TEETH, + Items.COIN_CASE, + Items.OAKS_PARCEL, + Items.ITEMFINDER, + Items.SILPH_SCOPE, + Items.POKE_FLUTE, + Items.LIFT_KEY, + Items.OLD_ROD, + Items.GOOD_ROD, + Items.SUPER_ROD, } -HM_ITEM_IDS = { - Items.HM_01.value, - Items.HM_02.value, - Items.HM_03.value, - Items.HM_04.value, - Items.HM_05.value, +HM_ITEMS = { + Items.HM_01, + Items.HM_02, + Items.HM_03, + Items.HM_04, + Items.HM_05, } diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index cdba343..d153518 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -24,8 +24,8 @@ ) from pokemonred_puffer.data.field_moves import FieldMoves from pokemonred_puffer.data.items import ( - HM_ITEM_IDS, - KEY_ITEM_IDS, + HM_ITEMS, + KEY_ITEMS, MAX_ITEM_CAPACITY, REQUIRED_ITEMS, USEFUL_ITEMS, @@ -1359,21 +1359,7 @@ def remove_all_nonuseful_items(self): new_bag_items = [ (item, quantity) for item, quantity in zip(bag_items[::2], bag_items[1::2]) - if (0x0 < item < Items.HM_01.value and (item - 1) in KEY_ITEM_IDS) - or item - in { - Items[name] - for name in [ - "LEMONADE", - "SODA_POP", - "FRESH_WATER", - "HM_01", - "HM_02", - "HM_03", - "HM_04", - "HM_05", - ] - } + if Items(item) in KEY_ITEMS | REQUIRED_ITEMS | USEFUL_ITEMS | HM_ITEMS ] # Write the new count back to memory self.pyboy.memory[wNumBagItems] = len(new_bag_items) @@ -1409,13 +1395,13 @@ def get_map_progress(self, map_idx): else: return -1 - def get_items_in_bag(self) -> Iterable[int]: + def get_items_in_bag(self) -> Iterable[Items]: num_bag_items = self.read_m("wNumBagItems") _, addr = self.pyboy.symbol_lookup("wBagItems") - return self.pyboy.memory[addr : addr + 2 * num_bag_items][::2] + return [Items(i) for i in self.pyboy.memory[addr : addr + 2 * num_bag_items][::2]] def get_hm_count(self) -> int: - return len(HM_ITEM_IDS.intersection(self.get_items_in_bag())) + return len(HM_ITEMS.intersection(self.get_items_in_bag())) def get_levels_reward(self): # Level reward diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 420e2f2..9a41fe7 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -1,4 +1,5 @@ import argparse +from contextlib import contextmanager import functools import importlib import os @@ -138,30 +139,35 @@ def update_args(args: argparse.Namespace): return args +@contextmanager def init_wandb(args, resume=True): - assert args.wandb.project is not None, "Please set the wandb project in config.yaml" - assert args.wandb.entity is not None, "Please set the wandb entity in config.yaml" - wandb_kwargs = { - "id": args.exp_name or wandb.util.generate_id(), - "project": args.wandb.project, - "entity": args.wandb.entity, - "group": args.wandb.group, - "config": { - "cleanrl": args.train, - "env": args.env, - "reward_module": args.reward_name, - "policy_module": args.policy_name, - "reward": args.rewards[args.reward_name], - "policy": args.policies[args.policy_name], - "wrappers": args.wrappers[args.wrappers_name], - "recurrent": "recurrent" in args.policies[args.policy_name], - }, - "name": args.exp_name, - "monitor_gym": True, - "save_code": True, - "resume": resume, - } - return wandb.init(**wandb_kwargs) + if not args.track: + yield None + else: + assert args.wandb.project is not None, "Please set the wandb project in config.yaml" + assert args.wandb.entity is not None, "Please set the wandb entity in config.yaml" + wandb_kwargs = { + "id": args.exp_name or wandb.util.generate_id(), + "project": args.wandb.project, + "entity": args.wandb.entity, + "group": args.wandb.group, + "config": { + "cleanrl": args.train, + "env": args.env, + "reward_module": args.reward_name, + "policy_module": args.policy_name, + "reward": args.rewards[args.reward_name], + "policy": args.policies[args.policy_name], + "wrappers": args.wrappers[args.wrappers_name], + "recurrent": "recurrent" in args.policies[args.policy_name], + }, + "name": args.exp_name, + "monitor_gym": True, + "save_code": True, + "resume": resume, + } + with wandb.init(**wandb_kwargs) as client: + yield client def train( @@ -286,37 +292,34 @@ def train( async_wrapper = args.train.async_wrapper env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, async_wrapper) - wandb_client = None - if args.track: - wandb_client = init_wandb(args) - - if args.mode == "train": - train(args, env_creator, wandb_client) - elif args.mode == "autotune": - env_kwargs = { - "env_config": args.env, - "wrappers_config": args.wrappers[args.wrappers_name], - "reward_config": args.rewards[args.reward_name]["reward"], - "async_config": {}, - } - pufferlib.vector.autotune( - functools.partial(env_creator, **env_kwargs), batch_size=args.train.env_batch_size - ) - elif args.mode == "evaluate": - env_kwargs = { - "env_config": args.env, - "wrappers_config": args.wrappers[args.wrappers_name], - "reward_config": args.rewards[args.reward_name]["reward"], - "async_config": {}, - } - try: - cleanrl_puffer.rollout( - env_creator, - env_kwargs, - agent_creator=make_policy, - agent_kwargs={"args": args}, - model_path=args.eval_model_path, - device=args.train.device, + with init_wandb(args) as wandb_client: + if args.mode == "train": + train(args, env_creator, wandb_client) + elif args.mode == "autotune": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + pufferlib.vector.autotune( + functools.partial(env_creator, **env_kwargs), batch_size=args.train.env_batch_size ) - except KeyboardInterrupt: - os._exit(0) + elif args.mode == "evaluate": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + try: + cleanrl_puffer.rollout( + env_creator, + env_kwargs, + agent_creator=make_policy, + agent_kwargs={"args": args}, + model_path=args.eval_model_path, + device=args.train.device, + ) + except KeyboardInterrupt: + os._exit(0)