Skip to content

Commit

Permalink
Fix debug env, dont run utilizaiton if verbose=False, fixed slowdown …
Browse files Browse the repository at this point in the history
…over time
  • Loading branch information
thatguy11325 committed Jun 23, 2024
1 parent 7963802 commit 97d6c6b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
9 changes: 6 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ debug:
device: cpu
compile: False
compile_mode: default
num_envs: 2
num_envs: 1
num_workers: 1
env_batch_size: 16
env_batch_size: 4
env_pool: True
zero_copy: False
batch_size: 128
minibatch_size: 128
batch_rows: 4
bptt_horizon: 2
total_timesteps: 100_000_000
Expand Down Expand Up @@ -63,6 +65,7 @@ env:
auto_pokeflute: True
infinite_money: True
use_global_map: False
save_state: False


train:
Expand All @@ -73,7 +76,7 @@ train:
compile_mode: "reduce-overhead"
float32_matmul_precision: "high"
total_timesteps: 100_000_000_000
batch_size: 65536
batch_size: 131072 # 65536
minibatch_size: 2048
learning_rate: 2.0e-4
anneal_lr: False
Expand Down
44 changes: 31 additions & 13 deletions pokemonred_puffer/cleanrl_puffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class CleanPuffeRL:
wandb_client: wandb.wandb_sdk.wandb_run.Run | None = None
profile: Profile = field(default_factory=lambda: Profile())
losses: Losses = field(default_factory=lambda: Losses())
utilization: Utilization = field(default_factory=lambda: Utilization())
global_step: int = 0
epoch: int = 0
stats: dict = field(default_factory=lambda: {})
Expand All @@ -155,6 +154,8 @@ def __post_init__(self):
clear=True,
)

self.utilization = Utilization()

self.vecenv.async_reset(self.config.seed)
obs_shape = self.vecenv.single_observation_space.shape
obs_dtype = self.vecenv.single_observation_space.dtype
Expand Down Expand Up @@ -199,6 +200,14 @@ def __post_init__(self):

@pufferlib.utils.profile
def evaluate(self):
# Clear all self.infos except for the state
for k in list(self.infos.keys()):
if k != "state":
del self.infos[k]
elif len(self.infos["state"]) > 0:
# just in case
self.infos["state"] = self.infos["state"][-1]

# now for a tricky bit:
# if we have swarm_frequency, we will take the top swarm_keep_pct envs and evenly distribute
# their states to the bottom 90%.
Expand All @@ -208,6 +217,7 @@ def evaluate(self):
and hasattr(self.config, "swarm_keep_pct")
and self.epoch % self.config.swarm_frequency == 0
and "reward/event" in self.infos
and "state" in self.infos
):
# collect the top swarm_keep_pct % of envs
largest = [
Expand Down Expand Up @@ -275,11 +285,17 @@ def evaluate(self):
actions = actions.cpu().numpy()
mask = torch.as_tensor(mask) # * policy.mask)
o = o if self.config.cpu_offload else o_device
if self.config.num_workers == 1:
actions = np.expand_dims(actions, 0)
logprob = logprob.unsqueeze(0)
self.experience.store(o, value, actions, logprob, r, d, env_id, mask)

for i in info:
for k, v in pufferlib.utils.unroll_nested_dict(i):
self.infos[k].append(v)
if k == "state":
self.infos[k] = [v]
else:
self.infos[k].append(v)

with self.profile.env:
self.vecenv.send(actions)
Expand Down Expand Up @@ -441,16 +457,17 @@ def train(self):

done_training = self.global_step >= self.config.total_timesteps
if self.profile.update(self) or done_training:
print_dashboard(
self.config.env,
self.utilization,
self.global_step,
self.epoch,
self.profile,
self.losses,
self.stats,
self.msg,
)
if self.config.verbose:
print_dashboard(
self.config.env,
self.utilization,
self.global_step,
self.epoch,
self.profile,
self.losses,
self.stats,
self.msg,
)

if (
self.wandb_client is not None
Expand All @@ -475,7 +492,8 @@ def train(self):

def close(self):
self.vecenv.close()
self.utilization.stop()
if self.config.verbose:
self.utilization.stop()

if self.wandb_client is not None:
artifact_name = f"{self.exp_name}_model"
Expand Down
3 changes: 2 additions & 1 deletion pokemonred_puffer/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, env_config: pufferlib.namespace):
self.auto_pokeflute = env_config.auto_pokeflute
self.infinite_money = env_config.infinite_money
self.use_global_map = env_config.use_global_map
self.save_state = env_config.save_state
self.action_space = ACTION_SPACE

# Obs space-related. TODO: avoid hardcoding?
Expand Down Expand Up @@ -552,7 +553,7 @@ def step(self, action):
self.pokecenters[self.read_m("wLastBlackoutMap")] = 1
info = {}

if self.get_events_sum() > self.max_event_rew:
if self.save_state and self.get_events_sum() > self.max_event_rew:
state = io.BytesIO()
self.pyboy.save_state(state)
state.seek(0)
Expand Down

0 comments on commit 97d6c6b

Please sign in to comment.