Skip to content

Commit

Permalink
Start refactoring for sweeps
Browse files Browse the repository at this point in the history
  • Loading branch information
thatguy11325 committed Oct 10, 2024
1 parent 58b3dd9 commit 55ebb2f
Show file tree
Hide file tree
Showing 9 changed files with 321 additions and 265 deletions.
5 changes: 2 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ debug:
num_envs: 1
envs_per_worker: 1
num_workers: 1
env_batch_size: 128
env_pool: True
env_batch_size: 20480
zero_copy: False
batch_size: 128
batch_size: 20480
minibatch_size: 4
batch_rows: 4
bptt_horizon: 2
Expand Down
9 changes: 2 additions & 7 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
def rollout(
env_creator,
env_kwargs,
agent_creator,
agent_kwargs,
model_path=None,
device="cuda",
):
Expand All @@ -54,12 +52,9 @@ def rollout(
except: # noqa: E722
env = pufferlib.vector.make(env_creator, env_kwargs=env_kwargs)

if model_path is None:
agent = agent_creator(env, **agent_kwargs)
else:
agent = torch.load(model_path, map_location=device)
agent = torch.load(model_path, map_location=device)

ob, info = env.reset()
ob, _ = env.reset()
driver = env.driver_env
os.system("clear")
state = None
Expand Down
19 changes: 13 additions & 6 deletions pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@

import mediapy as media
import numpy as np
from omegaconf import DictConfig, OmegaConf
from gymnasium import Env, spaces
from pyboy import PyBoy
from pyboy.utils import WindowEvent

import pufferlib
from pokemonred_puffer.data.elevators import NEXT_ELEVATORS
from pokemonred_puffer.data.events import (
EVENT_FLAGS_START,
Expand Down Expand Up @@ -94,8 +94,7 @@ class RedGymEnv(Env):
env_id = shared_memory.SharedMemory(create=True, size=4)
lock = Lock()

def __init__(self, env_config: pufferlib.namespace):
# TODO: Dont use pufferlib.namespace. It seems to confuse __init__
def __init__(self, env_config: DictConfig):
self.video_dir = Path(env_config.video_dir)
self.save_final_state = env_config.save_final_state
self.print_rewards = env_config.print_rewards
Expand All @@ -114,10 +113,11 @@ def __init__(self, env_config: pufferlib.namespace):
self.log_frequency = env_config.log_frequency
self.two_bit = env_config.two_bit
self.auto_flash = env_config.auto_flash
if isinstance(env_config.disable_wild_encounters, bool):
disable_wild_encounters = OmegaConf.to_object(env_config.disable_wild_encounters)
if isinstance(disable_wild_encounters, bool):
self.disable_wild_encounters = env_config.disable_wild_encounters
self.setup_disable_wild_encounters_maps = set([])
elif isinstance(env_config.disable_wild_encounters, list):
elif isinstance(disable_wild_encounters, list):
self.disable_wild_encounters = len(env_config.disable_wild_encounters) > 0
self.disable_wild_encounters_maps = {
MapIds[item].name for item in env_config.disable_wild_encounters
Expand Down Expand Up @@ -225,7 +225,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.observation_space = spaces.Dict(obs_dict)

self.pyboy = PyBoy(
env_config.gb_path,
str(env_config.gb_path),
debug=False,
no_input=False,
window="null" if self.headless else "SDL2",
Expand Down Expand Up @@ -713,6 +713,13 @@ def step(self, action):
self.use_surf = 1
info = {}

# self.memory[0xd16c] = 0xFF
self.pyboy.memory[0xD16D] = 0xFF
self.pyboy.memory[0xD188] = 0xFF
self.pyboy.memory[0xD189] = 0xFF
self.pyboy.memory[0xD18A] = 0xFF
self.pyboy.memory[0xD18B] = 0xFF

required_events = self.get_required_events()
required_items = self.get_required_items()
new_required_events = required_events - self.required_events
Expand Down
19 changes: 11 additions & 8 deletions pokemonred_puffer/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from threading import Thread
import time

from omegaconf import OmegaConf
import psutil
import torch

Expand Down Expand Up @@ -74,14 +75,16 @@ def update(self, data, interval_s=1):


def make_losses():
return pufferlib.namespace(
policy_loss=0,
value_loss=0,
entropy=0,
old_approx_kl=0,
approx_kl=0,
clipfrac=0,
explained_variance=0,
return OmegaConf.create(
dict(
policy_loss=0,
value_loss=0,
entropy=0,
old_approx_kl=0,
approx_kl=0,
clipfrac=0,
explained_variance=0,
)
)


Expand Down
4 changes: 2 additions & 2 deletions pokemonred_puffer/rewards/baseline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import pufferlib
from omegaconf import DictConfig

from pokemonred_puffer.data.events import REQUIRED_EVENTS
from pokemonred_puffer.data.items import REQUIRED_ITEMS, USEFUL_ITEMS
Expand All @@ -15,7 +15,7 @@


class BaselineRewardEnv(RedGymEnv):
def __init__(self, env_config: pufferlib.namespace, reward_config: pufferlib.namespace):
def __init__(self, env_config: DictConfig, reward_config: DictConfig):
super().__init__(env_config)
self.reward_config = reward_config
self.max_event_rew = 0
Expand Down
Loading

0 comments on commit 55ebb2f

Please sign in to comment.