diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 2d496d2..77c9421 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -39,45 +39,51 @@ def rollout( agent_kwargs, model_path=None, device="cuda", - verbose=True, ): - env = env_creator(**env_kwargs) + # We are just using Serial vecenv to give a consistent + # single-agent/multi-agent API for evaluation + try: + env = pufferlib.vector.make( + env_creator, env_kwargs={"render_mode": "rgb_array", **env_kwargs} + ) + except: # noqa: E722 + env = pufferlib.vector.make(env_creator, env_kwargs=env_kwargs) + if model_path is None: agent = agent_creator(env, **agent_kwargs) else: agent = torch.load(model_path, map_location=device) - terminal = truncated = True + ob, info = env.reset() + driver = env.driver_env + os.system("clear") + state = None while True: - if terminal or truncated: - if verbose: - print("--- Reset ---") - - ob, info = env.reset() - state = None - step = 0 - return_val = 0 + render = driver.render() + if driver.render_mode == "ansi": + print("\033[0;0H" + render + "\n") + time.sleep(0.6) + elif driver.render_mode == "rgb_array": + import cv2 + + render = cv2.cvtColor(render, cv2.COLOR_RGB2BGR) + cv2.imshow("frame", render) + cv2.waitKey(1) + time.sleep(1 / 24) - ob = torch.tensor(ob, device=device).unsqueeze(0) with torch.no_grad(): + ob = torch.from_numpy(ob).to(device) if hasattr(agent, "lstm"): - action, _, _, _, state = agent.get_action_and_value(ob, state) + action, _, _, _, state = agent(ob, state) else: - action, _, _, _ = agent.get_action_and_value(ob) - - ob, reward, terminal, truncated, _ = env.step(action[0].item()) - return_val += reward - - chars = env.render() - print("\033c", end="") - print(chars) + action, _, _, _ = agent(ob) - if verbose: - print(f"Step: {step} Reward: {reward:.4f} Return: {return_val:.2f}") + action = action.cpu().numpy().reshape(env.action_space.shape) - time.sleep(0.5) - step += 1 + ob, reward = env.step(action)[:2] + reward = reward.mean() + print(f"Reward: {reward:.4f}") def seed_everything(seed, torch_deterministic): diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index f8ce407..ae6cbdc 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -1,5 +1,7 @@ import argparse +import functools import importlib +import os import sys from multiprocessing import Queue from types import ModuleType @@ -15,6 +17,7 @@ import yaml import wandb +from pokemonred_puffer import cleanrl_puffer from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL from pokemonred_puffer.environment import RedGymEnv from pokemonred_puffer.wrappers.async_io import AsyncWrapper @@ -77,17 +80,20 @@ def load_from_config(args: argparse.ArgumentParser): def make_env_creator( wrapper_classes: list[tuple[str, ModuleType]], reward_class: RedGymEnv, + async_wrapper: bool = True, ) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]: def env_creator( env_config: pufferlib.namespace, wrappers_config: list[dict[str, Any]], reward_config: pufferlib.namespace, - async_config: dict[str, Queue], + async_config: dict[str, Queue] | None = None, ) -> pufferlib.emulation.GymnasiumPufferEnv: env = reward_class(env_config, reward_config) for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes): env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0])) - env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) + if async_wrapper and async_config: + print("HEOLAFDAF") + env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) env = pufferlib.postprocess.EpisodeStats(env) return pufferlib.emulation.GymnasiumPufferEnv(env=env) @@ -96,9 +102,7 @@ def env_creator( # Returns env_creator, agent_creator def setup_agent( - wrappers: list[str], - reward_name: str, - policy_name: str, + wrappers: list[str], reward_name: str, async_wrapper: bool = True ) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]: # TODO: Make this less dependent on the name of this repo and its file structure wrapper_classes = [ @@ -117,7 +121,7 @@ def setup_agent( importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name ) # NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class - env_creator = make_env_creator(wrapper_classes, reward_class) + env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper) return env_creator @@ -232,7 +236,9 @@ def train( help="Wrappers to use _in order of instantiation_.", ) # TODO: Add evaluate - parser.add_argument("--mode", type=str, default="train", choices=["train"]) + parser.add_argument( + "--mode", type=str, default="train", choices=["train", "autotune", "evaluate"] + ) parser.add_argument( "--eval-model-path", type=str, default=None, help="Path to model to evaluate" ) @@ -278,7 +284,8 @@ def train( args = update_args(args) args.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}" - env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, args.policy_name) + async_wrapper = args.mode == "train" + env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, async_wrapper) wandb_client = None if args.track: @@ -286,3 +293,31 @@ def train( if args.mode == "train": train(args, env_creator, wandb_client) + elif args.mode == "autotune": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + pufferlib.vector.autotune( + functools.partial(env_creator, **env_kwargs), batch_size=args.train.env_batch_size + ) + elif args.mode == "evaluate": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + try: + cleanrl_puffer.rollout( + env_creator, + env_kwargs, + agent_creator=make_policy, + agent_kwargs={"args": args}, + model_path=args.eval_model_path, + device=args.train.device, + ) + except KeyboardInterrupt: + os._exit(0)