Skip to content

Commit

Permalink
update train, eval, autotune
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Jun 21, 2024
1 parent 48dbc3d commit 9505658
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 33 deletions.
56 changes: 31 additions & 25 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,45 +39,51 @@ def rollout(
agent_kwargs,
model_path=None,
device="cuda",
verbose=True,
):
env = env_creator(**env_kwargs)
# We are just using Serial vecenv to give a consistent
# single-agent/multi-agent API for evaluation
try:
env = pufferlib.vector.make(
env_creator, env_kwargs={"render_mode": "rgb_array", **env_kwargs}
)
except: # noqa: E722
env = pufferlib.vector.make(env_creator, env_kwargs=env_kwargs)

if model_path is None:
agent = agent_creator(env, **agent_kwargs)
else:
agent = torch.load(model_path, map_location=device)

terminal = truncated = True
ob, info = env.reset()
driver = env.driver_env
os.system("clear")
state = None

while True:
if terminal or truncated:
if verbose:
print("--- Reset ---")

ob, info = env.reset()
state = None
step = 0
return_val = 0
render = driver.render()
if driver.render_mode == "ansi":
print("\033[0;0H" + render + "\n")
time.sleep(0.6)
elif driver.render_mode == "rgb_array":
import cv2

render = cv2.cvtColor(render, cv2.COLOR_RGB2BGR)
cv2.imshow("frame", render)
cv2.waitKey(1)
time.sleep(1 / 24)

ob = torch.tensor(ob, device=device).unsqueeze(0)
with torch.no_grad():
ob = torch.from_numpy(ob).to(device)
if hasattr(agent, "lstm"):
action, _, _, _, state = agent.get_action_and_value(ob, state)
action, _, _, _, state = agent(ob, state)
else:
action, _, _, _ = agent.get_action_and_value(ob)

ob, reward, terminal, truncated, _ = env.step(action[0].item())
return_val += reward

chars = env.render()
print("\033c", end="")
print(chars)
action, _, _, _ = agent(ob)

if verbose:
print(f"Step: {step} Reward: {reward:.4f} Return: {return_val:.2f}")
action = action.cpu().numpy().reshape(env.action_space.shape)

time.sleep(0.5)
step += 1
ob, reward = env.step(action)[:2]
reward = reward.mean()
print(f"Reward: {reward:.4f}")


def seed_everything(seed, torch_deterministic):
Expand Down
51 changes: 43 additions & 8 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import functools
import importlib
import os
import sys
from multiprocessing import Queue
from types import ModuleType
Expand All @@ -15,6 +17,7 @@
import yaml

import wandb
from pokemonred_puffer import cleanrl_puffer
from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL
from pokemonred_puffer.environment import RedGymEnv
from pokemonred_puffer.wrappers.async_io import AsyncWrapper
Expand Down Expand Up @@ -77,17 +80,20 @@ def load_from_config(args: argparse.ArgumentParser):
def make_env_creator(
wrapper_classes: list[tuple[str, ModuleType]],
reward_class: RedGymEnv,
async_wrapper: bool = True,
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
def env_creator(
env_config: pufferlib.namespace,
wrappers_config: list[dict[str, Any]],
reward_config: pufferlib.namespace,
async_config: dict[str, Queue],
async_config: dict[str, Queue] | None = None,
) -> pufferlib.emulation.GymnasiumPufferEnv:
env = reward_class(env_config, reward_config)
for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes):
env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0]))
env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"])
if async_wrapper and async_config:
print("HEOLAFDAF")
env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"])
env = pufferlib.postprocess.EpisodeStats(env)
return pufferlib.emulation.GymnasiumPufferEnv(env=env)

Expand All @@ -96,9 +102,7 @@ def env_creator(

# Returns env_creator, agent_creator
def setup_agent(
wrappers: list[str],
reward_name: str,
policy_name: str,
wrappers: list[str], reward_name: str, async_wrapper: bool = True
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
# TODO: Make this less dependent on the name of this repo and its file structure
wrapper_classes = [
Expand All @@ -117,7 +121,7 @@ def setup_agent(
importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name
)
# NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class
env_creator = make_env_creator(wrapper_classes, reward_class)
env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper)

return env_creator

Expand Down Expand Up @@ -232,7 +236,9 @@ def train(
help="Wrappers to use _in order of instantiation_.",
)
# TODO: Add evaluate
parser.add_argument("--mode", type=str, default="train", choices=["train"])
parser.add_argument(
"--mode", type=str, default="train", choices=["train", "autotune", "evaluate"]
)
parser.add_argument(
"--eval-model-path", type=str, default=None, help="Path to model to evaluate"
)
Expand Down Expand Up @@ -278,11 +284,40 @@ def train(
args = update_args(args)
args.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}"

env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, args.policy_name)
async_wrapper = args.mode == "train"
env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, async_wrapper)

wandb_client = None
if args.track:
wandb_client = init_wandb(args).id

if args.mode == "train":
train(args, env_creator, wandb_client)
elif args.mode == "autotune":
env_kwargs = {
"env_config": args.env,
"wrappers_config": args.wrappers[args.wrappers_name],
"reward_config": args.rewards[args.reward_name]["reward"],
"async_config": {},
}
pufferlib.vector.autotune(
functools.partial(env_creator, **env_kwargs), batch_size=args.train.env_batch_size
)
elif args.mode == "evaluate":
env_kwargs = {
"env_config": args.env,
"wrappers_config": args.wrappers[args.wrappers_name],
"reward_config": args.rewards[args.reward_name]["reward"],
"async_config": {},
}
try:
cleanrl_puffer.rollout(
env_creator,
env_kwargs,
agent_creator=make_policy,
agent_kwargs={"args": args},
model_path=args.eval_model_path,
device=args.train.device,
)
except KeyboardInterrupt:
os._exit(0)

0 comments on commit 9505658

Please sign in to comment.