Skip to content

Commit

Permalink
Pickle-able env creator. Fix Env ID on macos
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Apr 29, 2024
1 parent 78e1f2c commit 316fcc9
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 51 deletions.
2 changes: 1 addition & 1 deletion config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ debug:
batch_size: 16
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
total_timesteps: 100_000_000_000
save_checkpoint: True
checkpoint_interval: 4
save_overlay: True
Expand Down
12 changes: 5 additions & 7 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def rollout(
step = 0
return_val = 0

ob = torch.tensor(ob, device=device).unsqueeze(0)
ob = torch.tensor(ob).unsqueeze(0).to(device)
with torch.no_grad():
if hasattr(agent, "lstm"):
action, _, _, _, state = agent.get_action_and_value(ob, state)
action, _, _, _, state = agent(ob, state)
else:
action, _, _, _ = agent.get_action_and_value(ob)
action, _, _, _ = agent(ob)

ob, reward, terminal, truncated, _ = env.step(action[0].item())
return_val += reward
Expand Down Expand Up @@ -255,7 +255,6 @@ def __init__(
self.global_step = resume_state.get("global_step", 0)
self.agent_step = resume_state.get("agent_step", 0)
self.update = resume_state.get("update", 0)
self.lr_update = resume_state.get("lr_update", 0)

self.optimizer = optim.Adam(self.agent.parameters(), lr=config.learning_rate, eps=1e-5)
self.opt_state = resume_state.get("optimizer_state_dict", None)
Expand Down Expand Up @@ -318,7 +317,7 @@ def __init__(
)

self.sort_keys = []
self.learning_rate = (config.learning_rate,)
self.learning_rate = config.learning_rate
self.losses = Losses()
self.performance = Performance()

Expand Down Expand Up @@ -558,7 +557,7 @@ def train(self):
train_profiler.start()

if config.anneal_lr:
frac = 1.0 - (self.lr_update - 1.0) / self.total_updates
frac = 1.0 - (self.update - 1.0) / self.total_updates
lrnow = frac * config.learning_rate
self.optimizer.param_groups[0]["lr"] = lrnow

Expand Down Expand Up @@ -712,7 +711,6 @@ def train(self):
print_dashboard(self.stats, self.init_performance, self.performance)

self.update += 1
self.lr_update += 1

if self.update % config.checkpoint_interval == 0 or self.done_training():
self.save_checkpoint()
Expand Down
20 changes: 3 additions & 17 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import random
from collections import deque
from multiprocessing import Lock, shared_memory
from pathlib import Path
from typing import Any, Iterable, Optional
import uuid
Expand Down Expand Up @@ -133,12 +132,12 @@


# TODO: Make global map usage a configuration parameter
class RedGymEnv(Env):
env_id = shared_memory.SharedMemory(create=True, size=4)
lock = Lock()


class RedGymEnv(Env):
def __init__(self, env_config: pufferlib.namespace):
# TODO: Dont use pufferlib.namespace. It seems to confuse __init__
self.env_id = env_config.env_id
self.video_dir = Path(env_config.video_dir)
self.session_path = Path(env_config.session_path)
self.video_path = self.video_dir / self.session_path
Expand Down Expand Up @@ -229,19 +228,6 @@ def __init__(self, env_config: pufferlib.namespace):
self.screen = self.pyboy.screen

self.first = True
with RedGymEnv.lock:
env_id = (
(int(RedGymEnv.env_id.buf[0]) << 24)
+ (int(RedGymEnv.env_id.buf[1]) << 16)
+ (int(RedGymEnv.env_id.buf[2]) << 8)
+ (int(RedGymEnv.env_id.buf[3]))
)
self.env_id = env_id
env_id += 1
RedGymEnv.env_id.buf[0] = (env_id >> 24) & 0xFF
RedGymEnv.env_id.buf[1] = (env_id >> 16) & 0xFF
RedGymEnv.env_id.buf[2] = (env_id >> 8) & 0xFF
RedGymEnv.env_id.buf[3] = (env_id) & 0xFF

def register_hooks(self):
self.pyboy.hook_register(None, "DisplayStartMenu", self.start_menu_hook, None)
Expand Down
5 changes: 0 additions & 5 deletions pokemonred_puffer/policies/multi_convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
from torch import nn

import pufferlib.models
from pufferlib.emulation import unpack_batched_obs

from pokemonred_puffer.environment import PIXEL_VALUES

unpack_batched_obs = torch.compiler.disable(unpack_batched_obs)

# Because torch.nn.functional.one_hot cannot be traced by torch as of 2.2.0


Expand Down Expand Up @@ -71,8 +68,6 @@ def __init__(
)

def encode_observations(self, observations):
observations = unpack_batched_obs(observations, self.unflatten_context)

screen = observations["screen"]
visited_mask = observations["visited_mask"]
global_map = observations["global_map"]
Expand Down
54 changes: 34 additions & 20 deletions pokemonred_puffer/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import importlib
from multiprocessing import Queue
from multiprocessing import Manager, Queue
import multiprocessing
import pathlib
import sys
import time
Expand All @@ -14,6 +15,8 @@

import pufferlib
import pufferlib.utils
import pufferlib.postprocess
import pufferlib.emulation
from pokemonred_puffer.cleanrl_puffer import CleanPuffeRL, rollout
from pokemonred_puffer.environment import RedGymEnv
from pokemonred_puffer.wrappers.async_io import AsyncWrapper
Expand Down Expand Up @@ -62,32 +65,42 @@ def load_from_config(
return pufferlib.namespace(**combined_config)


def make_env_creator(
wrapper_classes: list[tuple[str, ModuleType]],
reward_class: RedGymEnv,
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
def env_creator(
class EnvCreator:
def __init__(
self,
wrapper_classes: list[tuple[str, ModuleType]],
reward_class: RedGymEnv,
env_id: multiprocessing.Value, # technically a ValueProxy
lock: multiprocessing.Lock,
):
self.wrapper_classes = wrapper_classes
self.reward_class = reward_class
self.env_id = env_id
self.lock = lock

def __call__(
self,
env_config: pufferlib.namespace,
wrappers_config: list[dict[str, Any]],
reward_config: pufferlib.namespace,
async_config: dict[str, Queue],
) -> pufferlib.emulation.GymnasiumPufferEnv:
env = reward_class(env_config, reward_config)
for cfg, (_, wrapper_class) in zip(wrappers_config, wrapper_classes):
with self.lock:
env_id = self.env_id.value
self.env_id.value += 1
print(f"Creating environment {env_id}")
env_config.env_id = env_id
env = self.reward_class(env_config, reward_config)
for cfg, (_, wrapper_class) in zip(wrappers_config, self.wrapper_classes):
env = wrapper_class(env, pufferlib.namespace(**[x for x in cfg.values()][0]))
env = AsyncWrapper(env, async_config["send_queues"], async_config["recv_queues"])
return pufferlib.emulation.GymnasiumPufferEnv(
env=env, postprocessor_cls=pufferlib.emulation.BasicPostprocessor
)

return env_creator
env = pufferlib.postprocess.EpisodeStats(env)
return pufferlib.emulation.GymnasiumPufferEnv(env=env)


# Returns env_creator, agent_creator
def setup_agent(
wrappers: list[str],
reward_name: str,
policy_name: str,
wrappers: list[str], reward_name: str, policy_name: str, manager: multiprocessing.Manager
) -> Callable[[pufferlib.namespace, pufferlib.namespace], pufferlib.emulation.GymnasiumPufferEnv]:
# TODO: Make this less dependent on the name of this repo and its file structure
wrapper_classes = [
Expand All @@ -106,7 +119,7 @@ def setup_agent(
importlib.import_module(f"pokemonred_puffer.rewards.{reward_module}"), reward_class_name
)
# NOTE: This assumes reward_module has RewardWrapper(RedGymEnv) class
env_creator = make_env_creator(wrapper_classes, reward_class)
env_creator = EnvCreator(wrapper_classes, reward_class, manager.Value("b", 0), manager.Lock())

policy_module_name, policy_class_name = policy_name.split(".")
policy_module = importlib.import_module(f"pokemonred_puffer.policies.{policy_module_name}")
Expand Down Expand Up @@ -143,7 +156,7 @@ def update_args(args: argparse.Namespace):
args.track = args.track
args.env.gb_path = args.rom_path

if args.vectorization == "serial" or args.debug:
if args.vectorization == "serial":
args.vectorization = pufferlib.vectorization.Serial
elif args.vectorization == "multiprocessing":
args.vectorization = pufferlib.vectorization.Multiprocessing
Expand Down Expand Up @@ -179,7 +192,7 @@ def init_wandb(args, resume=True):

def train(
args: pufferlib.namespace,
env_creator: Callable,
env_creator: EnvCreator,
agent_creator: Callable[[gym.Env, pufferlib.namespace], pufferlib.models.Policy],
):
# TODO: Remove the +1 once the driver env doesn't permanently increase the env id
Expand Down Expand Up @@ -274,8 +287,9 @@ def train(
clean_parser.parse_args(sys.argv[1:])
args = update_args(args)

manager = Manager()
env_creator, agent_creator = setup_agent(
args.wrappers[args.wrappers_name], args.reward_name, args.policy_name
args.wrappers[args.wrappers_name], args.reward_name, args.policy_name, manager
)

if args.track:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
"opencv-python",
"numpy",
"pyboy>=2",
"pufferlib[cleanrl]>=0.7.3",
"pufferlib[cleanrl] @ git+https://github.com/PufferAI/[email protected]",
"torch>=2.1",
"torchvision",
"wandb"
Expand Down

0 comments on commit 316fcc9

Please sign in to comment.