Skip to content

Commit

Permalink
merge upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
kywch committed Mar 12, 2024
2 parents 5ddc122 + 71d2bfe commit 9f5c968
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_from_config(


def make_env_creator(
wrapper_classes: dict[str, ModuleType],
wrapper_classes: list[tuple[str, ModuleType]],
reward_class: RedGymEnv,
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
def env_creator(
Expand All @@ -70,9 +70,8 @@ def env_creator(
reward_config: pufferlib.namespace,
) -> pufferlib.emulation.GymnasiumPufferEnv:
env = reward_class(env_config, reward_config)
flattened_wrappers_config = {k: v for d in wrappers_config for k, v in d.items()}
for wrapper_name, wrapper_class in wrapper_classes.items():
env = wrapper_class(env, pufferlib.namespace(**flattened_wrappers_config[wrapper_name]))
for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes):
env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0]))
return pufferlib.emulation.GymnasiumPufferEnv(
env=env, postprocessor_cls=pufferlib.emulation.BasicPostprocessor
)
Expand All @@ -87,14 +86,17 @@ def setup_agent(
policy_name: str,
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
# TODO: Make this less dependent on the name of this repo and its file structure
wrapper_classes = {
k: getattr(
importlib.import_module(f"pokemonred_puffer.wrappers.{k.split('.')[0]}"),
k.split(".")[1],
wrapper_classes = [
(
k,
getattr(
importlib.import_module(f"pokemonred_puffer.wrappers.{k.split('.')[0]}"),
k.split(".")[1],
),
)
for wrapper_dicts in wrappers
for k in wrapper_dicts.keys()
}
]
reward_module, reward_class_name = reward_name.split(".")
reward_class = getattr(
importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name
Expand Down

0 comments on commit 9f5c968

Please sign in to comment.