Skip to content

Commit

Permalink
pufferbox5
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Mar 5, 2024
1 parent 5bf079a commit fd6c279
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 19 deletions.
15 changes: 9 additions & 6 deletions clean_pufferl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create(
if exp_name is None:
exp_name = str(uuid.uuid4())[:8]
# Base directory path
required_resources_dir = Path('/home/bet_adsorption_xinpw8/pokegym/pokegym') # Path('/home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym')
required_resources_dir = Path('/bet_adsorption_xinpw8/PufferLib/pokegym/pokegym') # Path('/home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym')
# Path for the required_resources directory
required_resources_path = required_resources_dir / "required_resources"
required_resources_path.mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -209,7 +209,7 @@ def create(
torch.zeros(shape, device=device),
torch.zeros(shape, device=device),
)
obs=torch.zeros(config.batch_size + 1, *obs_shape)
obs=torch.zeros(config.batch_size + 1, *obs_shape, pin_memory=True) # added , pin_memory=True)
actions=torch.zeros(config.batch_size + 1, *atn_shape, dtype=int)
logprobs=torch.zeros(config.batch_size + 1)
rewards=torch.zeros(config.batch_size + 1)
Expand Down Expand Up @@ -522,7 +522,7 @@ def train(data):
delta + config.gamma * config.gae_lambda * nextnonterminal * lastgaelam
)

data.b_obs = b_obs = torch.Tensor(data.obs_ary[b_idxs])
data.b_obs = b_obs = data.obs[b_idxs].to(data.device, non_blocking=True) # torch.Tensor(data.obs_ary[b_idxs])
b_actions = torch.Tensor(data.actions_ary[b_idxs]).to(data.device, non_blocking=True)
b_logprobs = torch.Tensor(data.logprobs_ary[b_idxs]).to(data.device, non_blocking=True)
b_dones = torch.Tensor(data.dones_ary[b_idxs]).to(data.device, non_blocking=True)
Expand All @@ -537,13 +537,16 @@ def train(data):
train_time = time.time()
pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], []

mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(data.device == "cuda"))
# COMMENTED OUT BET
# mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(data.device == "cuda"))

for epoch in range(config.update_epochs):
lstm_state = None
for mb in range(num_minibatches):
mb_obs_buffer.copy_(b_obs[mb], non_blocking=True)
mb_obs = mb_obs_buffer.to(data.device, non_blocking=True)
mb_obs = b_obs[mb]
# COMMENTED OUT BET
# mb_obs_buffer.copy_(b_obs[mb], non_blocking=True)
# mb_obs = mb_obs_buffer.to(data.device, non_blocking=True)

mb_actions = b_actions[mb].contiguous()
mb_values = b_values[mb].reshape(-1)
Expand Down
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pokemon_red:
save_final_state: True
print_rewards: True
headless: True
init_state: /home/bet_adsorption_xinpw8/pokegym/pokegym/save_state_dir/has_pokedex_nballs_noanim.state # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir/start_from_state_dir/has_pokedex_nballs_noanim.state
init_state: /bet_adsorption_xinpw8/PufferLib/pokegym/pokegym/save_state_dir/has_pokedex_nballs_noanim.state # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir/start_from_state_dir/has_pokedex_nballs_noanim.state
action_freq: 24
max_steps: 3072000 # 30720000 # Updated to match ep_length
early_stop: True
Expand All @@ -91,10 +91,10 @@ pokemon_red:
swap_button: True
restricted_start_menu: True # False
level_reward_badge_scale: 1.0
save_state_dir: /home/bet_adsorption_xinpw8/pokegym/pokegym/save_state_dir # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir
save_state_dir: /bet_adsorption_xinpw8/PufferLib/pokegym/pokegym/save_state_dir # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir
special_exploration_scale: 1.0
enable_item_manager: True # True
enable_stage_manager: False # True
enable_stage_manager: True # True
enable_item_purchaser: True # True
auto_skip_anim: True
auto_skip_anim_frames: 8
Expand Down
121 changes: 121 additions & 0 deletions config_jsuarez.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
train:
seed: 1
torch_deterministic: True
device: cuda
total_timesteps: 800_000_000 # superceded by pokemon_red package
learning_rate: 0.0003
num_steps: 128 # 128
anneal_lr: False # True
gamma: 0.999 # gamma annealing: first 10m steps 0.999, then 0.9996; might have to screw with gamma and steps to make it work
gae_lambda: 0.95
# num_minibatches: 4 # 4
update_epochs: 3 # 2 # 3 # superceded by pokemon_red package
norm_adv: True
clip_coef: 0.1
clip_vloss: True
ent_coef: 0.01
vf_coef: 0.5
max_grad_norm: 0.5
target_kl: ~

num_envs: 48 # 128 # 48 # 512 num_envs, 12 envs/worker # superceded by pokemon_red package
envs_per_worker: 1 # or 2 - time it, see which is faster # 8 # 4 # superceded by pokemon_red package
envs_per_batch: 48 # must be <= num_envs # superceded by pokemon_red package
env_pool: True # superceded by pokemon_red package
verbose: True # superceded by pokemon_red package
data_dir: experiments
checkpoint_interval: 500 # 40960 # 2048 * 10 * 2
pool_kernel: [0]
batch_size: 32768 # 48 # no async to avoid messing with things # 32768 # 128 (?) # superceded by pokemon_red package
batch_rows: 128 # between 128 and 1024 - empricaly# 1024 # 256 # 128 # superceded by pokemon_red package
bptt_horizon: 32 # 16
vf_clip_coef: 0.1
compile: True # superceded by pokemon_red package
compile_mode: reduce-overhead

sweep:
method: random
name: sweep
metric:
goal: maximize
name: episodic_return
# Nested parameters name required by WandB API
parameters:
train:
parameters:
learning_rate: {
'distribution': 'log_uniform_values',
'min': 1e-4,
'max': 1e-1,
}
batch_size: {
'values': [128, 256, 512, 1024, 2048],
}
batch_rows: {
'values': [16, 32, 64, 128, 256],
}
bptt_horizon: {
'values': [4, 8, 16, 32],
}

pokemon_red:
package: pokemon_red
train:
total_timesteps: 800_000_000
num_envs: 48 # 256
envs_per_worker: 1
envs_per_batch: 48 # 48 # must be divisible by envs_per_worker
update_epochs: 3 # 10 # 3
gamma: 0.9996
batch_size: 32768 # 65536 # 32768
batch_rows: 128 # 256
compile: True

# Boey-specific env parameters; loaded by environment.py
save_final_state: True
print_rewards: True
headless: True
init_state: /bet_adsorption_xinpw8/PufferLib/pokegym/pokegym/save_state_dir/has_pokedex_nballs_noanim.state # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir/start_from_state_dir/has_pokedex_nballs_noanim.state
action_freq: 24
max_steps: 30720000 # 30720000 # Updated to match ep_length
early_stop: True
early_stopping_min_reward: 2.0
save_video: False
fast_video: True
explore_weight: 1.5
use_screen_explore: False
sim_frame_dist: 2000000.0 # 2000000.0
reward_scale: 4
extra_buttons: False
noop_button: True
swap_button: True
restricted_start_menu: True # False
level_reward_badge_scale: 1.0
save_state_dir: /bet_adsorption_xinpw8/PufferLib/pokegym/pokegym/save_state_dir # /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir
special_exploration_scale: 1.0
enable_item_manager: True # True
enable_stage_manager: True # True
enable_item_purchaser: True # True
auto_skip_anim: True
auto_skip_anim_frames: 8
total_envs: 48 # 48 # Updated to match num_cpu
gb_path: PokemonRed.gb
debug: False
level_manager_eval_mode: False
sess_id: generate # Updated dynamically, placeholder for dynamic generation
use_wandb_logging: False
cpu_multiplier: 0.25
save_freq: 500 # 40960 # 2048 * 10 * 2
n_steps: 163840 # Calculated as int(5120 // cpu_multiplier) * 1
num_cpu: 48 # number of processes, 1 env per process # 8 # Calculated as int(32 * cpu_multiplier)
env:
name: pokemon_red
pokemon-red:
package: pokemon_red
pokemonred:
package: pokemon_red
pokemon:
package: pokemon_red
pokegym:
package: pokemon_red

12 changes: 6 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def make_policy(env, env_module, args):
policy = pufferlib.frameworks.cleanrl.Policy(policy)

# BET ADDED 1
mode = "default"
if args.train.device == "cuda":
mode = "reduce-overhead"
policy = policy.to(args.train.device, non_blocking=True)
policy.get_value = torch.compile(policy.get_value, mode=mode)
policy.get_action_and_value = torch.compile(policy.get_action_and_value, mode=mode)
# mode = "default"
# if args.train.device == "cuda":
# mode = "reduce-overhead"
# policy = policy.to(args.train.device, non_blocking=True)
# policy.get_value = torch.compile(policy.get_value, mode=mode)
# policy.get_action_and_value = torch.compile(policy.get_action_and_value, mode=mode)

return policy.to(args.train.device)

Expand Down
2 changes: 1 addition & 1 deletion pufferlib/environments/pokemon_red/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def env_creator(name="pokemon_red"):
def make(name, **kwargs,):
"""Pokemon Red"""
env = Environment(kwargs)
env = StreamWrapper(env, stream_metadata={"user": " PUFFERBOX4|BET|PUFFERBOX4 \nPUFFERBOX4|BET|PUFFERBOX4 \n====BOEY====\nPUFFERBOX4|BET|PUFFERBOX4 "})
env = StreamWrapper(env, stream_metadata={"user": "PUFFERBOX5|BET|\n=BOEY=\n"})
# Looks like the following will optionally create the object for you
# Or use the one you pass it. I'll just construct it here.
return pufferlib.emulation.GymnasiumPufferEnv(
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
#!/bin/bash
python demo.py --backend clean_pufferl --config pokemon_red --no-render --vectorization multiprocessing --mode train --track
python demo.py --backend clean_pufferl --config pokemon_red --no-render --vectorization multiprocessing --mode train --track # --exp-name test4 # --wandb-entity xinpw8
5 changes: 3 additions & 2 deletions stream_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,18 @@ def colors_generator(step=1):
class StreamWrapper(gym.Wrapper):
def __init__(self, env, stream_metadata={}):
super().__init__(env)
self.color_generator = color_generator(step=5) # step=1
self.color_generator = color_generator(step=2) # step=1
# self.ws_address = "wss://poke-ws-test-ulsjzjzwpa-ue.a.run.app/broadcast"
self.ws_address = "wss://transdimensional.xyz/broadcast"
self.stream_metadata = stream_metadata
self.stream_metadata = {**stream_metadata, "env_id": env.env_id,} # env ids listed
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.websocket = None
self.loop.run_until_complete(
self.establish_wc_connection()
)
self.upload_interval = 250
self.upload_interval = 125
self.steam_step_counter = 0
self.coord_list = []
self.start_time = time.time()
Expand Down

0 comments on commit fd6c279

Please sign in to comment.