diff --git a/config.yaml b/config.yaml index 6ad611c..fd4108a 100644 --- a/config.yaml +++ b/config.yaml @@ -8,27 +8,31 @@ debug: headless: False stream_wrapper: False init_state: victory_road - max_steps: 1_000_000 + max_steps: 16 + log_frequency: 1 disable_wild_encounters: True disable_ai_actions: True + use_global_map: True train: device: cpu compile: False compile_mode: default num_envs: 1 - envs_per_worker: 1 - envs_per_batch: 1 - batch_size: 16 + num_workers: 1 + env_batch_size: 4 + env_pool: True + zero_copy: False + batch_size: 4 + minibatch_size: 4 batch_rows: 4 bptt_horizon: 2 total_timesteps: 100_000_000 save_checkpoint: True checkpoint_interval: 4 - save_overlay: True + save_overlay: False overlay_interval: 4 verbose: False env_pool: False - log_frequency: 5000 load_optimizer_state: False # swarm_frequency: 10 # swarm_keep_pct: .1 @@ -62,6 +66,7 @@ env: auto_pokeflute: True infinite_money: True use_global_map: False + save_state: False train: @@ -73,6 +78,7 @@ train: float32_matmul_precision: "high" total_timesteps: 100_000_000_000 batch_size: 65536 + minibatch_size: 2048 learning_rate: 2.0e-4 anneal_lr: False gamma: 0.998 @@ -90,10 +96,11 @@ train: bptt_horizon: 16 vf_clip_coef: 0.1 - num_envs: 96 - envs_per_worker: 1 - envs_per_batch: 32 + num_envs: 288 + num_workers: 24 + env_batch_size: 72 env_pool: True + zero_copy: False verbose: True data_dir: runs @@ -104,11 +111,15 @@ train: cpu_offload: True pool_kernel: [0] load_optimizer_state: False + use_rnn: True + async_wrapper: False # swarm_frequency: 500 # swarm_keep_pct: .8 wrappers: + empty: [] + baseline: - stream_wrapper.StreamWrapper: user: thatguy @@ -126,6 +137,7 @@ wrappers: forgetting_frequency: 10 - exploration.OnResetExplorationWrapper: full_reset_frequency: 1 + jitter: 0 finite_coords: - stream_wrapper.StreamWrapper: @@ -224,9 +236,10 @@ policies: policy: hidden_size: 512 - recurrent: + rnn: # Assumed to be in the same module as the policy - name: RecurrentMultiConvolutionalWrapper - input_size: 512 - hidden_size: 512 - num_layers: 1 + name: MultiConvolutionalRNN + args: + input_size: 512 + hidden_size: 512 + num_layers: 1 diff --git a/pokemonred_puffer/c_gae.pyx b/pokemonred_puffer/c_gae.pyx new file mode 100644 index 0000000..122762b --- /dev/null +++ b/pokemonred_puffer/c_gae.pyx @@ -0,0 +1,33 @@ +# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION +# cython: language_level=3 +# cython: boundscheck=False +# cython: initializedcheck=False +# cython: wraparound=False +# cython: nonecheck=False + +import numpy as np +cimport numpy as cnp + +def compute_gae(cnp.ndarray dones, cnp.ndarray values, + cnp.ndarray rewards, float gamma, float gae_lambda): + '''Fast Cython implementation of Generalized Advantage Estimation (GAE)''' + cdef int num_steps = len(rewards) + cdef cnp.ndarray advantages = np.zeros(num_steps, dtype=np.float32) + cdef float[:] c_advantages = advantages + cdef float[:] c_dones = dones + cdef float[:] c_values = values + cdef float[:] c_rewards = rewards + + cdef float lastgaelam = 0 + cdef float nextnonterminal, delta + cdef int t, t_cur, t_next + for t in range(num_steps-1): + t_cur = num_steps - 2 - t + t_next = num_steps - 1 - t + nextnonterminal = 1.0 - c_dones[t_next] + delta = c_rewards[t_next] + gamma * c_values[t_next] * nextnonterminal - c_values[t_cur] + lastgaelam = delta + gamma * gae_lambda * nextnonterminal * lastgaelam + c_advantages[t_cur] = lastgaelam + + return advantages + diff --git a/pokemonred_puffer/cleanrl_puffer.py b/pokemonred_puffer/cleanrl_puffer.py index 3c62d3e..c574220 100644 --- a/pokemonred_puffer/cleanrl_puffer.py +++ b/pokemonred_puffer/cleanrl_puffer.py @@ -1,69 +1,36 @@ +import argparse import heapq import math -from multiprocessing import Queue import os -import pathlib import random import time -from collections import deque -from types import SimpleNamespace -from typing import Any, Callable -import uuid -from collections import defaultdict -from datetime import timedelta +from collections import defaultdict, deque +from dataclasses import dataclass, field +from multiprocessing import Queue import numpy as np import pufferlib import pufferlib.emulation import pufferlib.frameworks.cleanrl -import pufferlib.policy_pool +import pufferlib.pytorch import pufferlib.utils -import pufferlib.vectorization +import pufferlib.vector + +# Fast Cython GAE implementation +import pyximport +import rich import torch import torch.nn as nn -import torch.optim as optim +from rich.console import Console +from rich.table import Table +import wandb from pokemonred_puffer.eval import make_pokemon_red_overlay from pokemonred_puffer.global_map import GLOBAL_MAP_SHAPE +from pokemonred_puffer.profile import Profile, Utilization - -@pufferlib.dataclass -class Performance: - total_uptime = 0 - total_updates = 0 - total_agent_steps = 0 - epoch_time = 0 - epoch_sps = 0 - eval_time = 0 - eval_sps = 0 - eval_memory = 0 - eval_pytorch_memory = 0 - env_time = 0 - env_sps = 0 - inference_time = 0 - inference_sps = 0 - train_time = 0 - train_sps = 0 - train_memory = 0 - train_pytorch_memory = 0 - - -@pufferlib.dataclass -class Losses: - policy_loss = 0 - value_loss = 0 - entropy = 0 - old_approx_kl = 0 - approx_kl = 0 - clipfrac = 0 - explained_variance = 0 - - -@pufferlib.dataclass -class Charts: - global_step = 0 - SPS = 0 - learning_rate = 0 +pyximport.install(setup_args={"include_dirs": np.get_include()}) +from pokemonred_puffer.c_gae import compute_gae # type: ignore # noqa: E402 def rollout( @@ -73,45 +40,51 @@ def rollout( agent_kwargs, model_path=None, device="cuda", - verbose=True, ): - env = env_creator(**env_kwargs) + # We are just using Serial vecenv to give a consistent + # single-agent/multi-agent API for evaluation + try: + env = pufferlib.vector.make( + env_creator, env_kwargs={"render_mode": "rgb_array", **env_kwargs} + ) + except: # noqa: E722 + env = pufferlib.vector.make(env_creator, env_kwargs=env_kwargs) + if model_path is None: agent = agent_creator(env, **agent_kwargs) else: agent = torch.load(model_path, map_location=device) - terminal = truncated = True + ob, info = env.reset() + driver = env.driver_env + os.system("clear") + state = None while True: - if terminal or truncated: - if verbose: - print("--- Reset ---") - - ob, info = env.reset() - state = None - step = 0 - return_val = 0 + render = driver.render() + if driver.render_mode == "ansi": + print("\033[0;0H" + render + "\n") + time.sleep(0.6) + elif driver.render_mode == "rgb_array": + import cv2 + + render = cv2.cvtColor(render, cv2.COLOR_RGB2BGR) + cv2.imshow("frame", render) + cv2.waitKey(1) + time.sleep(1 / 24) - ob = torch.tensor(ob, device=device).unsqueeze(0) with torch.no_grad(): + ob = torch.from_numpy(ob).to(device) if hasattr(agent, "lstm"): - action, _, _, _, state = agent.get_action_and_value(ob, state) + action, _, _, _, state = agent(ob, state) else: - action, _, _, _ = agent.get_action_and_value(ob) + action, _, _, _ = agent(ob) - ob, reward, terminal, truncated, _ = env.step(action[0].item()) - return_val += reward + action = action.cpu().numpy().reshape(env.action_space.shape) - chars = env.render() - print("\033c", end="") - print(chars) - - if verbose: - print(f"Step: {step} Reward: {reward:.4f} Return: {return_val:.2f}") - - time.sleep(0.5) - step += 1 + ob, reward = env.step(action)[:2] + reward = reward.mean() + print(f"Reward: {reward:.4f}") def seed_everything(seed, torch_deterministic): @@ -134,243 +107,122 @@ def unroll_nested_dict(d): yield k, v -def print_dashboard(stats, init_performance, performance): - output = [] - data = {**stats, **init_performance, **performance} - - grouped_data = defaultdict(dict) +def count_params(policy: nn.Module): + return sum(p.numel() for p in policy.parameters() if p.requires_grad) - for k, v in data.items(): - if k == "total_uptime": - v = timedelta(seconds=v) - if "memory" in k: - v = pufferlib.utils.format_bytes(v) - elif "time" in k: - try: - v = f"{v:.2f} s" - except: # noqa - pass - first_word, *rest_words = k.split("_") - rest_words = " ".join(rest_words).title() - - grouped_data[first_word][rest_words] = v - - for main_key, sub_dict in grouped_data.items(): - output.append(f"{main_key.title()}") - for sub_key, sub_value in sub_dict.items(): - output.append(f" {sub_key}: {sub_value}") - - print("\033c", end="") - print("\n".join(output)) - time.sleep(1 / 20) +@dataclass +class Losses: + policy_loss: float = 0.0 + value_loss: float = 0.0 + entropy: float = 0.0 + old_approx_kl: float = 0.0 + approx_kl: float = 0.0 + clipfrac: float = 0.0 + explained_variance: float = 0.0 -# TODO: Make this an unfrozen dataclass with a post_init? +@dataclass class CleanPuffeRL: - def __init__( - self, - config: SimpleNamespace | None = None, - exp_name: str | None = None, - track: bool = False, - # Agent - agent: nn.Module | None = None, - agent_creator: Callable[..., Any] | None = None, - agent_kwargs: dict = None, - # Environment - env_creator: Callable[..., Any] | None = None, - env_creator_kwargs: dict | None = None, - vectorization: ... = pufferlib.vectorization.Serial, - # Policy Pool options - policy_selector: Callable[ - [list[Any], int], list[Any] - ] = pufferlib.policy_pool.random_selector, - ): - self.config = config - if self.config is None: - self.config = pufferlib.args.CleanPuffeRL() - - self.exp_name = exp_name - if exp_name is None: - exp_name = str(uuid.uuid4())[:8] - - self.wandb = None - if track: - import wandb - - self.wandb = wandb - - self.start_time = time.time() - seed_everything(config.seed, config.torch_deterministic) - self.total_updates = config.total_timesteps // config.batch_size - self.total_agent_steps = 0 - - self.device = config.device - - # Create environments, agent, and optimizer - init_profiler = pufferlib.utils.Profiler(memory=True) - with init_profiler: - self.pool = vectorization( - env_creator, - env_kwargs=env_creator_kwargs, - num_envs=config.num_envs, - envs_per_worker=config.envs_per_worker, - envs_per_batch=config.envs_per_batch, - env_pool=config.env_pool, - mask_agents=True, + exp_name: str + config: argparse.Namespace + vecenv: pufferlib.vector.Serial | pufferlib.vector.Multiprocessing + policy: nn.Module + env_send_queues: list[Queue] + env_recv_queues: list[Queue] + wandb_client: wandb.wandb_sdk.wandb_run.Run | None = None + profile: Profile = field(default_factory=lambda: Profile()) + losses: Losses = field(default_factory=lambda: Losses()) + global_step: int = 0 + epoch: int = 0 + stats: dict = field(default_factory=lambda: {}) + msg: str = "" + infos: dict = field(default_factory=lambda: defaultdict(list)) + + def __post_init__(self): + seed_everything(self.config.seed, self.config.torch_deterministic) + if self.config.verbose: + print_dashboard( + self.config.env, + self.utilization, + 0, + 0, + self.profile, + self.losses, + {}, + self.msg, + clear=True, ) - obs_shape = self.pool.single_observation_space.shape - atn_shape = self.pool.single_action_space.shape - self.num_agents = self.pool.agents_per_env - total_agents = self.num_agents * config.num_envs - - self.agent = pufferlib.emulation.make_object( - agent, agent_creator, [self.pool.driver_env], agent_kwargs + self.utilization = Utilization() + + self.vecenv.async_reset(self.config.seed) + obs_shape = self.vecenv.single_observation_space.shape + obs_dtype = self.vecenv.single_observation_space.dtype + atn_shape = self.vecenv.single_action_space.shape + total_agents = self.vecenv.num_agents + + self.lstm = self.policy.lstm if hasattr(self.policy, "lstm") else None + self.experience = Experience( + self.config.batch_size, + self.vecenv.agents_per_batch, + self.config.bptt_horizon, + self.config.minibatch_size, + obs_shape, + obs_dtype, + atn_shape, + self.config.cpu_offload, + self.config.device, + self.lstm, + total_agents, ) - self.env_send_queues: list[Queue] = env_creator_kwargs["async_config"]["send_queues"] - self.env_recv_queues: list[Queue] = env_creator_kwargs["async_config"]["recv_queues"] - - # If data_dir is provided, load the resume state - resume_state = {} - path = pathlib.Path(config.data_dir) / exp_name - trainer_path = path / "trainer_state.pt" - if trainer_path.exists(): - resume_state = torch.load(trainer_path) - - model_version = str(resume_state["update"]).zfill(6) - model_filename = f"model_{model_version}_state.pth" - model_path = path / model_filename - if model_path.exists(): - self.agent.load_state_dict(torch.load(model_path, map_location=self.device)) - print( - f'Resumed from update {resume_state["update"]} ' - f'with policy {resume_state["model_name"]}' - ) - else: - print("No checkpoint found. Starting fresh.") - else: - print("No checkpoint found. Starting fresh.") - - self.global_step = resume_state.get("global_step", 0) - self.agent_step = resume_state.get("agent_step", 0) - self.update = resume_state.get("update", 0) - self.lr_update = resume_state.get("lr_update", 0) - self.optimizer = optim.Adam(self.agent.parameters(), lr=config.learning_rate, eps=1e-5) - self.opt_state = resume_state.get("optimizer_state_dict", None) + self.uncompiled_policy = self.policy - if config.compile: - self.agent = torch.compile(self.agent, mode=config.compile_mode) - # TODO: Figure out how to compile the optimizer! - # self.calculate_loss = torch.compile(self.calculate_loss, mode=config.compile_mode) + if self.config.compile: + self.policy = torch.compile(self.policy, mode=self.config.compile_mode) - if config.load_optimizer_state is True and self.opt_state is not None: - self.optimizer.load_state_dict(resume_state["optimizer_state_dict"]) - - # Create policy pool - pool_agents = self.num_agents * self.pool.envs_per_batch - self.policy_pool = pufferlib.policy_pool.PolicyPool( - self.agent, - pool_agents, - atn_shape, - self.device, - path, - self.config.pool_kernel, - policy_selector, + self.optimizer = torch.optim.Adam( + self.policy.parameters(), lr=self.config.learning_rate, eps=1e-5 ) - # Allocate Storage - storage_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() - self.pool.async_reset(config.seed) - self.next_lstm_state = None - if hasattr(self.agent, "lstm"): - shape = (self.agent.lstm.num_layers, total_agents, self.agent.lstm.hidden_size) - self.next_lstm_state = ( - torch.zeros(shape, device=self.device), - torch.zeros(shape, device=self.device), - ) - self.obs = torch.zeros(config.batch_size + 1, *obs_shape, dtype=torch.uint8) - self.actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) - self.logprobs = torch.zeros(config.batch_size + 1) - self.rewards = torch.zeros(config.batch_size + 1) - self.dones = torch.zeros(config.batch_size + 1) - self.truncateds = torch.zeros(config.batch_size + 1) - self.values = torch.zeros(config.batch_size + 1) - - self.obs_ary = np.asarray(self.obs, dtype=np.uint8) - self.actions_ary = np.asarray(self.actions) - self.logprobs_ary = np.asarray(self.logprobs) - self.rewards_ary = np.asarray(self.rewards) - self.dones_ary = np.asarray(self.dones) - self.truncateds_ary = np.asarray(self.truncateds) - self.values_ary = np.asarray(self.values) - - storage_profiler.stop() - - # "charts/actions": wandb.Histogram(b_actions.cpu().numpy()), - self.init_performance = pufferlib.namespace( - init_time=time.time() - self.start_time, - init_env_time=init_profiler.elapsed, - init_env_memory=init_profiler.memory, - tensor_memory=storage_profiler.memory, - tensor_pytorch_memory=storage_profiler.pytorch_memory, - ) - - self.sort_keys = [] - self.learning_rate = (config.learning_rate,) - self.losses = Losses() - self.performance = Performance() + self.last_log_time = time.time() self.reward_buffer = deque(maxlen=1_000) - self.exploration_map_agg = np.zeros((config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32) + self.exploration_map_agg = np.zeros( + (self.config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32 + ) self.cut_exploration_map_agg = np.zeros( - (config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32 + (self.config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32 ) self.taught_cut = False - - self.infos = {} self.log = False @pufferlib.utils.profile def evaluate(self): - config = self.config - # TODO: Handle update on resume - if self.log and self.wandb is not None and self.performance.total_uptime > 0: - self.wandb.log( - { - "SPS": self.SPS, - "global_step": self.global_step, - "learning_rate": self.optimizer.param_groups[0]["lr"], - **{f"losses/{k}": v for k, v in self.losses.items()}, - **{f"performance/{k}": v for k, v in self.performance.items()}, - **{f"stats/{k}": v for k, v in self.stats.items()}, - **{f"max_stats/{k}": v for k, v in self.max_stats.items()}, - **{ - f"skillrank/{policy}": elo - for policy, elo in self.policy_pool.ranker.ratings.items() - }, - }, - ) - self.log = False + # Clear all self.infos except for the state + for k in list(self.infos.keys()): + if k != "state": + del self.infos[k] # now for a tricky bit: # if we have swarm_frequency, we will take the top swarm_keep_pct envs and evenly distribute # their states to the bottom 90%. # we do this here so the environment can remain "pure" if ( - hasattr(self.config, "swarm_frequency") + self.config.async_wrapper + and hasattr(self.config, "swarm_frequency") and hasattr(self.config, "swarm_keep_pct") - and self.update % self.config.swarm_frequency == 0 - and "learner" in self.infos - and "reward/event" in self.infos["learner"] + and self.epoch % self.config.swarm_frequency == 0 + and "reward/event" in self.infos + and "state" in self.infos ): # collect the top swarm_keep_pct % of envs largest = [ x[0] for x in heapq.nlargest( math.ceil(self.config.num_envs * self.config.swarm_keep_pct), - enumerate(self.infos["learner"]["reward/event"]), + enumerate(self.infos["reward/event"]), key=lambda x: x[1], ) ] @@ -382,9 +234,9 @@ def evaluate(self): if i not in largest: new_state = random.choice(largest) print( - f'\t {i+1} -> {new_state+1}, event scores: {self.infos["learner"]["reward/event"][i]} -> {self.infos["learner"]["reward/event"][new_state]}' + f'\t {i+1} -> {new_state+1}, event scores: {self.infos["reward/event"][i]} -> {self.infos["reward/event"][new_state]}' ) - self.env_recv_queues[i + 1].put(self.infos["learner"]["state"][new_state]) + self.env_recv_queues[i + 1].put(self.infos["state"][new_state]) waiting_for.append(i + 1) # Now copy the hidden state over # This may be a little slow, but so is this whole process @@ -394,378 +246,285 @@ def evaluate(self): self.env_send_queues[i].get() print("State migration complete") - self.policy_pool.update_policies() - env_profiler = pufferlib.utils.Profiler() - inference_profiler = pufferlib.utils.Profiler() - eval_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() - misc_profiler = pufferlib.utils.Profiler() - - ptr = step = padded_steps_collected = agent_steps_collected = 0 - while True: - step += 1 - if ptr == config.batch_size + 1: - break - - with env_profiler: - o, r, d, t, i, env_id, mask = self.pool.recv() - - with misc_profiler: - i = self.policy_pool.update_scores(i, "return") - # TODO: Update this for policy pool - for ii, ee in zip(i["learner"], env_id): - ii["env_id"] = ee - - with inference_profiler, torch.no_grad(): - o = torch.as_tensor(o).to(device=self.device, non_blocking=True) - r = ( - torch.as_tensor(r, dtype=torch.float32) - .to(device=self.device, non_blocking=True) - .view(-1) - ) - d = ( - torch.as_tensor(d, dtype=torch.float32) - .to(device=self.device, non_blocking=True) - .view(-1) - ) - - agent_steps_collected += sum(mask) - padded_steps_collected += len(mask) - - # Multiple policies will not work with new envpool - next_lstm_state = self.next_lstm_state - if next_lstm_state is not None: - next_lstm_state = ( - next_lstm_state[0][:, env_id], - next_lstm_state[1][:, env_id], - ) - - actions, logprob, value, next_lstm_state = self.policy_pool.forwards( - o, next_lstm_state - ) + with self.profile.eval_misc: + policy = self.policy + lstm_h, lstm_c = self.experience.lstm_h, self.experience.lstm_c + + while not self.experience.full: + with self.profile.env: + o, r, d, t, info, env_id, mask = self.vecenv.recv() + env_id = env_id.tolist() + + with self.profile.eval_misc: + self.global_step += sum(mask) + + o = torch.as_tensor(o) + o_device = o.to(self.config.device) + r = torch.as_tensor(r) + d = torch.as_tensor(d) + + with self.profile.eval_forward, torch.no_grad(): + # TODO: In place-update should be faster. Leaking 7% speed max + # Also should be using a cuda tensor to index + if lstm_h is not None: + h = lstm_h[:, env_id] + c = lstm_c[:, env_id] + actions, logprob, _, value, (h, c) = policy(o_device, (h, c)) + lstm_h[:, env_id] = h + lstm_c[:, env_id] = c + else: + actions, logprob, _, value = policy(o_device) - if next_lstm_state is not None: - h, c = next_lstm_state - self.next_lstm_state[0][:, env_id] = h - self.next_lstm_state[1][:, env_id] = c + if self.config.device == "cuda": + torch.cuda.synchronize() + with self.profile.eval_misc: value = value.flatten() - - with misc_profiler: actions = actions.cpu().numpy() + mask = torch.as_tensor(mask) # * policy.mask) + o = o if self.config.cpu_offload else o_device + if self.config.num_envs == 1: + actions = np.expand_dims(actions, 0) + logprob = logprob.unsqueeze(0) + self.experience.store(o, value, actions, logprob, r, d, env_id, mask) + + for i in info: + for k, v in pufferlib.utils.unroll_nested_dict(i): + if k == "state": + self.infos[k] = [v] + else: + self.infos[k].append(v) + + with self.profile.env: + self.vecenv.send(actions) + + with self.profile.eval_misc: + self.stats = {} + + for k, v in self.infos.items(): + # Moves into models... maybe. Definitely moves. + # You could also just return infos and have it in demo + if "pokemon_exploration_map" in self.infos and self.config.save_overlay is True: + if self.epoch % self.config.overlay_interval == 0: + overlay = make_pokemon_red_overlay( + np.stack(self.infos["pokemon_exploration_map"], axis=0) + ) + if self.wandb_client is not None: + self.stats["Media/aggregate_exploration_map"] = wandb.Image(overlay) + elif "state" in k: + continue - # Index alive mask with policy pool idxs... - # TODO: Find a way to avoid having to do this - learner_mask = torch.as_tensor(mask * self.policy_pool.mask) - - # Ensure indices do not exceed batch size - indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy() - end = ptr + len(indices) - - # Batch indexing - self.obs_ary[ptr:end] = o.cpu().numpy()[indices] - self.values_ary[ptr:end] = value.cpu().numpy()[indices] - self.actions_ary[ptr:end] = actions[indices] - self.logprobs_ary[ptr:end] = logprob.cpu().numpy()[indices] - self.rewards_ary[ptr:end] = r.cpu().numpy()[indices] - self.dones_ary[ptr:end] = d.cpu().numpy()[indices] - self.sort_keys.extend([(env_id[i], step) for i in indices]) - - # Update pointer - ptr += len(indices) - - for policy_name, policy_i in i.items(): - for agent_i in policy_i: - for name, dat in unroll_nested_dict(agent_i): - if policy_name not in self.infos: - self.infos[policy_name] = {} - if name not in self.infos[policy_name]: - self.infos[policy_name][name] = [ - np.zeros_like(dat) - ] * self.config.num_envs - self.infos[policy_name][name][agent_i["env_id"]] = dat - # infos[policy_name][name].append(dat) - with env_profiler: - self.pool.send(actions) - - eval_profiler.stop() - - # Now that we initialized the model, we can get the number of parameters - if self.global_step == 0 and self.config.verbose: - self.n_params = sum(p.numel() for p in self.agent.parameters() if p.requires_grad) - print(f"Model Size: {self.n_params//1000} K parameters") - - self.total_agent_steps += padded_steps_collected - new_step = np.mean(self.infos["learner"]["stats/step"]) - - if new_step > self.global_step: - self.global_step = new_step - self.log = True - self.reward = torch.mean(self.rewards).float().item() - self.SPS = int(padded_steps_collected / eval_profiler.elapsed) - - perf = self.performance - perf.total_uptime = int(time.time() - self.start_time) - perf.total_agent_steps = self.total_agent_steps - perf.env_time = env_profiler.elapsed - perf.env_sps = int(agent_steps_collected / env_profiler.elapsed) - perf.inference_time = inference_profiler.elapsed - perf.inference_sps = int(padded_steps_collected / inference_profiler.elapsed) - perf.eval_time = eval_profiler.elapsed - perf.eval_sps = int(padded_steps_collected / eval_profiler.elapsed) - perf.eval_memory = eval_profiler.end_mem - perf.eval_pytorch_memory = eval_profiler.end_torch_mem - perf.misc_time = misc_profiler.elapsed - - self.stats = {} - self.max_stats = {} - for k, v in self.infos["learner"].items(): - if "pokemon_exploration_map" in k and config.save_overlay is True: - if self.update % config.overlay_interval == 0: - overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) - if self.wandb is not None: - self.stats["Media/aggregate_exploration_map"] = self.wandb.Image(overlay) - # elif "cut_exploration_map" in k and config.save_overlay is True: - # if self.update % config.overlay_interval == 0: - # overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) - # if self.wandb is not None: - # self.stats["Media/aggregate_cut_exploration_map"] = self.wandb.Image( - # overlay - # ) - elif "state" in k: - pass - else: try: # TODO: Better checks on log data types - # self.stats[f"Histogram/{k}"] = self.wandb.Histogram(v, num_bins=16) self.stats[k] = np.mean(v) - self.max_stats[k] = np.max(v) - except: # noqa + except: # noqa: E722 continue - if config.verbose: - print_dashboard(self.stats, self.init_performance, self.performance) + if self.config.verbose: + self.msg = f"Model Size: {abbreviate(count_params(self.policy))} parameters" + print_dashboard( + self.config.env, + self.utilization, + self.global_step, + self.epoch, + self.profile, + self.losses, + self.stats, + self.msg, + ) return self.stats, self.infos @pufferlib.utils.profile def train(self): - if self.done_training(): - raise RuntimeError(f"Max training updates {self.total_updates} already reached") - - config = self.config - # assert data.num_steps % bptt_horizon == 0, "num_steps must be divisible by bptt_horizon" - - train_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True) - train_profiler.start() - - if config.anneal_lr: - frac = 1.0 - (self.lr_update - 1.0) / self.total_updates - lrnow = frac * config.learning_rate - self.optimizer.param_groups[0]["lr"] = lrnow - - num_minibatches = config.batch_size // config.bptt_horizon // config.batch_rows - assert ( - num_minibatches > 0 - ), "config.batch_size // config.bptt_horizon // config.batch_rows must be > 0" - idxs = sorted(range(len(self.sort_keys)), key=self.sort_keys.__getitem__) - self.sort_keys = [] - b_idxs = ( - torch.tensor(idxs, dtype=torch.long)[:-1] - .reshape(config.batch_rows, num_minibatches, config.bptt_horizon) - .transpose(0, 1) - ) - - # bootstrap value if not done - with torch.no_grad(): - advantages = torch.zeros(config.batch_size, device=self.device) - lastgaelam = 0 - for t in reversed(range(config.batch_size)): - i, i_nxt = idxs[t], idxs[t + 1] - nextnonterminal = 1.0 - self.dones[i_nxt] - nextvalues = self.values[i_nxt] - delta = ( - self.rewards[i_nxt] - + config.gamma * nextvalues * nextnonterminal - - self.values[i] - ) - advantages[t] = lastgaelam = ( - delta + config.gamma * config.gae_lambda * nextnonterminal * lastgaelam - ) + self.losses = Losses() + losses = self.losses - # Flatten the batch - self.b_obs = b_obs = torch.as_tensor(self.obs_ary[b_idxs], dtype=torch.uint8) - b_actions = torch.as_tensor(self.actions_ary[b_idxs]).to(self.device, non_blocking=True) - b_logprobs = torch.as_tensor(self.logprobs_ary[b_idxs]).to(self.device, non_blocking=True) - # b_dones = torch.as_tensor(self.dones_ary[b_idxs]).to(self.device, non_blocking=True) - b_values = torch.as_tensor(self.values_ary[b_idxs]).to(self.device, non_blocking=True) - b_advantages = advantages.reshape( - config.batch_rows, num_minibatches, config.bptt_horizon - ).transpose(0, 1) - b_returns = b_advantages + b_values - - # Optimizing the policy and value network - train_time = time.time() - pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], [] - mb_obs_buffer = torch.zeros_like( - b_obs[0], pin_memory=(self.device == "cuda"), dtype=torch.uint8 - ) + with self.profile.train_misc: + idxs = self.experience.sort_training_data() + dones_np = self.experience.dones_np[idxs] + values_np = self.experience.values_np[idxs] + rewards_np = self.experience.rewards_np[idxs] + # TODO: bootstrap between segment bounds + advantages_np = compute_gae( + dones_np, values_np, rewards_np, self.config.gamma, self.config.gae_lambda + ) + self.experience.flatten_batch(advantages_np) - for epoch in range(config.update_epochs): + for _ in range(self.config.update_epochs): lstm_state = None - for mb in range(num_minibatches): - mb_obs_buffer.copy_(b_obs[mb], non_blocking=True) - mb_obs = mb_obs_buffer.to(self.device, non_blocking=True) - mb_actions = b_actions[mb].contiguous() - mb_values = b_values[mb].reshape(-1) - mb_advantages = b_advantages[mb].reshape(-1) - mb_returns = b_returns[mb].reshape(-1) - - if hasattr(self.agent, "lstm"): - ( - _, - newlogprob, - entropy, - newvalue, - lstm_state, - ) = self.agent.get_action_and_value(mb_obs, state=lstm_state, action=mb_actions) - lstm_state = (lstm_state[0].detach(), lstm_state[1].detach()) - else: - _, newlogprob, entropy, newvalue = self.agent.get_action_and_value( - mb_obs.reshape(-1, *self.pool.single_observation_space.shape), - action=mb_actions, + for mb in range(self.experience.num_minibatches): + with self.profile.train_misc: + obs = self.experience.b_obs[mb] + obs = obs.to(self.config.device) + atn = self.experience.b_actions[mb] + log_probs = self.experience.b_logprobs[mb] + val = self.experience.b_values[mb] + adv = self.experience.b_advantages[mb] + ret = self.experience.b_returns[mb] + + with self.profile.train_forward: + if self.experience.lstm_h is not None: + _, newlogprob, entropy, newvalue, lstm_state = self.policy( + obs, state=lstm_state, action=atn + ) + lstm_state = (lstm_state[0].detach(), lstm_state[1].detach()) + else: + _, newlogprob, entropy, newvalue = self.policy( + obs.reshape(-1, *self.vecenv.single_observation_space.shape), + action=atn, + ) + + if self.config.device == "cuda": + torch.cuda.synchronize() + + with self.profile.train_misc: + logratio = newlogprob - log_probs.reshape(-1) + ratio = logratio.exp() + + with torch.no_grad(): + # calculate approx_kl http://joschu.net/blog/kl-approx.html + old_approx_kl = (-logratio).mean() + approx_kl = ((ratio - 1) - logratio).mean() + clipfrac = ((ratio - 1.0).abs() > self.config.clip_coef).float().mean() + + adv = adv.reshape(-1) + if self.config.norm_adv: + adv = (adv - adv.mean()) / (adv.std() + 1e-8) + + # Policy loss + pg_loss1 = -adv * ratio + pg_loss2 = -adv * torch.clamp( + ratio, 1 - self.config.clip_coef, 1 + self.config.clip_coef ) - - logratio = newlogprob - b_logprobs[mb].reshape(-1) - ratio = logratio.exp() - - with torch.no_grad(): - # calculate approx_kl http://joschu.net/blog/kl-approx.html - old_approx_kl = (-logratio).mean() - old_kls.append(old_approx_kl.item()) - approx_kl = ((ratio - 1) - logratio).mean() - kls.append(approx_kl.item()) - clipfracs += [((ratio - 1.0).abs() > config.clip_coef).float().mean().item()] - - mb_advantages = mb_advantages.reshape(-1) - if config.norm_adv: - mb_advantages = (mb_advantages - mb_advantages.mean()) / ( - mb_advantages.std() + 1e-8 + pg_loss = torch.max(pg_loss1, pg_loss2).mean() + + # Value loss + newvalue = newvalue.view(-1) + if self.config.clip_vloss: + v_loss_unclipped = (newvalue - ret) ** 2 + v_clipped = val + torch.clamp( + newvalue - val, + -self.config.vf_clip_coef, + self.config.vf_clip_coef, + ) + v_loss_clipped = (v_clipped - ret) ** 2 + v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) + v_loss = 0.5 * v_loss_max.mean() + else: + v_loss = 0.5 * ((newvalue - ret) ** 2).mean() + + entropy_loss = entropy.mean() + loss = ( + pg_loss - self.config.ent_coef * entropy_loss + v_loss * self.config.vf_coef ) - # Policy loss - pg_loss1 = -mb_advantages * ratio - pg_loss2 = -mb_advantages * torch.clamp( - ratio, 1 - self.config.clip_coef, 1 + self.config.clip_coef - ) - pg_loss = torch.max(pg_loss1, pg_loss2).mean() - pg_losses.append(pg_loss.item()) - - # Value loss - newvalue = newvalue.view(-1) - if self.config.clip_vloss: - v_loss_unclipped = (newvalue - mb_returns) ** 2 - v_clipped = mb_values + torch.clamp( - newvalue - mb_values, - -self.config.vf_clip_coef, - self.config.vf_clip_coef, + with self.profile.learn: + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_( + self.policy.parameters(), self.config.max_grad_norm ) - v_loss_clipped = (v_clipped - mb_returns) ** 2 - v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) - v_loss = 0.5 * v_loss_max.mean() - else: - v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() - v_losses.append(v_loss.item()) - - entropy_loss = entropy.mean() - entropy_losses.append(entropy_loss.item()) - - self.calculate_loss(pg_loss, entropy_loss, v_loss) - - if config.target_kl is not None: - if approx_kl > config.target_kl: + self.optimizer.step() + if self.config.device == "cuda": + torch.cuda.synchronize() + + with self.profile.train_misc: + losses.policy_loss += pg_loss.item() / self.experience.num_minibatches + losses.value_loss += v_loss.item() / self.experience.num_minibatches + losses.entropy += entropy_loss.item() / self.experience.num_minibatches + losses.old_approx_kl += old_approx_kl.item() / self.experience.num_minibatches + losses.approx_kl += approx_kl.item() / self.experience.num_minibatches + losses.clipfrac += clipfrac.item() / self.experience.num_minibatches + + if self.config.target_kl is not None: + if approx_kl > self.config.target_kl: break - train_profiler.stop() - y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy() - var_y = np.var(y_true) - explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + with self.profile.train_misc: + if self.config.anneal_lr: + frac = 1.0 - self.global_step / self.config.total_timesteps + lrnow = frac * self.config.learning_rate + self.optimizer.param_groups[0]["lr"] = lrnow + + y_pred = self.experience.values_np + y_true = self.experience.returns_np + var_y = np.var(y_true) + explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + losses.explained_variance = explained_var + self.epoch += 1 + + done_training = self.global_step >= self.config.total_timesteps + if self.profile.update(self) or done_training: + if self.config.verbose: + print_dashboard( + self.config.env, + self.utilization, + self.global_step, + self.epoch, + self.profile, + self.losses, + self.stats, + self.msg, + ) - losses = self.losses - losses.policy_loss = np.mean(pg_losses) - losses.value_loss = np.mean(v_losses) - losses.entropy = np.mean(entropy_losses) - losses.old_approx_kl = np.mean(old_kls) - losses.approx_kl = np.mean(kls) - losses.clipfrac = np.mean(clipfracs) - losses.explained_variance = explained_var - - perf = self.performance - perf.total_uptime = int(time.time() - self.start_time) - perf.total_updates = self.update + 1 - perf.train_time = time.time() - train_time - perf.train_sps = int(config.batch_size / perf.train_time) - perf.train_memory = train_profiler.end_mem - perf.train_pytorch_memory = train_profiler.end_torch_mem - perf.epoch_time = perf.eval_time + perf.train_time - perf.epoch_sps = int(config.batch_size / perf.epoch_time) - - if config.verbose: - print_dashboard(self.stats, self.init_performance, self.performance) - - self.update += 1 - self.lr_update += 1 - - if self.update % config.checkpoint_interval == 0 or self.done_training(): - self.save_checkpoint() + if ( + self.wandb_client is not None + and self.global_step > 0 + and time.time() - self.last_log_time > 5.0 + ): + self.last_log_time = time.time() + self.wandb_client.log( + { + "Overview/SPS": self.profile.SPS, + "Overview/agent_steps": self.global_step, + "Overview/learning_rate": self.optimizer.param_groups[0]["lr"], + **{f"environment/{k}": v for k, v in self.stats.items()}, + **{f"losses/{k}": v for k, v in self.losses.__dict__.items()}, + **{f"performance/{k}": v for k, v in self.profile}, + } + ) + + if self.epoch % self.config.checkpoint_interval == 0 or done_training: + self.save_checkpoint() + self.msg = f"Checkpoint saved at update {self.epoch}" def close(self): - self.pool.close() + self.vecenv.close() + if self.config.verbose: + self.utilization.stop() - if self.wandb is not None: + if self.wandb_client is not None: artifact_name = f"{self.exp_name}_model" - artifact = self.wandb.Artifact(artifact_name, type="model") + artifact = wandb.Artifact(artifact_name, type="model") model_path = self.save_checkpoint() artifact.add_file(model_path) - self.wandb.run.log_artifact(artifact) - self.wandb.finish() + self.wandb_client.log_artifact(artifact) + self.wandb_client.finish() def save_checkpoint(self): - if self.config.save_checkpoint is False: - return - - path = os.path.join(self.config.data_dir, self.exp_name) + config = self.config + path = os.path.join(config.data_dir, config.exp_id) if not os.path.exists(path): os.makedirs(path) - model_name = f"model_{self.update:06d}_state.pth" + model_name = f"model_{self.epoch:06d}.pt" model_path = os.path.join(path, model_name) - - # Already saved if os.path.exists(model_path): return model_path - # To handleboth uncompiled and compiled self.agent, when getting state_dict() - torch.save(getattr(self.agent, "_orig_mod", self.agent).state_dict(), model_path) + torch.save(self.uncompiled_policy, model_path) state = { "optimizer_state_dict": self.optimizer.state_dict(), "global_step": self.global_step, "agent_step": self.global_step, - "update": self.update, + "update": self.epoch, "model_name": model_name, + "exp_id": config.exp_id, } - - if self.wandb: - state["exp_name"] = self.exp_name - state_path = os.path.join(path, "trainer_state.pt") torch.save(state, state_path + ".tmp") os.rename(state_path + ".tmp", state_path) - - # Also save a copy - torch.save(state, os.path.join(path, f"trainer_state_{self.update:06d}.pt")) - - print(f"Model saved to {model_path}") - return model_path def calculate_loss(self, pg_loss, entropy_loss, v_loss): @@ -776,7 +535,7 @@ def calculate_loss(self, pg_loss, entropy_loss, v_loss): self.optimizer.step() def done_training(self): - return self.update >= self.total_updates + return self.global_step >= self.config.total_timesteps def __enter__(self): return self @@ -786,3 +545,301 @@ def __exit__(self, *args): self.save_checkpoint() self.close() print("Run complete") + + +class Experience: + """Flat tensor storage and array views for faster indexing""" + + def __init__( + self, + batch_size: int, + agents_per_batch: int, + bptt_horizon: int, + minibatch_size: int, + obs_shape: tuple[int], + obs_dtype: np.dtype, + atn_shape: tuple[int], + cpu_offload: bool = False, + device: str = "cuda", + lstm: torch.nn.LSTM | None = None, + lstm_total_agents: int = 0, + ): + obs_dtype = pufferlib.pytorch.numpy_to_torch_dtype_dict[obs_dtype] + pin = device == "cuda" and cpu_offload + # obs_device = device if not pin else "cpu" + self.obs = torch.zeros( + batch_size, + *obs_shape, + dtype=obs_dtype, + pin_memory=pin, + device=device if not pin else "cpu", + ) + self.actions = torch.zeros(batch_size, *atn_shape, dtype=int, pin_memory=pin) + self.logprobs = torch.zeros(batch_size, pin_memory=pin) + self.rewards = torch.zeros(batch_size, pin_memory=pin) + self.dones = torch.zeros(batch_size, pin_memory=pin) + self.truncateds = torch.zeros(batch_size, pin_memory=pin) + self.values = torch.zeros(batch_size, pin_memory=pin) + + # self.obs_np = np.asarray(self.obs) + self.actions_np = np.asarray(self.actions) + self.logprobs_np = np.asarray(self.logprobs) + self.rewards_np = np.asarray(self.rewards) + self.dones_np = np.asarray(self.dones) + self.truncateds_np = np.asarray(self.truncateds) + self.values_np = np.asarray(self.values) + + self.lstm_h = self.lstm_c = None + if lstm is not None: + assert lstm_total_agents > 0 + shape = (lstm.num_layers, lstm_total_agents, lstm.hidden_size) + self.lstm_h = torch.zeros(shape).to(device) + self.lstm_c = torch.zeros(shape).to(device) + + num_minibatches = batch_size / minibatch_size + self.num_minibatches = int(num_minibatches) + if self.num_minibatches != num_minibatches: + raise ValueError("batch_size must be divisible by minibatch_size") + + minibatch_rows = minibatch_size / bptt_horizon + self.minibatch_rows = int(minibatch_rows) + if self.minibatch_rows != minibatch_rows: + raise ValueError("minibatch_size must be divisible by bptt_horizon") + + self.batch_size = batch_size + self.bptt_horizon = bptt_horizon + self.minibatch_size = minibatch_size + self.device = device + self.sort_keys = [] + self.ptr = 0 + self.step = 0 + + @property + def full(self): + return self.ptr >= self.batch_size + + def store( + self, + obs: torch.Tensor, + value: torch.Tensor, + action: torch.Tensor, + logprob: torch.Tensor, + reward: torch.Tensor, + done: torch.Tensor, + env_id: torch.Tensor, + mask: torch.Tensor, + ): + # Mask learner and Ensure indices do not exceed batch size + ptr = self.ptr + indices = torch.where(mask)[0].numpy()[: self.batch_size - ptr] + end = ptr + len(indices) + + self.obs[ptr:end] = obs.to(self.obs.device)[indices] + self.values_np[ptr:end] = value.cpu().numpy()[indices] + self.actions_np[ptr:end] = action[indices] + self.logprobs_np[ptr:end] = logprob.cpu().numpy()[indices] + self.rewards_np[ptr:end] = reward.cpu().numpy()[indices] + self.dones_np[ptr:end] = done.cpu().numpy()[indices] + self.sort_keys.extend([(env_id[i], self.step) for i in indices]) + self.ptr = end + self.step += 1 + + def sort_training_data(self): + idxs = np.asarray(sorted(range(len(self.sort_keys)), key=self.sort_keys.__getitem__)) + self.b_idxs_obs = ( + torch.as_tensor( + idxs.reshape( + self.minibatch_rows, self.num_minibatches, self.bptt_horizon + ).transpose(1, 0, -1) + ) + .to(self.obs.device) + .long() + ) + self.b_idxs = self.b_idxs_obs.to(self.device) + self.b_idxs_flat = self.b_idxs.reshape(self.num_minibatches, self.minibatch_size) + self.sort_keys = [] + self.ptr = 0 + self.step = 0 + return idxs + + def flatten_batch(self, advantages_np: np.ndarray): + advantages = torch.from_numpy(advantages_np).to(self.device) + b_idxs, b_flat = self.b_idxs, self.b_idxs_flat + self.b_actions = self.actions.to(self.device, non_blocking=True) + self.b_logprobs = self.logprobs.to(self.device, non_blocking=True) + self.b_dones = self.dones.to(self.device, non_blocking=True) + self.b_values = self.values.to(self.device, non_blocking=True) + self.b_advantages = ( + advantages.reshape(self.minibatch_rows, self.num_minibatches, self.bptt_horizon) + .transpose(0, 1) + .reshape(self.num_minibatches, self.minibatch_size) + ) + self.returns_np = advantages_np + self.values_np + self.b_obs = self.obs[self.b_idxs_obs] + self.b_actions = self.b_actions[b_idxs].contiguous() + self.b_logprobs = self.b_logprobs[b_idxs] + self.b_dones = self.b_dones[b_idxs] + self.b_values = self.b_values[b_flat] + self.b_returns = self.b_advantages + self.b_values + + +ROUND_OPEN = rich.box.Box( + "╭──╮\n" # noqa: F401 + "│ │\n" + "│ │\n" + "│ │\n" + "│ │\n" + "│ │\n" + "│ │\n" + "╰──╯\n" +) + +c1 = "[bright_cyan]" +c2 = "[white]" +c3 = "[cyan]" +b1 = "[bright_cyan]" +b2 = "[bright_white]" + + +def abbreviate(num): + if num < 1e3: + return f"{b2}{num:.0f}" + elif num < 1e6: + return f"{b2}{num/1e3:.1f}{c2}k" + elif num < 1e9: + return f"{b2}{num/1e6:.1f}{c2}m" + elif num < 1e12: + return f"{b2}{num/1e9:.1f}{c2}b" + else: + return f"{b2}{num/1e12:.1f}{c2}t" + + +def duration(seconds): + seconds = int(seconds) + h = seconds // 3600 + m = (seconds % 3600) // 60 + s = seconds % 60 + return ( + f"{b2}{h}{c2}h {b2}{m}{c2}m {b2}{s}{c2}s" + if h + else f"{b2}{m}{c2}m {b2}{s}{c2}s" + if m + else f"{b2}{s}{c2}s" + ) + + +def fmt_perf(name, time, uptime): + percent = 0 if uptime == 0 else int(100 * time / uptime - 1e-5) + return f"{c1}{name}", duration(time), f"{b2}{percent:2d}%" + + +# TODO: Add env name to print_dashboard +def print_dashboard( + env_name: str, + utilization: Utilization, + global_step: int, + epoch: int, + profile: Profile, + losses: Losses, + stats, + msg: str, + clear: bool = False, + max_stats=None, +): + if not max_stats: + max_stats = [0] + console = Console() + if clear: + console.clear() + + dashboard = Table(box=ROUND_OPEN, expand=True, show_header=False, border_style="bright_cyan") + + table = Table(box=None, expand=True, show_header=False) + dashboard.add_row(table) + cpu_percent = np.mean(utilization.cpu_util) + dram_percent = np.mean(utilization.cpu_mem) + gpu_percent = np.mean(utilization.gpu_util) + vram_percent = np.mean(utilization.gpu_mem) + table.add_column(justify="left", width=30) + table.add_column(justify="center", width=12) + table.add_column(justify="center", width=12) + table.add_column(justify="center", width=13) + table.add_column(justify="right", width=13) + table.add_row( + f":blowfish: {c1}PufferLib {b2}1.0.0", + f"{c1}CPU: {c3}{cpu_percent:.1f}%", + f"{c1}GPU: {c3}{gpu_percent:.1f}%", + f"{c1}DRAM: {c3}{dram_percent:.1f}%", + f"{c1}VRAM: {c3}{vram_percent:.1f}%", + ) + + s = Table(box=None, expand=True) + s.add_column(f"{c1}Summary", justify="left", vertical="top", width=16) + s.add_column(f"{c1}Value", justify="right", vertical="top", width=8) + s.add_row(f"{c2}Environment", f"{b2}{env_name}") + s.add_row(f"{c2}Agent Steps", abbreviate(global_step)) + s.add_row(f"{c2}SPS", abbreviate(profile.SPS)) + s.add_row(f"{c2}Epoch", abbreviate(epoch)) + s.add_row(f"{c2}Uptime", duration(profile.uptime)) + s.add_row(f"{c2}Remaining", duration(profile.remaining)) + + p = Table(box=None, expand=True, show_header=False) + p.add_column(f"{c1}Performance", justify="left", width=10) + p.add_column(f"{c1}Time", justify="right", width=8) + p.add_column(f"{c1}%", justify="right", width=4) + p.add_row(*fmt_perf("Evaluate", profile.eval_time, profile.uptime)) + p.add_row(*fmt_perf(" Forward", profile.eval_forward_time, profile.uptime)) + p.add_row(*fmt_perf(" Env", profile.env_time, profile.uptime)) + p.add_row(*fmt_perf(" Misc", profile.eval_misc_time, profile.uptime)) + p.add_row(*fmt_perf("Train", profile.train_time, profile.uptime)) + p.add_row(*fmt_perf(" Forward", profile.train_forward_time, profile.uptime)) + p.add_row(*fmt_perf(" Learn", profile.learn_time, profile.uptime)) + p.add_row(*fmt_perf(" Misc", profile.train_misc_time, profile.uptime)) + + l = Table( # noqa: E741 + box=None, + expand=True, + ) + l.add_column(f"{c1}Losses", justify="left", width=16) + l.add_column(f"{c1}Value", justify="right", width=8) + for metric, value in losses.__dict__.items(): + l.add_row(f"{c2}{metric}", f"{b2}{value:.3f}") + + monitor = Table(box=None, expand=True, pad_edge=False) + monitor.add_row(s, p, l) + dashboard.add_row(monitor) + + table = Table(box=None, expand=True, pad_edge=False) + dashboard.add_row(table) + left = Table(box=None, expand=True) + right = Table(box=None, expand=True) + table.add_row(left, right) + left.add_column(f"{c1}User Stats", justify="left", width=20) + left.add_column(f"{c1}Value", justify="right", width=10) + right.add_column(f"{c1}User Stats", justify="left", width=20) + right.add_column(f"{c1}Value", justify="right", width=10) + i = 0 + for metric, value in stats.items(): + try: # Discard non-numeric values + int(value) + except: # noqa: E722 + continue + + u = left if i % 2 == 0 else right + u.add_row(f"{c2}{metric}", f"{b2}{value:.3f}") + i += 1 + + for i in range(max_stats[0] - i): + u = left if i % 2 == 0 else right + u.add_row("", "") + + max_stats[0] = max(max_stats[0], i) + + table = Table(box=None, expand=True, pad_edge=False) + dashboard.add_row(table) + table.add_row(f" {c1}Message: {c2}{msg}") + + with console.capture() as capture: + console.print(dashboard) + + print("\033[0;0H" + capture.get()) diff --git a/pokemonred_puffer/environment.py b/pokemonred_puffer/environment.py index 3af4b2d..016cef0 100644 --- a/pokemonred_puffer/environment.py +++ b/pokemonred_puffer/environment.py @@ -73,8 +73,6 @@ class RedGymEnv(Env): def __init__(self, env_config: pufferlib.namespace): # TODO: Dont use pufferlib.namespace. It seems to confuse __init__ self.video_dir = Path(env_config.video_dir) - self.session_path = Path(env_config.session_path) - self.video_path = self.video_dir / self.session_path self.save_final_state = env_config.save_final_state self.print_rewards = env_config.print_rewards self.headless = env_config.headless @@ -104,6 +102,7 @@ def __init__(self, env_config: pufferlib.namespace): self.auto_pokeflute = env_config.auto_pokeflute self.infinite_money = env_config.infinite_money self.use_global_map = env_config.use_global_map + self.save_state = env_config.save_state self.action_space = ACTION_SPACE # Obs space-related. TODO: avoid hardcoding? @@ -307,10 +306,13 @@ def reset(self, seed: Optional[int] = None, options: Optional[dict[str, Any]] = self.total_reward = sum([val for _, val in self.progress_reward.items()]) self.first = False - state = io.BytesIO() - self.pyboy.save_state(state) - state.seek(0) - return self._get_obs(), {"state": state.read()} + infos = {} + if self.save_state: + state = io.BytesIO() + self.pyboy.save_state(state) + state.seek(0) + infos |= {"state": state.read()} + return self._get_obs(), infos def init_mem(self): # Maybe I should preallocate a giant matrix for all map ids @@ -554,7 +556,7 @@ def step(self, action): self.pokecenters[self.read_m("wLastBlackoutMap")] = 1 info = {} - if self.get_events_sum() > self.max_event_rew: + if self.save_state and self.get_events_sum() > self.max_event_rew: state = io.BytesIO() self.pyboy.save_state(state) state.seek(0) @@ -1061,6 +1063,8 @@ def disable_wild_encounter_hook(self, *args, **kwargs): def agent_stats(self, action): levels = [self.read_m(f"wPartyMon{i+1}Level") for i in range(self.read_m("wPartyCount"))] badges = self.read_m("wObtainedBadges") + explore_map = self.explore_map + explore_map[explore_map > 0] = 1 return { "stats": { "step": self.step_count + self.reset_count * self.max_steps, @@ -1111,7 +1115,7 @@ def agent_stats(self, action): | {f"badge_{i+1}": bool(badges & (1 << i)) for i in range(8)}, "reward": self.get_game_state_reward(), "reward/reward_sum": sum(self.get_game_state_reward().values()), - "pokemon_exploration_map": self.explore_map, + "pokemon_exploration_map": explore_map, "cut_exploration_map": self.cut_explore_map, } @@ -1157,6 +1161,8 @@ def update_seen_coords(self): if not (self.read_m("wd736") & 0b1000_0000): x_pos, y_pos, map_n = self.get_game_coords() self.seen_coords[(x_pos, y_pos, map_n)] = 1 + # TODO: Turn into a wrapper? + self.explore_map[self.explore_map > 0] = 0.5 self.explore_map[local_to_global(y_pos, x_pos, map_n)] = 1 # self.seen_global_coords[local_to_global(y_pos, x_pos, map_n)] = 1 self.seen_map_ids[map_n] = 1 diff --git a/pokemonred_puffer/eval.py b/pokemonred_puffer/eval.py index 9d290da..f61fb1b 100644 --- a/pokemonred_puffer/eval.py +++ b/pokemonred_puffer/eval.py @@ -3,7 +3,6 @@ import cv2 import matplotlib.colors as mcolors import numpy as np -import torch KANTO_MAP_PATH = os.path.join(os.path.dirname(__file__), "kanto_map_dsv.png") BACKGROUND = np.array(cv2.imread(KANTO_MAP_PATH)) @@ -36,7 +35,3 @@ def make_pokemon_red_overlay(counts: np.ndarray): render = np.clip(render, 0, 255).astype(np.uint8) return render - - -if torch.cuda.is_available(): - make_pokemon_red_overlay = torch.compile(make_pokemon_red_overlay) diff --git a/pokemonred_puffer/policies/multi_convolutional.py b/pokemonred_puffer/policies/multi_convolutional.py index f65ec29..37cfdc2 100644 --- a/pokemonred_puffer/policies/multi_convolutional.py +++ b/pokemonred_puffer/policies/multi_convolutional.py @@ -1,17 +1,14 @@ +import pufferlib.emulation +import pufferlib.models +import pufferlib.pytorch import torch from torch import nn -import pufferlib.models -from pufferlib.emulation import unpack_batched_obs - from pokemonred_puffer.data.items import Items from pokemonred_puffer.environment import PIXEL_VALUES -unpack_batched_obs = torch.compiler.disable(unpack_batched_obs) # Because torch.nn.functional.one_hot cannot be traced by torch as of 2.2.0 - - def one_hot(tensor, num_classes): index = torch.arange(0, num_classes, device=tensor.device) return (tensor.view([*tensor.shape, 1]) == index.view([1] * tensor.ndim + [num_classes])).to( @@ -19,21 +16,25 @@ def one_hot(tensor, num_classes): ) -class RecurrentMultiConvolutionalWrapper(pufferlib.models.RecurrentWrapper): +class MultiConvolutionalRNN(pufferlib.models.LSTMWrapper): def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): super().__init__(env, policy, input_size, hidden_size, num_layers) -class MultiConvolutionalPolicy(pufferlib.models.Policy): +# We dont inherit from the pufferlib convolutional because we wont be able +# to easily call its __init__ due to our usage of lazy layers +# All that really means is a slightly different forward +class MultiConvolutionalPolicy(nn.Module): def __init__( self, - env, - hidden_size=512, + env: pufferlib.emulation.GymnasiumPufferEnv, + hidden_size: int = 512, channels_last: bool = True, downsample: int = 1, ): - super().__init__(env) - self.num_actions = self.action_space.n + super().__init__() + self.dtype = pufferlib.pytorch.nativize_dtype(env.emulated) + self.num_actions = env.single_action_space.n self.channels_last = channels_last self.downsample = downsample self.screen_network = nn.Sequential( @@ -45,15 +46,8 @@ def __init__( nn.ReLU(), nn.Flatten(), ) - self.global_map_network = nn.Sequential( - nn.LazyConv2d(32, 8, stride=4), - nn.ReLU(), - nn.LazyConv2d(64, 4, stride=2), - nn.ReLU(), - nn.LazyConv2d(64, 3, stride=1), - nn.ReLU(), - nn.Flatten(), - ) + # if channels_last: + # self.screen_network = self.screen_network.to(memory_format=torch.channels_last) self.encode_linear = nn.Sequential( nn.LazyLinear(hidden_size), @@ -66,6 +60,23 @@ def __init__( self.two_bit = env.unwrapped.env.two_bit self.use_global_map = env.unwrapped.env.use_global_map + if self.use_global_map: + self.global_map_network = nn.Sequential( + nn.LazyConv2d(32, 8, stride=4), + nn.ReLU(), + nn.LazyConv2d(64, 4, stride=2), + nn.ReLU(), + nn.LazyConv2d(64, 3, stride=1), + nn.ReLU(), + nn.Flatten(), + nn.LazyLinear(480), + nn.ReLU(), + ) + # if channels_last: + # self.global_map_network = self.global_map_network.to( + # memory_format=torch.channels_last + # ) + self.register_buffer( "screen_buckets", torch.tensor(PIXEL_VALUES, dtype=torch.uint8), persistent=False ) @@ -91,8 +102,14 @@ def __init__( item_count, int(item_count**0.25 + 1), dtype=torch.float32 ) + def forward(self, observations): + hidden, lookup = self.encode_observations(observations) + actions, value = self.decode_actions(hidden, lookup) + return actions, value + def encode_observations(self, observations): - observations = unpack_batched_obs(observations, self.unflatten_context) + observations = observations.type(torch.uint8) # Undo bad cleanrl cast + observations = pufferlib.pytorch.nativize_tensor(observations, self.dtype) screen = observations["screen"] visited_mask = observations["visited_mask"] @@ -140,8 +157,10 @@ def encode_observations(self, observations): image_observation = torch.cat((screen, visited_mask), dim=-1) if self.channels_last: image_observation = image_observation.permute(0, 3, 1, 2) + # image_observation = image_observation.to( memory_format=torch.channels_last) if self.use_global_map: global_map = global_map.permute(0, 3, 1, 2) + # global_map = global_map.to(memory_format=torch.channels_last) if self.downsample > 1: image_observation = image_observation[:, :, :: self.downsample, :: self.downsample] diff --git a/pokemonred_puffer/profile.py b/pokemonred_puffer/profile.py new file mode 100644 index 0000000..2332d8c --- /dev/null +++ b/pokemonred_puffer/profile.py @@ -0,0 +1,112 @@ +from collections import deque +from threading import Thread +import time + +import psutil +import torch + +import pufferlib.utils + + +class Profile: + SPS: ... = 0 + uptime: ... = 0 + remaining: ... = 0 + eval_time: ... = 0 + env_time: ... = 0 + eval_forward_time: ... = 0 + eval_misc_time: ... = 0 + train_time: ... = 0 + train_forward_time: ... = 0 + learn_time: ... = 0 + train_misc_time: ... = 0 + + def __init__(self): + self.start = time.time() + self.env = pufferlib.utils.Profiler() + self.eval_forward = pufferlib.utils.Profiler() + self.eval_misc = pufferlib.utils.Profiler() + self.train_forward = pufferlib.utils.Profiler() + self.learn = pufferlib.utils.Profiler() + self.train_misc = pufferlib.utils.Profiler() + self.prev_steps = 0 + + def __iter__(self): + yield "SPS", self.SPS + yield "uptime", self.uptime + yield "remaining", self.remaining + yield "eval_time", self.eval_time + yield "env_time", self.env_time + yield "eval_forward_time", self.eval_forward_time + yield "eval_misc_time", self.eval_misc_time + yield "train_time", self.train_time + yield "train_forward_time", self.train_forward_time + yield "learn_time", self.learn_time + yield "train_misc_time", self.train_misc_time + + @property + def epoch_time(self): + return self.train_time + self.eval_time + + def update(self, data, interval_s=1): + global_step = data.global_step + if global_step == 0: + return True + + uptime = time.time() - self.start + if uptime - self.uptime < interval_s: + return False + + self.SPS = (global_step - self.prev_steps) / (uptime - self.uptime) + self.prev_steps = global_step + self.uptime = uptime + + self.remaining = (data.config.total_timesteps - global_step) / self.SPS + self.eval_time = data._timers["evaluate"].elapsed + self.eval_forward_time = self.eval_forward.elapsed + self.env_time = self.env.elapsed + self.eval_misc_time = self.eval_misc.elapsed + self.train_time = data._timers["train"].elapsed + self.train_forward_time = self.train_forward.elapsed + self.learn_time = self.learn.elapsed + self.train_misc_time = self.train_misc.elapsed + return True + + +def make_losses(): + return pufferlib.namespace( + policy_loss=0, + value_loss=0, + entropy=0, + old_approx_kl=0, + approx_kl=0, + clipfrac=0, + explained_variance=0, + ) + + +class Utilization(Thread): + def __init__(self, delay=1, maxlen=20): + super().__init__() + self.cpu_mem = deque(maxlen=maxlen) + self.cpu_util = deque(maxlen=maxlen) + self.gpu_util = deque(maxlen=maxlen) + self.gpu_mem = deque(maxlen=maxlen) + + self.delay = delay + self.stopped = False + self.start() + + def run(self): + while not self.stopped: + self.cpu_util.append(psutil.cpu_percent()) + mem = psutil.virtual_memory() + self.cpu_mem.append(mem.active / mem.total) + if torch.cuda.is_available(): + self.gpu_util.append(torch.cuda.utilization()) + free, total = torch.cuda.mem_get_info() + self.gpu_mem.append(free / total) + time.sleep(self.delay) + + def stop(self): + self.stopped = True diff --git a/pokemonred_puffer/train.py b/pokemonred_puffer/train.py index 3a12a0b..02df802 100644 --- a/pokemonred_puffer/train.py +++ b/pokemonred_puffer/train.py @@ -1,48 +1,63 @@ import argparse +import functools import importlib -from multiprocessing import Queue -import pathlib +import os import sys -import time +from multiprocessing import Queue from types import ModuleType from typing import Any, Callable - -import gymnasium as gym -import torch -import wandb -import yaml +import uuid import pufferlib +import pufferlib.emulation +import pufferlib.frameworks.cleanrl +import pufferlib.postprocess import pufferlib.utils -from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL, rollout +import pufferlib.vector +import yaml + +import wandb +from pokemonred_puffer import cleanrl_puffer +from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL from pokemonred_puffer.environment import RedGymEnv from pokemonred_puffer.wrappers.async_io import AsyncWrapper +def make_policy(env, policy_name, args): + policy_module_name, policy_class_name = policy_name.split(".") + policy_module = importlib.import_module(f"pokemonred_puffer.policies.{policy_module_name}") + policy_class = getattr(policy_module, policy_class_name) + + policy = policy_class(env, **args.policies[policy_name]["policy"]) + if args.train.use_rnn: + rnn_config = args.policies[policy_name]["rnn"] + policy_class = getattr(policy_module, rnn_config["name"]) + policy = policy_class(env, policy, **rnn_config["args"]) + policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy) + else: + policy = pufferlib.frameworks.cleanrl.Policy(policy) + + return policy.to(args.train.device) + + # TODO: Replace with Pydantic or Spock parser -def load_from_config( - yaml_path: str | pathlib.Path, - wrappers: str, - policy: str, - reward: str, - debug: bool = False, -): - with open(yaml_path) as f: +def load_from_config(args: argparse.ArgumentParser): + with open(args["yaml"]) as f: config = yaml.safe_load(f) default_keys = ["env", "train", "policies", "rewards", "wrappers", "wandb"] defaults = {key: config.get(key, {}) for key in default_keys} # Package and subpackage (environment) configs - debug_config = config.get("debug", {}) if debug else {} + debug_config = config.get("debug", {}) if args["debug"] else {} # This is overly complicated. Clean it up. Or remove configs entirely # if we're gonna start creating an ersatz programming language. wrappers_config = {} - for wrapper in config["wrappers"][wrappers]: + for wrapper in config["wrappers"][args["wrappers_name"]]: for k, v in wrapper.items(): wrappers_config[k] = v - reward_config = config["rewards"][reward] - policy_config = config["policies"][policy] + reward_config = config["rewards"][args["reward_name"]] + policy_config = config["policies"][args["policy_name"]] combined_config = {} for key in default_keys: @@ -65,29 +80,28 @@ def load_from_config( def make_env_creator( wrapper_classes: list[tuple[str, ModuleType]], reward_class: RedGymEnv, + async_wrapper: bool = True, ) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]: def env_creator( env_config: pufferlib.namespace, wrappers_config: list[dict[str, Any]], reward_config: pufferlib.namespace, - async_config: dict[str, Queue], + async_config: dict[str, Queue] | None = None, ) -> pufferlib.emulation.GymnasiumPufferEnv: env = reward_class(env_config, reward_config) for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes): env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0])) - env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) - return pufferlib.emulation.GymnasiumPufferEnv( - env=env, postprocessor_cls=pufferlib.emulation.BasicPostprocessor - ) + if async_wrapper and async_config: + env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"]) + # env = pufferlib.postprocess.EpisodeStats(env) + return pufferlib.emulation.GymnasiumPufferEnv(env=env) return env_creator # Returns env_creator, agent_creator def setup_agent( - wrappers: list[str], - reward_name: str, - policy_name: str, + wrappers: list[str], reward_name: str, async_wrapper: bool = True ) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]: # TODO: Make this less dependent on the name of this repo and its file structure wrapper_classes = [ @@ -106,35 +120,9 @@ def setup_agent( importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name ) # NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class - env_creator = make_env_creator(wrapper_classes, reward_class) - - policy_module_name, policy_class_name = policy_name.split(".") - policy_module = importlib.import_module(f"pokemonred_puffer.policies.{policy_module_name}") - policy_class = getattr(policy_module, policy_class_name) - - def agent_creator(env: gym.Env, args: pufferlib.namespace): - policy = policy_class(env, **args.policies[policy_name]["policy"]) - if "recurrent" in args.policies[policy_name]: - recurrent_args = args.policies[policy_name]["recurrent"] - recurrent_class_name = recurrent_args["name"] - del recurrent_args["name"] - policy = getattr(policy_module, recurrent_class_name)(env, policy, **recurrent_args) - policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy) - else: - policy = pufferlib.frameworks.cleanrl.Policy(policy) - - if args.train.device == "cuda": - torch.set_float32_matmul_precision(args.train.float32_matmul_precision) - policy = policy.to(args.train.device, non_blocking=True) - if args.train.compile: - policy.get_value = torch.compile(policy.get_value, mode=args.train.compile_mode) - policy.get_action_and_value = torch.compile( - policy.get_action_and_value, mode=args.train.compile_mode - ) - - return policy + env_creator = make_env_creator(wrapper_classes, reward_class, async_wrapper) - return env_creator, agent_creator + return env_creator def update_args(args: argparse.Namespace): @@ -144,9 +132,9 @@ def update_args(args: argparse.Namespace): args.env.gb_path = args.rom_path if args.vectorization == "serial" or args.debug: - args.vectorization = pufferlib.vectorization.Serial + args.vectorization = pufferlib.vector.Serial elif args.vectorization == "multiprocessing": - args.vectorization = pufferlib.vectorization.Multiprocessing + args.vectorization = pufferlib.vector.Multiprocessing return args @@ -180,25 +168,45 @@ def init_wandb(args, resume=True): def train( args: pufferlib.namespace, env_creator: Callable, - agent_creator: Callable[[gym.Env, pufferlib.namespace], pufferlib.models.Policy], + wandb_client: wandb.wandb_sdk.wandb_run.Run | None, ): + vec = args.vectorization + if vec == "serial": + vec = pufferlib.vector.Serial + elif vec == "multiprocessing": + vec = pufferlib.vector.Multiprocessing + elif vec == "ray": + vec = pufferlib.vector.Ray + # TODO: Remove the +1 once the driver env doesn't permanently increase the env id env_send_queues = [Queue() for _ in range(args.train.num_envs + 1)] env_recv_queues = [Queue() for _ in range(args.train.num_envs + 1)] - with CleanPuffeRL( - config=args.train, - agent_creator=agent_creator, - agent_kwargs={"args": args}, - env_creator=env_creator, - env_creator_kwargs={ + + vecenv = pufferlib.vector.make( + env_creator, + env_kwargs={ "env_config": args.env, "wrappers_config": args.wrappers[args.wrappers_name], "reward_config": args.rewards[args.reward_name]["reward"], "async_config": {"send_queues": env_send_queues, "recv_queues": env_recv_queues}, }, - vectorization=args.vectorization, + num_envs=args.train.num_envs, + num_workers=args.train.num_workers, + batch_size=args.train.env_batch_size, + zero_copy=args.train.zero_copy, + backend=vec, + ) + policy = make_policy(vecenv.driver_env, args.policy_name, args) + + args.train.env = "Pokemon Red" + with CleanPuffeRL( exp_name=args.exp_name, - track=args.track, + config=args.train, + vecenv=vecenv, + policy=policy, + env_recv_queues=env_recv_queues, + env_send_queues=env_send_queues, + wandb_client=wandb_client, ) as trainer: while not trainer.done_training(): trainer.evaluate() @@ -227,7 +235,10 @@ def train( default="baseline", help="Wrappers to use _in order of instantiation_.", ) - parser.add_argument("--mode", type=str, default="train", choices=["train", "evaluate"]) + # TODO: Add evaluate + parser.add_argument( + "--mode", type=str, default="train", choices=["train", "autotune", "evaluate"] + ) parser.add_argument( "--eval-model-path", type=str, default=None, help="Path to model to evaluate" ) @@ -245,9 +256,7 @@ def train( clean_parser = argparse.ArgumentParser(parents=[parser]) args = parser.parse_known_args()[0].__dict__ - config = load_from_config( - args["yaml"], args["wrappers_name"], args["policy_name"], args["reward_name"], args["debug"] - ) + config = load_from_config(args) # Generate argparse menu from config # This is also a reason for Spock/Argbind/OmegaConf/pydantic-cli @@ -273,29 +282,42 @@ def train( args[name] = pufferlib.namespace(**args[name]) clean_parser.parse_args(sys.argv[1:]) args = update_args(args) + args.train.exp_id = f"pokemon-red-{str(uuid.uuid4())[:8]}" - env_creator, agent_creator = setup_agent( - args.wrappers[args.wrappers_name], args.reward_name, args.policy_name - ) + async_wrapper = args.train.async_wrapper + env_creator = setup_agent(args.wrappers[args.wrappers_name], args.reward_name, async_wrapper) + wandb_client = None if args.track: - args.exp_name = init_wandb(args).id - else: - args.exp_name = f"poke_{time.strftime('%Y%m%d_%H%M%S')}" - args.env.session_path = args.exp_name + wandb_client = init_wandb(args) if args.mode == "train": - train(args, env_creator, agent_creator) - - elif args.mode == "evaluate": - # TODO: check if this works - rollout( - env_creator=env_creator, - env_creator_kwargs={"env_config": args.env, "wrappers_config": args.wrappers}, - agent_creator=agent_creator, - agent_kwargs={"args": args}, - model_path=args.eval_model_path, - device=args.train.device, + train(args, env_creator, wandb_client) + elif args.mode == "autotune": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + pufferlib.vector.autotune( + functools.partial(env_creator, **env_kwargs), batch_size=args.train.env_batch_size ) - else: - raise ValueError("Mode must be one of train or evaluate") + elif args.mode == "evaluate": + env_kwargs = { + "env_config": args.env, + "wrappers_config": args.wrappers[args.wrappers_name], + "reward_config": args.rewards[args.reward_name]["reward"], + "async_config": {}, + } + try: + cleanrl_puffer.rollout( + env_creator, + env_kwargs, + agent_creator=make_policy, + agent_kwargs={"args": args}, + model_path=args.eval_model_path, + device=args.train.device, + ) + except KeyboardInterrupt: + os._exit(0) diff --git a/pokemonred_puffer/wrappers/async_io.py b/pokemonred_puffer/wrappers/async_io.py index 069d518..9fd2329 100644 --- a/pokemonred_puffer/wrappers/async_io.py +++ b/pokemonred_puffer/wrappers/async_io.py @@ -15,7 +15,7 @@ def __init__(self, env: RedGymEnv, send_queues: list[Queue], recv_queues: list[Q # Now we will spawn a thread that will listen for updates # and send back when the new state has been loaded # this is a slow process and should rarely happen. - self.thread = threading.Thread(target=self.update) + self.thread = threading.Thread(target=self.update, daemon=True) self.thread.start() # TODO: Figure out if there's a safe way to exit the thread diff --git a/pyproject.toml b/pyproject.toml index 35de646..f620214 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ dependencies = [ "numpy", "opencv-python", "pyboy>=2", - "pufferlib[cleanrl]>=0.7.3,<1.0.0", + "pufferlib[cleanrl]>=1.0.0", "scikit-image", - "torch>=2.1", + "torch>=2.3", "torchvision", "wandb", "websockets"