diff --git a/clean_pufferl.py b/clean_pufferl.py index e423577a..6e388a46 100755 --- a/clean_pufferl.py +++ b/clean_pufferl.py @@ -24,6 +24,7 @@ from collections import deque from pokegym.global_map import GLOBAL_MAP_SHAPE from pokegym.eval import make_pokemon_red_overlay +from pathlib import Path @pufferlib.dataclass class Performance: @@ -44,6 +45,7 @@ class Performance: train_sps = 0 train_memory = 0 train_pytorch_memory = 0 + misc_time = 0 @pufferlib.dataclass class Losses: @@ -61,31 +63,48 @@ class Charts: SPS = 0 learning_rate = 0 -def init( +def create( self: object = None, config: pufferlib.namespace = None, exp_name: str = None, track: bool = False, - # Agent agent: nn.Module = None, agent_creator: callable = None, agent_kwargs: dict = None, - # Environment env_creator: callable = None, env_creator_kwargs: dict = None, vectorization: ... = pufferlib.vectorization.Serial, - # Policy Pool options policy_selector: callable = pufferlib.policy_pool.random_selector, ): - if config is None: - config = pufferlib.args.CleanPuffeRL() + if config is None: + config = pufferlib.args.CleanPuffeRL() + # Check if exp_name is set, otherwise generate a new one if exp_name is None: - exp_name = str(uuid.uuid4())[:8] - + exp_name = str(uuid.uuid4())[:8] + # Base directory path + required_resources_dir = Path('/home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym') + # Path for the required_resources directory + required_resources_path = required_resources_dir / "required_resources" + required_resources_path.mkdir(parents=True, exist_ok=True) + # Files to be created in the required_resources directory + files = ["running_experiment.txt", "test_exp.txt", "stats.txt"] + # Create the files if they do not exist + for file_name in files: + file_path = required_resources_path / file_name + file_path.touch(exist_ok=True) # Creates the file if it doesn't exist, without erasing content if it does + # Now, you can correctly specify the file path for each file + running_experiment_file_path = required_resources_path / "running_experiment.txt" + test_exp_file_path = required_resources_path / "test_exp.txt" + # Write the experiment name to "running_experiment.txt" for environment.py folder logic + # TODO: write to json for easier reading + exp_name = f"{exp_name}" + with open(running_experiment_file_path, 'w') as file: + file.write(f"{exp_name}") + wandb = None if track: import wandb @@ -95,7 +114,10 @@ def init( total_updates = config.total_timesteps // config.batch_size device = config.device - # obs_device = 'cpu' if config.cpu_offload else device ## BET ADDED 19 + + # Write parsed config to file; environment.py reads for initialization + with open(test_exp_file_path, 'w') as file: + file.write(f"{config}") # Create environments, agent, and optimizer init_profiler = pufferlib.utils.Profiler(memory=True) @@ -107,7 +129,9 @@ def init( envs_per_worker=config.envs_per_worker, envs_per_batch=config.envs_per_batch, env_pool=config.env_pool, + mask_agents=True, ) + print(f'pool=cprl {pool}') obs_shape = pool.single_observation_space.shape atn_shape = pool.single_action_space.shape @@ -115,37 +139,43 @@ def init( total_agents = num_agents * config.num_envs # If data_dir is provided, load the resume state - resume_state = {} - path = os.path.join(config.data_dir, exp_name) - if os.path.exists(path): - trainer_path = os.path.join(path, 'trainer_state.pt') - resume_state = torch.load(trainer_path) - model_path = os.path.join(path, resume_state["model_name"]) - agent = torch.load(model_path, map_location=device) - print(f'Resumed from update {resume_state["update"]} ' - f'with policy {resume_state["model_name"]}') - else: - agent = pufferlib.emulation.make_object( - agent, agent_creator, [pool.driver_env], agent_kwargs) + try: + resume_state = {} + path = os.path.join(config.data_dir, exp_name) + if os.path.exists(path): + trainer_path = os.path.join(path, 'trainer_state.pt') + resume_state = torch.load(trainer_path) + model_path = os.path.join(path, resume_state["model_name"]) + agent = torch.load(model_path, map_location=device) + print(f'Resumed from update {resume_state["update"]} ' + f'with policy {resume_state["model_name"]}') + else: + agent = pufferlib.emulation.make_object( + agent, agent_creator, [pool.driver_env], agent_kwargs) + except: + pass + # Some data to preserve run parameters when loading a saved model global_step = resume_state.get("global_step", 0) agent_step = resume_state.get("agent_step", 0) update = resume_state.get("update", 0) lr_update = resume_state.get("lr_update", 0) # BET ADDED 20 + agent = pufferlib.emulation.make_object( + agent, agent_creator, [pool.driver_env], agent_kwargs) + optimizer = optim.Adam(agent.parameters(), lr=config.learning_rate, eps=1e-5) + uncompiled_agent = agent # Needed to save the model opt_state = resume_state.get("optimizer_state_dict", None) - # BET ADDED 21 (through line 144) if config.compile: - agent = torch.compile(agent, mode=config.compile_mode) - # TODO: Figure out how to compile the optimizer! - # self.calculate_loss = torch.compile(self.calculate_loss, mode=config.compile_mode) + agent = torch.compile(agent, mode=config.compile_mode) if config.verbose: n_params = sum(p.numel() for p in agent.parameters() if p.requires_grad) print(f"Model Size: {n_params//1000} K parameters") + opt_state = resume_state.get("optimizer_state_dict", None) if opt_state is not None: optimizer.load_state_dict(resume_state["optimizer_state_dict"]) @@ -163,7 +193,7 @@ def init( # Allocate Storage storage_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True).start() - # next_lstm_state = [] ## BET ADDED 13 + next_lstm_state = [] pool.async_reset(config.seed) next_lstm_state = None @@ -174,13 +204,13 @@ def init( torch.zeros(shape, device=device), torch.zeros(shape, device=device), ) - obs = torch.zeros(config.batch_size + 1, *obs_shape) - actions = torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) - logprobs = torch.zeros(config.batch_size + 1) - rewards = torch.zeros(config.batch_size + 1) - dones = torch.zeros(config.batch_size + 1) - truncateds = torch.zeros(config.batch_size + 1) - values = torch.zeros(config.batch_size + 1) + obs=torch.zeros(config.batch_size + 1, *obs_shape) + actions=torch.zeros(config.batch_size + 1, *atn_shape, dtype=int) + logprobs=torch.zeros(config.batch_size + 1) + rewards=torch.zeros(config.batch_size + 1) + dones=torch.zeros(config.batch_size + 1) + truncateds=torch.zeros(config.batch_size + 1) + values=torch.zeros(config.batch_size + 1) obs_ary = np.asarray(obs) actions_ary = np.asarray(actions) @@ -189,25 +219,9 @@ def init( dones_ary = np.asarray(dones) truncateds_ary = np.asarray(truncateds) values_ary = np.asarray(values) - - ## BET ADDED 14 - # if hasattr(agent, 'lstm'): - # shape = (agent.lstm.num_layers, total_agents, agent.lstm.hidden_size) - # next_lstm_state = ( - # torch.zeros(shape).to(device), - # torch.zeros(shape).to(device), - # ) - # obs=torch.zeros(config.batch_size + 1, *obs_shape).to(obs_device) - # actions=torch.zeros(config.batch_size + 1, *atn_shape, dtype=int).to(device) - # logprobs=torch.zeros(config.batch_size + 1).to(device) - # rewards=torch.zeros(config.batch_size + 1).to(device) - # dones=torch.zeros(config.batch_size + 1).to(device) - # truncateds=torch.zeros(config.batch_size + 1).to(device) - # values=torch.zeros(config.batch_size + 1).to(device) storage_profiler.stop() - #"charts/actions": wandb.Histogram(b_actions.cpu().numpy()), init_performance = pufferlib.namespace( init_time = time.time() - start_time, init_env_time = init_profiler.elapsed, @@ -216,26 +230,12 @@ def init( tensor_pytorch_memory = storage_profiler.pytorch_memory, ) - return pufferlib.namespace(self, - - - # BET ADDED 22 - reward_buffer = deque(maxlen=1_000), - exploration_map_agg = np.zeros((config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32), - taught_cut = False, - infos = {}, - obs_ary = obs_ary, - actions_ary = actions_ary, - logprobs_ary = logprobs_ary, - rewards_ary = rewards_ary, - dones_ary = dones_ary, - truncateds_ary = truncateds_ary, - values_ary = values_ary, - + return pufferlib.namespace(self, # Agent, Optimizer, and Environment config=config, pool = pool, agent = agent, + uncompiled_agent = uncompiled_agent, optimizer = optimizer, policy_pool = policy_pool, @@ -258,13 +258,24 @@ def init( rewards = rewards, dones = dones, values = values, + # BET ADDED 22 + reward_buffer = deque(maxlen=1_000), + exploration_map_agg = np.zeros((config.num_envs, *GLOBAL_MAP_SHAPE), dtype=np.float32), + taught_cut = False, + infos = {}, + obs_ary = obs_ary, + actions_ary = actions_ary, + logprobs_ary = logprobs_ary, + rewards_ary = rewards_ary, + dones_ary = dones_ary, + truncateds_ary = truncateds_ary, + values_ary = values_ary, # Misc total_updates = total_updates, update = update, global_step = global_step, device = device, - # obs_device = obs_device, ## BET ADDED 24 start_time = start_time, ) @@ -281,7 +292,7 @@ def evaluate(data): **{f'performance/{k}': v for k, v in data.performance.items()}, **{f'stats/{k}': v for k, v in data.stats.items()}, - **{f"max_stats/{k}": v for k, v in data.max_stats.items()}, # BET ADDED 1 + # **{f"max_stats/{k}": v for k, v in data.max_stats.items()}, # BET ADDED 1 **{f'skillrank/{policy}': elo for policy, elo in data.policy_pool.ranker.ratings.items()}, }) @@ -294,7 +305,7 @@ def evaluate(data): misc_profiler = pufferlib.utils.Profiler() # BET ADDED 2 ptr = step = padded_steps_collected = agent_steps_collected = 0 - # infos = defaultdict(lambda: defaultdict(list)) + infos = defaultdict(lambda: defaultdict(list)) while True: step += 1 if ptr == config.batch_size + 1: @@ -303,21 +314,12 @@ def evaluate(data): with env_profiler: o, r, d, t, i, env_id, mask = data.pool.recv() - # i = data.policy_pool.update_scores(i, "return") ## BET ADDED 3 - # BET ADDED 4 with misc_profiler: i = data.policy_pool.update_scores(i, "return") # TODO: Update this for policy pool - for ii, ee in zip(i["learner"], env_id): - ii["env_id"] = ee - - ## BET ADDED 5 - # with inference_profiler, torch.no_grad(): - # o = torch.as_tensor(o) - # r = torch.as_tensor(r).float().to(data.device).view(-1) - # d = torch.as_tensor(d).float().to(data.device).view(-1) - - ## BET ADDED 6 + for ii, ee in zip(i['learner'], env_id): + ii['env_id'] = ee + with inference_profiler, torch.no_grad(): o = torch.as_tensor(o).to(device=data.device, non_blocking=True) r = (torch.as_tensor(r, dtype=torch.float32).to(device=data.device, non_blocking=True).view(-1)) @@ -325,7 +327,7 @@ def evaluate(data): agent_steps_collected += sum(mask) padded_steps_collected += len(mask) - # with inference_profiler, torch.no_grad(): ## BET ADDED 7 + # Multiple policies will not work with new envpool next_lstm_state = data.next_lstm_state if next_lstm_state is not None: @@ -344,18 +346,13 @@ def evaluate(data): value = value.flatten() - # BET ADDED 8 with misc_profiler: actions = actions.cpu().numpy() - # Index alive mask with policy pool idxs... - # TODO: Find a way to avoid having to do this - # learner_mask = mask * data.policy_pool.mask ## BET ADDED 9 learner_mask = torch.Tensor(mask * data.policy_pool.mask) # BET ADDED 10 - ## BET ADDED 12 (through 320) # Ensure indices do not exceed batch size - indices = torch.where(learner_mask)[0][: config.batch_size - ptr + 1].numpy() + indices = torch.where(learner_mask)[0][:config.batch_size - ptr + 1].numpy() end = ptr + len(indices) # Batch indexing @@ -373,18 +370,10 @@ def evaluate(data): for policy_name, policy_i in i.items(): for agent_i in policy_i: for name, dat in unroll_nested_dict(agent_i): - if policy_name not in data.infos: - data.infos[policy_name] = {} - if name not in data.infos[policy_name]: - data.infos[policy_name][name] = [ - np.zeros_like(dat) - ] * data.config.num_envs - data.infos[policy_name][name][agent_i["env_id"]] = dat + infos[policy_name][name].append(dat) with env_profiler: - data.pool.send(actions) # BET ADDED 37 - # data.pool.send(actions.cpu().numpy()) # BET ADDED 36 - - # BET ADDED 35 (through line 403) + data.pool.send(actions) + data.reward_buffer.append(r.cpu().sum().numpy()) # Probably should normalize the rewards before trying to take the variance... reward_var = np.var(data.reward_buffer) @@ -403,38 +392,11 @@ def evaluate(data): data.reward_buffer.clear() # reset lr update if the reward starts stalling data.lr_update = 1.0 - - ## BET ADDED 11 - # for idx in np.where(learner_mask)[0]: - # if ptr == config.batch_size + 1: - # break - # data.obs[ptr] = o[idx] - # data.values[ptr] = value[idx] - # data.actions[ptr] = actions[idx] - # data.logprobs[ptr] = logprob[idx] - # data.sort_keys.append((env_id[idx], step)) - # if len(d) != 0: - # data.rewards[ptr] = r[idx] - # data.dones[ptr] = d[idx] - # ptr += 1 - - - # for policy_name, policy_i in i.items(): - # for agent_i in policy_i: - # for name, dat in unroll_nested_dict(agent_i): - # infos[policy_name][name].append(dat) - eval_profiler.stop() - # BET ADDED 23 - # data.global_step += padded_steps_collected - # data.reward = float(torch.mean(data.rewards)) - - # BET ADDED 24 - data.global_step = np.mean(data.infos["learner"]["stats/step"]) - data.reward = torch.mean(data.rewards).float().item() - + data.global_step += padded_steps_collected + data.reward = float(torch.mean(data.rewards)) data.SPS = int(padded_steps_collected / eval_profiler.elapsed) perf = data.performance @@ -451,34 +413,48 @@ def evaluate(data): perf.misc_time = misc_profiler.elapsed # BET ADDED 25 data.stats = {} - data.max_stats = {} # BET ADDED 26 - - # BET ADDED 27 (bunch changed) - for k, v in data.infos['learner'].items(): - if 'Task_eval_fn' in k: - # Temporary hack for NMMO competition - continue - if 'pokemon_exploration_map' in k: - # import cv2 - # from pokemon_red_eval import make_pokemon_red_overlay - # bg = cv2.imread('kanto_map_dsv.png') - # overlay = make_pokemon_red_overlay(bg, sum(v)) - overlay = make_pokemon_red_overlay(np.stack(v, axis=0)) + # data.max_stats = {} # BET ADDED 26 + # BET ADDED 0.7 Original logic: + infos = infos['learner'] + + try: + if 'pokemon_exploration_map' in infos: + for idx, pmap in zip(infos['learner']['env_id'], infos['pokemon_exploration_map']): + if not hasattr(data, 'pokemon'): + import pokemon_red_eval + data.map_updater = pokemon_red_eval.map_updater() + data.map_buffer = np.zeros((data.config.num_envs, *pmap.shape)) + data.map_buffer[idx] = pmap + pokemon_map = np.sum(data.map_buffer, axis=0) + rendered = data.map_updater(pokemon_map) + import cv2 + # cv2.imwrite('c_counts_map.png', rendered) + # cv2.wait(1) + data.stats['Media/exploration_map'] = data.wandb.Image(rendered) + except: + pass + + try: + if "stats/step" in infos: + data.global_step = np.mean(infos["stats/step"]) + if 'pokemon_exploration_map' in infos: + overlay = make_pokemon_red_overlay(np.stack(infos['pokemon_exploration_map'], axis=0)) if data.wandb is not None: data.stats['Media/exploration_map'] = data.wandb.Image(overlay) - # @Leanke: Add your infos['learner']['x'] etc - try: # TODO: Better checks on log data types - data.stats[k] = np.mean(v) - data.max_stats[k] = np.max(v) - if data.max_stats["got_hm01"] > 0: - data.taught_cut = True + try: + data.stats['stats'] = np.mean(infos) + # data.max_stats['stats'] = np.max(infos) + # if data.max_stats["got_hm01"] > 0: + # data.taught_cut = True except: - continue + pass + except: + pass if config.verbose: print_dashboard(data.stats, data.init_performance, data.performance) - return data.stats, data.infos # BET ADDED 28 data.stats, infos + return data.stats, infos @pufferlib.utils.profile def train(data): @@ -491,11 +467,6 @@ def train(data): train_profiler = pufferlib.utils.Profiler(memory=True, pytorch_memory=True) train_profiler.start() - # # Anneal learning rate - # frac = 1.0 - (data.update - 1.0) / data.total_updates - # lrnow = frac * config.learning_rate - # data.optimizer.param_groups[0]["lr"] = lrnow - if config.anneal_lr: frac = 1.0 - (data.lr_update - 1.0) / data.total_updates lrnow = frac * config.learning_rate @@ -504,17 +475,12 @@ def train(data): num_minibatches = config.batch_size // config.bptt_horizon // config.batch_rows idxs = sorted(range(len(data.sort_keys)), key=data.sort_keys.__getitem__) data.sort_keys = [] - - # BET ADDED 28 - b_idxs = (torch.tensor(idxs, dtype=torch.long)[:-1] - .reshape(config.batch_rows, num_minibatches, config.bptt_horizon).transpose(0, 1)) - - # BET ADDED 27 - # b_idxs = ( - # torch.Tensor(idxs).long()[:-1] - # .reshape(config.batch_rows, num_minibatches, config.bptt_horizon) - # .transpose(0, 1) - # ) + + b_idxs = ( + torch.Tensor(idxs).long()[:-1] + .reshape(config.batch_rows, num_minibatches, config.bptt_horizon) + .transpose(0, 1) + ) # bootstrap value if not done with torch.no_grad(): @@ -533,15 +499,6 @@ def train(data): delta + config.gamma * config.gae_lambda * nextnonterminal * lastgaelam ) - # Flatten the batch - # BET ADDED 29 - # data.b_obs = b_obs = data.obs[b_idxs] - # b_actions = data.actions[b_idxs] - # b_logprobs = data.logprobs[b_idxs] - # b_dones = data.dones[b_idxs] - # b_values = data.values[b_idxs] - - # BET ADDED 30 (through line 520) data.b_obs = b_obs = torch.Tensor(data.obs_ary[b_idxs]) b_actions = torch.Tensor(data.actions_ary[b_idxs]).to(data.device, non_blocking=True) b_logprobs = torch.Tensor(data.logprobs_ary[b_idxs]).to(data.device, non_blocking=True) @@ -556,16 +513,12 @@ def train(data): # Optimizing the policy and value network train_time = time.time() pg_losses, entropy_losses, v_losses, clipfracs, old_kls, kls = [], [], [], [], [], [] - - # BET ADDED 31 + mb_obs_buffer = torch.zeros_like(b_obs[0], pin_memory=(data.device == "cuda")) for epoch in range(config.update_epochs): lstm_state = None for mb in range(num_minibatches): - # mb_obs = b_obs[mb].to(data.device) ## BET ADDED 32 - - # BET ADDED 33 mb_obs_buffer.copy_(b_obs[mb], non_blocking=True) mb_obs = mb_obs_buffer.to(data.device, non_blocking=True) @@ -668,7 +621,7 @@ def train(data): print_dashboard(data.stats, data.init_performance, data.performance) data.update += 1 - data.lr_update += 1 # BET ADDED 34 + data.lr_update += 1 if data.update % config.checkpoint_interval == 0 or done_training(data): save_checkpoint(data) @@ -677,13 +630,13 @@ def close(data): data.pool.close() ## BET ADDED 35 - # if data.wandb is not None: - # artifact_name = f"{data.exp_name}_model" - # artifact = data.wandb.Artifact(artifact_name, type="model") - # model_path = save_checkpoint(data) - # artifact.add_file(model_path) - # data.wandb.run.log_artifact(artifact) - # data.wandb.finish() + if data.wandb is not None: + artifact_name = f"{data.exp_name}_model" + artifact = data.wandb.Artifact(artifact_name, type="model") + model_path = save_checkpoint(data) + artifact.add_file(model_path) + data.wandb.run.log_artifact(artifact) + data.wandb.finish() def rollout(env_creator, env_kwargs, agent_creator, agent_kwargs, model_path=None, device='cuda', verbose=True): @@ -740,7 +693,7 @@ def save_checkpoint(data): if os.path.exists(model_path): return model_path - torch.save(data.agent, model_path) + torch.save(data.uncompiled_agent, model_path) state = { "optimizer_state_dict": data.optimizer.state_dict(), @@ -807,10 +760,3 @@ def print_dashboard(stats, init_performance, performance): print("\033c", end="") print('\n'.join(output)) time.sleep(1/20) - -class CleanPuffeRL: - __init__ = init - evaluate = evaluate - train = train - close = close - done_training = done_training diff --git a/config.yaml b/config.yaml index 27fc47cb..3c878f49 100755 --- a/config.yaml +++ b/config.yaml @@ -2,14 +2,14 @@ train: seed: 1 torch_deterministic: True device: cuda - total_timesteps: 10_000_000 - learning_rate: 0.0004 + total_timesteps: 800_000_000 + learning_rate: 0.0003 num_steps: 128 anneal_lr: True - gamma: 0.99 + gamma: 0.999 gae_lambda: 0.95 num_minibatches: 4 - update_epochs: 4 + update_epochs: 2 # 3 norm_adv: True clip_coef: 0.1 clip_vloss: True @@ -18,21 +18,20 @@ train: max_grad_norm: 0.5 target_kl: ~ - num_envs: 8 - envs_per_worker: 1 - envs_per_batch: ~ + num_envs: 128 # 48 + envs_per_worker: 4 + envs_per_batch: 48 # must be <= num_envs env_pool: True verbose: True data_dir: experiments - checkpoint_interval: 200 - cpu_offload: True + checkpoint_interval: 40960 # 2048 * 10 * 2 pool_kernel: [0] - batch_size: 1024 - batch_rows: 32 - bptt_horizon: 16 #8 + batch_size: 32768 # 128 (?) + batch_rows: 128 + bptt_horizon: 16 vf_clip_coef: 0.1 - - debug: False + compile: True + compile_mode: reduce-overhead sweep: method: random @@ -59,624 +58,56 @@ sweep: 'values': [4, 8, 16, 32], } -### Arcade Learning Environment suite -# Convenience wrappers provided for common test environments -atari: - package: atari - env: - name: BreakoutNoFrameskip-v4 -beamrider: - package: atari - env: - name: BeamRiderNoFrameskip-v4 -breakout: - package: atari - env: - name: BreakoutNoFrameskip-v4 -enduro: - package: atari - env: - name: EnduroNoFrameskip-v4 -pong: - package: atari - env: - name: PongNoFrameskip-v4 -qbert: - package: atari - env: - name: QbertNoFrameskip-v4 -seaquest: - package: atari - env: - name: SeaquestNoFrameskip-v4 -space_invaders: - package: atari - env: - name: SpaceInvadersNoFrameskip-v4 - -box2d: - package: box2d - -### Procgen Suite -# Shared hyperparams (best for all envs) -# Per-env hyperparams from CARBS -procgen: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0005 - num_cores: 1 - num_envs: 1 - batch_size: 16384 - batch_rows: 8 - bptt_horizon: 256 - gamma: 0.999 - update_epochs: 3 - anneal_lr: False - clip_coef: 0.2 - vf_clip_coef: 0.2 - env: - name: bigfish - policy: - cnn_width: 16 - mlp_width: 32 -bigfish: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0001901266338648 - gamma: 0.9990684264891424 - ent_coef: 0.0025487710400836 - vf_coef: 1.1732211834792117 - gae_lambda: 0.8620630095238284 - clip_coef: 0.4104603426698214 - num_cores: 1 - num_envs: 1 - batch_size: 53210 - batch_rows: 5321 - bptt_horizon: 1 - update_epochs: 3 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: bigfish - num_envs: 24 - policy: - cnn_width: 22 - mlp_width: 327 -bossfight: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0001391202716783 - gamma: 0.9989348776761554 - ent_coef: 0.0141638234842547 - vf_coef: 2.3544979860388664 - gae_lambda: 0.8895733311775463 - clip_coef: 0.5642914060539239 - num_cores: 1 - num_envs: 1 - batch_size: 48520 - batch_rows: 6065 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: bossfight - num_envs: 186 - policy: - cnn_width: 34 - mlp_width: 83 -caveflyer: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0003922570060721 - gamma: 0.9974587177630908 - ent_coef: 0.0225727962984408 - vf_coef: 1.6255759569858712 - gae_lambda: 0.9094175213807228 - clip_coef: 0.4508383484491862 - num_cores: 1 - num_envs: 1 - batch_size: 32308 - batch_rows: 8077 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: caveflyer - num_envs: 96 - policy: - cnn_width: 17 - mlp_width: 242 -chaser: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0003508035442326 - gamma: 0.9942435848334558 - ent_coef: 0.0071001859366116 - vf_coef: 2.1530812235373684 - gae_lambda: 0.8186838232115529 - clip_coef: 0.0821348744853704 - num_cores: 1 - num_envs: 1 - batch_size: 17456 - batch_rows: 2182 - bptt_horizon: 1 - update_epochs: 1 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: chaser - num_envs: 89 - policy: - cnn_width: 37 - mlp_width: 198 -climber: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0001217047694837 - gamma: 0.998084323380632 - ent_coef: 0.0171304566412224 - vf_coef: 0.8123888927054865 - gae_lambda: 0.8758003745828604 - clip_coef: 0.3879433119086241 - num_cores: 1 - num_envs: 1 - batch_size: 113288 - batch_rows: 3332 - bptt_horizon: 256 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: climber - num_envs: 207 - policy: - cnn_width: 29 - mlp_width: 134 -coinrun: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0002171100540455 - gamma: 0.9962953325196714 - ent_coef: 0.0024830293961112 - vf_coef: 0.4045225563446447 - gae_lambda: 0.9708900757395368 - clip_coef: 0.271239381520248 - num_cores: 1 - num_envs: 1 - batch_size: 184170 - batch_rows: 6139 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: coinrun - num_envs: 246 - policy: - cnn_width: 16 - mlp_width: 384 -dodgeball: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0002471773711262 - gamma: 0.9892421826991458 - ent_coef: 0.0061212242920176 - vf_coef: 0.905405768115384 - gae_lambda: 0.929215062387182 - clip_coef: 0.1678680070658446 - num_cores: 1 - num_envs: 1 - batch_size: 233026 - batch_rows: 4958 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: dodgeball - num_envs: 385 - policy: - cnn_width: 24 - mlp_width: 538 -fruitbot: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0005426317191531 - gamma: 0.9988953819963396 - ent_coef: 0.0115430852027873 - vf_coef: 0.5489566038515201 - gae_lambda: 0.7517437269156811 - clip_coef: 0.3909436413913963 - num_cores: 1 - num_envs: 1 - batch_size: 25344 - batch_rows: 4224 - bptt_horizon: 1 - update_epochs: 1 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: fruitbot - num_envs: 184 - policy: - cnn_width: 24 - mlp_width: 600 -heist: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0001460588554421 - gamma: 0.9929899907866796 - ent_coef: 0.0063411167117336 - vf_coef: 1.3750495866441763 - gae_lambda: 0.864713026766495 - clip_coef: 0.0341243664433126 - num_cores: 1 - num_envs: 1 - batch_size: 162233 - batch_rows: 3061 - bptt_horizon: 1 - update_epochs: 1 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: heist - num_envs: 999 - policy: - cnn_width: 60 - mlp_width: 154 -jumper: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0002667825838749 - gamma: 0.996178793124514 - ent_coef: 0.0035712927399072 - vf_coef: 0.2066134576246479 - gae_lambda: 0.9385007945498072 - clip_coef: 0.0589308261206342 - num_cores: 1 - num_envs: 1 - batch_size: 76925 - batch_rows: 3077 - bptt_horizon: 1 - update_epochs: 3 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: jumper - num_envs: 320 - policy: - cnn_width: 24 - mlp_width: 190 -leaper: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.000238551194954 - gamma: 0.9984543257393016 - ent_coef: 0.0264785452036158 - vf_coef: 1.12387183485305 - gae_lambda: 0.8968331903476625 - clip_coef: 0.6941033332120052 - num_cores: 1 - num_envs: 1 - batch_size: 19380 - batch_rows: 6460 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: leaper - num_envs: 252 - policy: - cnn_width: 28 - mlp_width: 100 -maze: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0001711754945436 - gamma: 0.9986484783565428 - ent_coef: 0.0027020733255912 - vf_coef: 0.1236421145384316 - gae_lambda: 0.971943769322524 - clip_coef: 0.2335644352369076 - num_cores: 1 - num_envs: 1 - batch_size: 116008 - batch_rows: 6834 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: maze - num_envs: 820 - policy: - cnn_width: 28 - mlp_width: 526 -miner: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.000328692228852 - gamma: 0.990897931823388 - ent_coef: 0.0045505824544649 - vf_coef: 6.559292234163336 - gae_lambda: 0.6494040942916905 - clip_coef: 0.2293978935956241 - num_cores: 1 - num_envs: 1 - batch_size: 154512 - batch_rows: 2088 - bptt_horizon: 1 - update_epochs: 3 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: miner - num_envs: 343 - policy: - cnn_width: 38 - mlp_width: 175 -ninja: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0002649776171804 - gamma: 0.998357586821043 - ent_coef: 0.0077158486367147 - vf_coef: 2.171674659769069 - gae_lambda: 0.9664148604540898 - clip_coef: 0.5891635585927152 - num_cores: 1 - num_envs: 1 - batch_size: 45246 - batch_rows: 7541 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: ninja - num_envs: 293 - policy: - cnn_width: 25 - mlp_width: 317 -plunder: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0002630139944456 - gamma: 0.9981502407071172 - ent_coef: 0.0222691283544936 - vf_coef: 4.316832667738928 - gae_lambda: 0.84500339385464 - clip_coef: 0.0914132500563203 - num_cores: 1 - num_envs: 1 - batch_size: 26304 - batch_rows: 4384 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: plunder - num_envs: 127 - policy: - cnn_width: 30 - mlp_width: 288 -starpilot: - package: procgen - train: - total_timesteps: 8_000_000 - learning_rate: 0.0004257280551714 - gamma: 0.9930510505613882 - ent_coef: 0.007836164188961 - vf_coef: 5.482314699746532 - gae_lambda: 0.82792978724664 - clip_coef: 0.2645124138418521 - num_cores: 1 - num_envs: 1 - batch_size: 107440 - batch_rows: 6715 - bptt_horizon: 1 - update_epochs: 2 - anneal_lr: False - vf_clip_coef: 0.2 - env: - name: starpilot - num_envs: 320 - policy: - cnn_width: 25 - mlp_width: 144 - -bsuite: - package: bsuite - train: - total_timesteps: 1_000_000 - num_envs: 1 - env: - name: bandit/0 - -butterfly: - package: butterfly - env: - name: cooperative_pong_v5 - -classic_control: - package: classic_control - train: - num_envs: 16 - env: - name: cartpole -classic-control: - package: classic_control -classiccontrol: - package: classic_control -cartpole: - package: classic_control - -crafter: - package: crafter - env: - name: CrafterReward-v1 - -dm_control: - package: dm_control -dm-control: - package: dm_control -dmcontrol: - package: dm_control -dmc: - package: dm_control - -dm_lab: - package: dm_lab -dm-lab: - package: dm_lab -dmlab: - package: dm_lab -dml: - package: dm_lab - -griddly: - package: griddly - env: - name: GDY-Spiders-v0 - -magent: - package: magent - env: - name: battle_v4 - -microrts: - package: microrts - env: - name: GlobalAgentCombinedRewardEnv - -minerl: - package: minerl - env: - name: MineRLNavigateDense-v0 - -minigrid: - package: minigrid - env: - name: MiniGrid-LavaGapS7-v0 - -minihack: - package: minihack - env: - name: MiniHack-River-v0 - -nethack: - package: nethack - env: - name: NetHackScore-v0 - -nmmo: - package: nmmo - train: - num_envs: 1 - envs_per_batch: 1 - envs_per_worker: 1 - batch_size: 4096 - batch_rows: 128 - env: - name: nmmo - -# Ocean: PufferAI's first party environment suite -ocean: - package: ocean - train: - total_timesteps: 30_000 - learning_rate: 0.017 - num_envs: 8 - batch_rows: 32 - bptt_horizon: 4 - device: cpu - env: - name: squared -bandit: - package: ocean - env: - name: bandit -memory: - package: ocean - env: - name: memory -password: - package: ocean - env: - name: password -squared: - package: ocean - env: - name: squared -stochastic: - package: ocean - env: - name: stochastic - -open_spiel: - package: open_spiel - train: - pool_kernel: [0, 1, 1, 0] - num_envs: 32 - batch_size: 4096 - env: - name: connect_four -open-spiel: - package: open_spiel -openspiel: - package: open_spiel -connect_four: - package: open_spiel - env: - name: connect_four -connect-four: - package: open_spiel - env: - name: connect_four -connectfour: - package: open_spiel - env: - name: connect_four -connect4: - package: open_spiel - env: - name: connect_four - pokemon_red: package: pokemon_red train: - total_timesteps: 100_000_000 + total_timesteps: 800_000_000 num_envs: 128 envs_per_worker: 4 - envpool_batch_size: 48 - update_epochs: 1 + envs_per_batch: 48 + update_epochs: 2 # 3 gamma: 0.998 batch_size: 32768 batch_rows: 128 + compile: True + + # Boey-specific env parameters; loaded by environment.py + save_final_state: True + print_rewards: True + headless: True + init_state: /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir/start_from_state_dir/has_pokedex_nballs_noanim.state + action_freq: 24 + max_steps: 30720000 # Updated to match ep_length + early_stop: True + early_stopping_min_reward: 2.0 + save_video: False + fast_video: True + explore_weight: 1.5 + use_screen_explore: False + sim_frame_dist: 1000000.0 # 2000000.0 + reward_scale: 4 + extra_buttons: False + noop_button: True + swap_button: True + restricted_start_menu: True # False + level_reward_badge_scale: 1.0 + save_state_dir: /home/daa/puffer0.5.2_iron/obs_space_experiments/pokegym/pokegym/save_state_dir + special_exploration_scale: 1.0 + enable_item_manager: True + enable_stage_manager: True + enable_item_purchaser: True + auto_skip_anim: True + auto_skip_anim_frames: 8 + total_envs: 48 # Updated to match num_cpu + gb_path: PokemonRed.gb debug: False - compile: True # BET ADDED 1 - compile_mode: reduce-overhead # BET ADDED 2 - + level_manager_eval_mode: False + sess_id: generate # Updated dynamically, placeholder for dynamic generation + use_wandb_logging: False + cpu_multiplier: 0.25 + save_freq: 40960 # 2048 * 10 * 2 + n_steps: 163840 # Calculated as int(5120 // cpu_multiplier) * 1 + num_cpu: 8 # Calculated as int(32 * cpu_multiplier) env: name: pokemon_red pokemon-red: @@ -687,30 +118,4 @@ pokemon: package: pokemon_red pokegym: package: pokemon_red - -links_awaken: - package: links_awaken -links-awaken: - package: links_awaken -linksawaken: - package: links_awaken -zelda: - package: links_awaken - -smac: - package: smac - env: - name: smac -starcraft: - package: smac - -stable_retro: - package: stable_retro - env: - name: Airstriker-Genesis -stable-retro: - package: stable_retro -stableretro: - package: stable_retro -retro: - package: stable_retro + \ No newline at end of file diff --git a/demo.py b/demo.py index b11b638d..61919400 100755 --- a/demo.py +++ b/demo.py @@ -11,7 +11,7 @@ import pufferlib import pufferlib.utils -from clean_pufferl import CleanPuffeRL, rollout, done_training +import clean_pufferl def load_from_config(env): @@ -32,22 +32,30 @@ def load_from_config(env): for key in default_keys: env_subconfig = env_config.get(key, {}) pkg_subconfig = pkg_config.get(key, {}) + # Override first with pkg then with env configs - combined_config[key] = {**defaults[key], **pkg_subconfig, **env_subconfig} + try: + combined_config[key] = {**defaults[key], **pkg_subconfig, **env_subconfig} + # print(f'combo_config: {combined_config[key]}') + except TypeError as e: + pass + # print(f'combined_config={combined_config}') + # print(f' {type(e)} ') + # print(f'key={type(key)}; combined_config[{key}]=sad') + finally: + # print(f'{key} has caused its last problem.') + pass return pkg, pufferlib.namespace(**combined_config) def make_policy(env, env_module, args): policy = env_module.Policy(env, **args.policy) - if args.force_recurrence or env_module.Recurrent is not None: policy = env_module.Recurrent(env, policy, **args.recurrent) policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy) - else: policy = pufferlib.frameworks.cleanrl.Policy(policy) - return policy.to(args.train.device) @@ -106,11 +114,12 @@ def get_init_args(fn): continue else: args[name] = param.default if param.default is not inspect.Parameter.empty else None + # print(f'ARGS LINE116 DEMO.PY: {args}\n\n') return args def train(args, env_module, make_env): if args.backend == 'clean_pufferl': - trainer = CleanPuffeRL( + data = clean_pufferl.create( config=args.train, agent_creator=make_policy, agent_kwargs={'env_module': env_module, 'args': args}, @@ -121,12 +130,12 @@ def train(args, env_module, make_env): track=args.track, ) - while not done_training(trainer): - trainer.evaluate() - trainer.train() + while not clean_pufferl.done_training(data): + clean_pufferl.evaluate(data) + clean_pufferl.train(data) print('Done training. Saving data...') - trainer.close() + clean_pufferl.close(data) print('Run complete') elif args.backend == 'sb3': from stable_baselines3 import PPO @@ -150,12 +159,12 @@ def train(args, env_module, make_env): parser.add_argument('--backend', type=str, default='clean_pufferl', help='Train backend (clean_pufferl, sb3)') parser.add_argument('--config', type=str, default='pokemon_red', help='Configuration in config.yaml to use') parser.add_argument('--env', type=str, default=None, help='Name of specific environment to run') - parser.add_argument('--mode', type=str, default='train', help='train/sweep/evaluate') + parser.add_argument('--mode', type=str, default='train', choices='train sweep evaluate'.split()) parser.add_argument('--eval-model-path', type=str, default=None, help='Path to model to evaluate') parser.add_argument('--baseline', action='store_true', help='Baseline run') parser.add_argument('--no-render', action='store_true', help='Disable render during evaluate') parser.add_argument('--exp-name', type=str, default=None, help="Resume from experiment") - parser.add_argument('--vectorization', type=str, default='serial', help='Vectorization method (serial, multiprocessing, ray)') + parser.add_argument('--vectorization', type=str, default='serial', choices='serial multiprocessing ray'.split()) parser.add_argument('--wandb-entity', type=str, default='xinpw8', help='WandB entity') parser.add_argument('--wandb-project', type=str, default='pufferlib', help='WandB project') parser.add_argument('--wandb-group', type=str, default='debug', help='WandB group') @@ -178,60 +187,36 @@ def train(args, env_module, make_env): # Update config with environment defaults config.env = {**get_init_args(make_env), **config.env} + # print(f'config.env={config.env}') config.policy = {**get_init_args(env_module.Policy.__init__), **config.policy} + # print(f'config.policy={config.policy}') config.recurrent = {**get_init_args(env_module.Recurrent.__init__), **config.recurrent} - + # print(f'config.recurrent={config.recurrent}') # Generate argparse menu from config - # This is also a reason for Spock/Argbind/OmegaConf/pydantic-cli for name, sub_config in config.items(): args[name] = {} for key, value in sub_config.items(): - data_key = f"{name}.{key}" - cli_key = f"--{data_key}".replace("_", "-") + data_key = f'{name}.{key}' + cli_key = f'--{data_key}'.replace('_', '-') if isinstance(value, bool) and value is False: - action = "store_false" - parser.add_argument(cli_key, default=value, action="store_true") - clean_parser.add_argument(cli_key, default=value, action="store_true") + action = 'store_false' + parser.add_argument(cli_key, default=value, action='store_true') + clean_parser.add_argument(cli_key, default=value, action='store_true') elif isinstance(value, bool) and value is True: - data_key = f"{name}.no_{key}" - cli_key = f"--{data_key}".replace("_", "-") - parser.add_argument(cli_key, default=value, action="store_false") - clean_parser.add_argument(cli_key, default=value, action="store_false") + data_key = f'{name}.no_{key}' + cli_key = f'--{data_key}'.replace('_', '-') + parser.add_argument(cli_key, default=value, action='store_false') + clean_parser.add_argument(cli_key, default=value, action='store_false') else: parser.add_argument(cli_key, default=value, type=type(value)) - clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value)) + clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value)) args[name][key] = getattr(parser.parse_known_args()[0], data_key) args[name] = pufferlib.namespace(**args[name]) clean_parser.parse_args(sys.argv[1:]) args = pufferlib.namespace(**args) - - # # Generate argparse menu from config - # for name, sub_config in config.items(): - # args[name] = {} - # for key, value in sub_config.items(): - # data_key = f'{name}.{key}' - # cli_key = f'--{data_key}'.replace('_', '-') - # if isinstance(value, bool) and value is False: - # action = 'store_false' - # parser.add_argument(cli_key, default=value, action='store_true') - # clean_parser.add_argument(cli_key, default=value, action='store_true') - # elif isinstance(value, bool) and value is True: - # data_key = f'{name}.no_{key}' - # cli_key = f'--{data_key}'.replace('_', '-') - # parser.add_argument(cli_key, default=value, action='store_false') - # clean_parser.add_argument(cli_key, default=value, action='store_false') - # else: - # parser.add_argument(cli_key, default=value, type=type(value)) - # clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value)) - - # args[name][key] = getattr(parser.parse_known_args()[0], data_key) - # args[name] = pufferlib.namespace(**args[name]) - - # clean_parser.parse_args(sys.argv[1:]) - # args = pufferlib.namespace(**args) vec = args.vectorization if vec == 'serial': @@ -249,12 +234,13 @@ def train(args, env_module, make_env): args.exp_name = init_wandb(args, env_module).id elif args.baseline: args.track = True - args.exp_name = f'puf-{pufferlib.__version__}-{args.config}' - args.wandb_group = f'puf-{pufferlib.__version__}-baseline' + version = '.'.join(pufferlib.__version__.split('.')[:2]) + args.exp_name = f'puf-{version}-{args.config}' + args.wandb_group = f'puf-{version}-baseline' shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True) run = init_wandb(args, env_module, name=args.exp_name, resume=False) if args.mode == 'evaluate': - model_name = f'puf-{pufferlib.__version__}-{args.config}_model:latest' + model_name = f'puf-{version}-{args.config}_model:latest' artifact = run.use_artifact(model_name) data_dir = artifact.download() model_file = max(os.listdir(data_dir)) @@ -262,10 +248,10 @@ def train(args, env_module, make_env): if args.mode == 'train': train(args, env_module, make_env) - exit(0) + # exit(0) elif args.mode == 'sweep': sweep(args, env_module, make_env) - exit(0) + # exit(0) elif args.mode == 'evaluate' and pkg != 'pokemon_red': rollout( make_env, @@ -286,4 +272,4 @@ def train(args, env_module, make_env): device=args.train.device, ) elif pkg != 'pokemon_red': - raise ValueError('Mode must be one of train, sweep, or evaluate') + raise ValueError('Mode must be one of train, sweep, or evaluate') \ No newline at end of file diff --git a/kanto_map_dsv.png b/kanto_map_dsv.png old mode 100644 new mode 100755 diff --git a/pufferlib/__init__.py b/pufferlib/__init__.py index 40c8fa76..fdebbb74 100755 --- a/pufferlib/__init__.py +++ b/pufferlib/__init__.py @@ -1,26 +1,23 @@ from pufferlib import version __version__ = version.__version__ -# Shut deepmind_lab up -import warnings -warnings.filterwarnings("ignore", category=DeprecationWarning)#, module="deepmind_lab") -try: - from deepmind_lab import dmenv_module # Or whatever the actual module is -except ImportError: - pass - import os import sys -# Shut pygame up +# Silence noisy packages original_stdout = sys.stdout +original_stderr = sys.stderr sys.stdout = open(os.devnull, 'w') +sys.stderr = open(os.devnull, 'w') try: + import gymnasium import pygame except ImportError: pass sys.stdout.close() +sys.stderr.close() sys.stdout = original_stdout +sys.stderr = original_stderr from pufferlib.namespace import namespace, dataclass diff --git a/pufferlib/emulation.py b/pufferlib/emulation.py index c10bea46..140f9d71 100755 --- a/pufferlib/emulation.py +++ b/pufferlib/emulation.py @@ -1,662 +1,3 @@ -# from pdb import set_trace as T - -# import numpy as np -# import warnings - -# import gym -# import gymnasium -# import inspect -# from functools import cached_property -# from collections import OrderedDict -# from collections.abc import Iterable - -# import pufferlib -# import pufferlib.spaces -# from pufferlib import utils, exceptions -# from pufferlib.extensions import flatten, unflatten - -# DICT = 0 -# LIST = 1 -# TUPLE = 2 -# VALUE = 3 - - -# class Postprocessor: -# '''Modify data before it is returned from or passed to the environment - -# For multi-agent environments, each agent has its own stateful postprocessor. -# ''' -# def __init__(self, env, is_multiagent, agent_id=None): -# '''Postprocessors provide full access to the environment - -# This means you can use them to cheat. Don't blame us if you do. -# ''' -# self.env = env -# self.agent_id = agent_id -# self.is_multiagent = is_multiagent - -# @property -# def observation_space(self): -# '''The space of observations output by the postprocessor - -# You will have to implement this function if Postprocessor.observation -# modifies the structure of observations. Defaults to the env's obs space. - -# PufferLib supports heterogeneous observation spaces for multi-agent environments, -# provided that your postprocessor pads or otherwise cannonicalizes the observations. -# ''' -# if self.is_multiagent: -# return self.env.observation_space(self.env.possible_agents[0]) -# return self.env.observation_space - -# def reset(self, observation): -# '''Called at the beginning of each episode''' -# return - -# def observation(self, observation): -# '''Called on each observation after it is returned by the environment - -# You must override Postprocessor.observation_space if this function -# changes the structure of observations. -# ''' -# return observation - -# def action(self, action): -# '''Called on each action before it is passed to the environment - -# Actions output by your policy do not need to match the action space, -# but they must be compatible after this function is called. -# ''' -# return action - -# def reward_done_truncated_info(self, reward, done, truncated, info): -# '''Called on the reward, done, truncated, and info after they are returned by the environment''' -# return reward, done, truncated, info - - -# class BasicPostprocessor(Postprocessor): -# '''Basic postprocessor that injects returns and lengths information into infos and -# provides an option to pad to a maximum episode length. Works for single-agent and -# team-based multi-agent environments''' -# def reset(self, obs): -# self.epoch_return = 0 -# self.epoch_length = 0 -# self.done = False - -# def reward_done_truncated_info(self, reward, done, truncated, info): -# if isinstance(reward, (list, np.ndarray)): -# reward = sum(reward.values()) - -# # Env is done -# if self.done: -# return reward, done, truncated, info - -# self.epoch_length += 1 -# self.epoch_return += reward - -# if done or truncated: -# info['return'] = self.epoch_return -# info['length'] = self.epoch_length -# self.done = True - -# return reward, done, truncated, info - -# class GymnasiumPufferEnv(gymnasium.Env): -# def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, -# postprocessor_cls=BasicPostprocessor): -# self.env = make_object(env, env_creator, env_args, env_kwargs) -# self.postprocessor = postprocessor_cls(self.env, is_multiagent=False) - -# self.initialized = False -# self.done = True - -# self.is_observation_checked = False -# self.is_action_checked = False - -# # Cache the observation and action spaces -# self.observation_space -# self.action_space -# self.render_modes = 'human rgb_array'.split() -# self.render_mode = 'rgb_array' - -# @cached_property -# def observation_space(self): -# '''Returns a flattened, single-tensor observation space''' -# self.structured_observation_space = self.postprocessor.observation_space - -# # Flatten the featurized observation space and store -# # it for use in step. Return a box space for the user -# self.flat_observation_space, self.flat_observation_structure, self.box_observation_space, self.pad_observation = ( -# make_flat_and_box_obs_space(self.structured_observation_space)) - -# return self.box_observation_space - -# @cached_property -# def action_space(self): -# '''Returns a flattened, multi-discrete action space''' -# self.structured_action_space = self.env.action_space -# self.flat_action_structure = flatten_structure(self.structured_action_space.sample()) - -# # Store a flat version of the action space for use in step. Return a multidiscrete version for the user -# self.flat_action_space, self.multidiscrete_action_space = ( -# make_flat_and_multidiscrete_atn_space(self.env.action_space)) - -# self.sz = [ -# int(np.prod(subspace.shape)) -# for subspace in self.flat_action_space.values() -# ] - -# return self.multidiscrete_action_space - -# def seed(self, seed): -# self.env.seed(seed) - -# def reset(self, seed=None): -# self.initialized = True -# self.done = False - -# ob, info = _seed_and_reset(self.env, seed) - -# # Call user featurizer and flatten the observations -# self.postprocessor.reset(ob) -# processed_ob = concatenate(flatten(self.postprocessor.observation(ob))) - -# if __debug__: -# if not self.is_observation_checked: -# self.is_observation_checked = check_space( -# processed_ob, self.box_observation_space) - -# return processed_ob, info - -# def step(self, action): -# '''Execute an action and return (observation, reward, done, info)''' -# if not self.initialized: -# raise exceptions.APIUsageError('step() called before reset()') -# if self.done: -# raise exceptions.APIUsageError('step() called after environment is done') - -# action = self.postprocessor.action(action) - -# if __debug__: -# if not self.is_action_checked: -# self.is_action_checked = check_space( -# action, self.multidiscrete_action_space) - -# # Unpack actions from multidiscrete into the original action space -# action = unflatten( -# split( -# action, self.flat_action_space, self.sz, batched=False -# ), self.flat_action_structure -# ) - -# ob, reward, done, truncated, info = self.env.step(action) - -# # Call user postprocessors and flatten the observations -# ob = concatenate(flatten(self.postprocessor.observation(ob))) -# reward, done, truncated, info = self.postprocessor.reward_done_truncated_info(reward, done, truncated, info) - -# self.done = done -# return ob, reward, done, truncated, info - -# def render(self): -# return self.env.render() - -# def close(self): -# return self.env.close() - -# def unpack_batched_obs(self, batched_obs): -# return unpack_batched_obs(batched_obs, self.flat_observation_space, self.flat_observation_structure) - - -# class PettingZooPufferEnv: -# def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, -# postprocessor_cls=Postprocessor, postprocessor_kwargs={}, teams=None): -# self.env = make_object(env, env_creator, env_args, env_kwargs) -# self.initialized = False -# self.all_done = True - -# self.is_observation_checked = False -# self.is_action_checked = False - -# self.possible_agents = self.env.possible_agents if teams is None else list(teams.keys()) -# self.teams = teams - -# self.postprocessors = {agent: postprocessor_cls( -# self.env, is_multiagent=True, agent_id=agent, **postprocessor_kwargs) -# for agent in self.possible_agents} - -# # Cache the observation and action spaces -# self.observation_space(self.possible_agents[0]) -# self.action_space(self.possible_agents[0]) - -# @property -# def agents(self): -# return self.env.agents - -# @property -# def done(self): -# return len(self.agents) == 0 or self.all_done - -# @property -# def single_observation_space(self): -# return self.box_observation_space - -# @property -# def single_action_space(self): -# return self.multidiscrete_action_space - -# def observation_space(self, agent): -# '''Returns the observation space for a single agent''' -# if agent not in self.possible_agents: -# raise pufferlib.exceptions.InvalidAgentError(agent, self.possible_agents) - -# # Make a gym space defining observations for the whole team -# if self.teams is not None: -# obs_space = make_team_space( -# self.env.observation_space, self.teams[agent]) -# else: -# obs_space = self.env.observation_space(agent) - -# # Call user featurizer and create a corresponding gym space -# self.structured_observation_space = self.postprocessors[agent].observation_space - -# # Flatten the featurized observation space and store it for use in step. Return a box space for the user -# self.flat_observation_space, self.flat_observation_structure, self.box_observation_space, self.pad_observation = ( -# make_flat_and_box_obs_space(self.structured_observation_space)) - -# return self.box_observation_space - -# def action_space(self, agent): -# '''Returns the action space for a single agent''' -# if agent not in self.possible_agents: -# raise pufferlib.exceptions.InvalidAgentError(agent, self.possible_agents) - -# # Make a gym space defining actions for the whole team -# if self.teams is not None: -# atn_space = make_team_space( -# self.env.action_space, self.teams[agent]) -# else: -# atn_space = self.env.action_space(agent) - -# self.structured_action_space = atn_space -# self.flat_action_structure = flatten_structure(atn_space.sample()) - -# # Store a flat version of the action space for use in step. Return a multidiscrete version for the user -# self.flat_action_space, self.multidiscrete_action_space = make_flat_and_multidiscrete_atn_space(atn_space) - -# return self.multidiscrete_action_space - -# def reset(self, seed=None): -# obs, info = self.env.reset(seed=seed) -# self.initialized = True -# self.all_done = False - -# # Group observations into teams -# if self.teams is not None: -# obs = group_into_teams(self.teams, obs) - -# # Call user featurizer and flatten the observations -# postprocessed_obs = {} -# ob = list(obs.values())[0] -# for agent in self.possible_agents: -# post = self.postprocessors[agent] -# post.reset(ob) -# if agent in obs: -# ob = obs[agent] -# postprocessed_obs[agent] = concatenate(flatten(post.observation(ob))) - -# if __debug__: -# if not self.is_observation_checked: -# self.is_observation_checked = check_space( -# next(iter(postprocessed_obs.values())), -# self.box_observation_space -# ) - -# padded_obs = pad_agent_data(postprocessed_obs, -# self.possible_agents, self.pad_observation) - -# # Mask out missing agents -# padded_infos = {} -# for agent in self.possible_agents: -# if agent not in info: -# padded_infos[agent] = {} -# else: -# padded_infos[agent] = info[agent] -# padded_infos[agent]['mask'] = agent in obs - -# return padded_obs, padded_infos - -# def step(self, actions): -# '''Step the environment and return (observations, rewards, dones, infos)''' -# if not self.initialized: -# raise exceptions.APIUsageError('step() called before reset()') -# if self.done: -# raise exceptions.APIUsageError('step() called after environment is done') - -# # Postprocess actions and validate action spaces -# for agent in actions: -# if __debug__: -# if agent not in self.possible_agents: -# raise exceptions.InvalidAgentError(agent, self.agents) - -# actions[agent] = self.postprocessors[agent].action(actions[agent]) - -# if __debug__: -# if not self.is_action_checked: -# self.is_action_checked = check_space( -# next(iter(actions.values())), -# self.multidiscrete_action_space -# ) - -# # Unpack actions from multidiscrete into the original action space -# unpacked_actions = {} -# for agent, atn in actions.items(): -# if agent in self.agents: -# unpacked_actions[agent] = unflatten( -# split(atn, self.flat_action_space, self.sz, batched=False), -# self.flat_action_structure -# ) - -# if self.teams is not None: -# unpacked_actions = ungroup_from_teams(self.teams, unpacked_actions) - -# obs, rewards, dones, truncateds, infos = self.env.step(unpacked_actions) -# # TODO: Can add this assert once NMMO Horizon is ported to puffer -# # assert all(dones.values()) == (len(self.env.agents) == 0) - -# if self.teams is not None: -# obs, rewards, truncateds, dones = group_into_teams(self.teams, obs, rewards, truncateds, dones) - -# # Call user postprocessors and flatten the observations -# for agent in obs: -# obs[agent] = concatenate(flatten(self.postprocessors[agent].observation(obs[agent]))) -# rewards[agent], dones[agent], truncateds[agent], infos[agent] = self.postprocessors[agent].reward_done_truncated_info( -# rewards[agent], dones[agent], truncateds[agent], infos[agent]) - -# self.all_done = all(dones.values()) - -# # Mask out missing agents -# for agent in self.possible_agents: -# if agent not in infos: -# infos[agent] = {} -# else: -# infos[agent] = infos[agent] -# infos[agent]['mask'] = agent in obs - -# obs, rewards, dones, truncateds = pad_to_const_num_agents( -# self.env.possible_agents, obs, rewards, dones, truncateds, self.pad_observation) - -# return obs, rewards, dones, truncateds, infos - -# def render(self): -# return self.env.render() - -# def close(self): -# return self.env.close() - -# def unpack_batched_obs(self, batched_obs): -# return unpack_batched_obs(batched_obs, self.flat_observation_space, self.flat_observation_structure, self.sz) - - -# def unpack_batched_obs(batched_obs, flat_observation_space, -# flat_observation_structure, sz): -# unpacked = split(batched_obs, flat_observation_space, self.sz, batched=True) -# unflattened = unflatten(unpacked, flat_observation_structure) -# return unflattened - -# def make_object(object_instance=None, object_creator=None, creator_args=[], creator_kwargs={}): -# if (object_instance is None) == (object_creator is None): -# raise ValueError('Exactly one of object_instance or object_creator must be provided') - -# if object_instance is not None: -# if callable(object_instance) or inspect.isclass(object_instance): -# raise TypeError('object_instance must be an instance, not a function or class') -# return object_instance - -# if object_creator is not None: -# if not callable(object_creator): -# raise TypeError('object_creator must be a callable') - -# if creator_args is None: -# creator_args = [] - -# if creator_kwargs is None: -# creator_kwargs = {} - -# return object_creator(*creator_args, **creator_kwargs) - -# def pad_agent_data(data, agents, pad_value): -# return {agent: data[agent] if agent in data else pad_value -# for agent in agents} - -# def pad_to_const_num_agents(agents, obs, rewards, dones, truncateds, pad_obs): -# padded_obs = pad_agent_data(obs, agents, pad_obs) -# rewards = pad_agent_data(rewards, agents, 0) -# dones = pad_agent_data(dones, agents, False) -# truncateds = pad_agent_data(truncateds, agents, False) -# return padded_obs, rewards, dones, truncateds - -# def make_flat_and_multidiscrete_atn_space(atn_space): -# flat_action_space = flatten_space(atn_space) -# if len(flat_action_space) == 1: -# return flat_action_space, list(flat_action_space.values())[0] -# multidiscrete_space = convert_to_multidiscrete(flat_action_space) -# return flat_action_space, multidiscrete_space - - -# def make_flat_and_box_obs_space(obs_space): -# obs = obs_space.sample() -# flat_observation_structure = flatten_structure(obs) - -# flat_observation_space = flatten_space(obs_space) -# obs = obs_space.sample() - -# flat_observation = concatenate(flatten(obs)) - -# mmin, mmax = pufferlib.utils._get_dtype_bounds(flat_observation.dtype) -# pad_obs = flat_observation * 0 -# box_obs_space = gymnasium.spaces.Box( -# low=mmin, high=mmax, -# shape=flat_observation.shape, dtype=flat_observation.dtype -# ) - -# return flat_observation_space, flat_observation_structure, box_obs_space, pad_obs - - -# def make_featurized_obs_and_space(obs_space, postprocessor): -# obs_sample = obs_space.sample() -# featurized_obs = postprocessor.observation(obs_sample) -# featurized_obs_space = make_space_like(featurized_obs) -# return featurized_obs_space, featurized_obs - -# def make_team_space(observation_space, agents): -# return gymnasium.spaces.Dict({agent: observation_space(agent) for agent in agents}) - -# def check_space(data, space): -# try: -# contains = space.contains(data) -# except: -# raise ValueError( -# f'Error checking space {space} with sample :\n{data}') - -# if not contains: -# raise ValueError( -# f'Data:\n{data}\n not in space:\n{space}') - -# return True - -# def check_teams(env, teams): -# if set(env.possible_agents) != {item for team in teams.values() for item in team}: -# raise ValueError(f'Invalid teams: {teams} for possible_agents: {env.possible_agents}') - -# def group_into_teams(teams, *args): -# grouped_data = [] - -# for agent_data in args: -# if __debug__: -# if set(agent_data) != {item for team in teams.values() for item in team}: -# raise ValueError(f'Invalid teams: {teams} for agents: {set(agent_data)}') - -# team_data = {} -# for team_id, team in teams.items(): -# team_data[team_id] = {} -# for agent_id in team: -# if agent_id in agent_data: -# team_data[team_id][agent_id] = agent_data[agent_id] - -# grouped_data.append(team_data) - -# if len(grouped_data) == 1: -# return grouped_data[0] - -# return grouped_data - -# def ungroup_from_teams(team_data): -# agent_data = {} -# for team in team_data.values(): -# for agent_id, data in team.items(): -# agent_data[agent_id] = data -# return agent_data - - -# def flatten_structure(data): -# structure = [] - -# def helper(d): -# if isinstance(d, dict): -# structure.append(DICT) -# structure.append(len(d)) -# for key, value in sorted(d.items()): -# structure.append(key) -# helper(value) -# elif isinstance(d, list): -# structure.append(LIST) -# structure.append(len(d)) -# for item in d: -# helper(item) -# elif isinstance(d, tuple): -# structure.append(TUPLE) -# structure.append(len(d)) -# for item in d: -# helper(item) -# else: -# structure.append(VALUE) - -# helper(data) -# return structure - -# def flatten_space(space): -# def _recursion_helper(current, key): -# if isinstance(current, pufferlib.spaces.Tuple): -# for idx, elem in enumerate(current): -# _recursion_helper(elem, f'{key}T{idx}.') -# elif isinstance(current, pufferlib.spaces.Dict): -# for k, value in current.items(): -# _recursion_helper(value, f'{key}D{k}.') -# else: -# flat[f'{key}V'] = current - -# flat = {} -# _recursion_helper(space, '') -# return flat - -# def concatenate(flat_sample): -# # TODO: This section controls whether to special-case -# # pure tensor obs to retain shape. Consider whether this is good. -# if len(flat_sample) == 1: -# flat_sample = flat_sample[0] -# if isinstance(flat_sample,(np.ndarray, -# gymnasium.wrappers.frame_stack.LazyFrames)): -# return flat_sample -# return np.array([flat_sample]) - -# return np.concatenate([ -# e.ravel() if isinstance(e, np.ndarray) else np.array([e]) -# for e in flat_sample] -# ) - -# def split(stacked_sample, flat_space, sz, batched=True): -# if not isinstance(stacked_sample, Iterable): -# return [stacked_sample] - -# if batched: -# batch = stacked_sample.shape[0] - -# leaves = [] -# ptr = 0 -# for sz, subspace in zip(sz, flat_space.values()): -# shape = subspace.shape -# typ = subspace.dtype -# # Patch cached this -# #sz = int(np.prod(shape)) - -# if shape == (): -# shape = (1,) - -# if batched: -# samp = stacked_sample[:, ptr:ptr+sz].reshape(batch, *shape) -# else: -# samp = stacked_sample[ptr:ptr+sz].reshape(*shape).astype(typ) -# if isinstance(subspace, pufferlib.spaces.Discrete): -# samp = int(samp[0]) - -# leaves.append(samp) -# ptr += sz - -# return leaves - -# def convert_to_multidiscrete(flat_space): -# lens = [] -# for e in flat_space.values(): -# if isinstance(e, pufferlib.spaces.Discrete): -# lens.append(e.n) -# elif isinstance(e, pufferlib.spaces.MultiDiscrete): -# lens += e.nvec.tolist() -# else: -# raise ValueError(f'Invalid action space: {e}') - -# return gymnasium.spaces.MultiDiscrete(lens) - -# def make_space_like(ob): -# if type(ob) == np.ndarray: -# mmin, mmax = utils._get_dtype_bounds(ob.dtype) -# return gymnasium.spaces.Box( -# low=mmin, high=mmax, -# shape=ob.shape, dtype=ob.dtype -# ) - -# # TODO: Handle Discrete (how to get max?) -# if type(ob) in (tuple, list): -# return gymnasium.spaces.Tuple([make_space_like(v) for v in ob]) - -# if type(ob) in (dict, OrderedDict): -# return gymnasium.spaces.Dict({k: make_space_like(v) for k, v in ob.items()}) - -# if type(ob) in (int, float): -# # TODO: Tighten bounds -# return gymnasium.spaces.Box(low=-np.inf, high=np.inf, shape=()) - -# raise ValueError(f'Invalid type for featurized obs: {type(ob)}') - -# def _seed_and_reset(env, seed): -# if seed is None: -# # Gym bug: does not reset env correctly -# # when seed is passed as explicit None -# return env.reset() - -# try: -# obs, info = env.reset(seed=seed) -# except: -# try: -# env.seed(seed) -# obs, info = env.reset() -# except: -# obs, info = env.reset() -# warnings.warn('WARNING: Environment does not support seeding.', DeprecationWarning) - -# return obs, info - - from pdb import set_trace as T import numpy as np @@ -665,6 +6,7 @@ import gym import gymnasium import inspect +from functools import cached_property from collections import OrderedDict from collections.abc import Iterable @@ -678,7 +20,6 @@ TUPLE = 2 VALUE = 3 - class Postprocessor: '''Modify data before it is returned from or passed to the environment @@ -762,6 +103,7 @@ def reward_done_truncated_info(self, reward, done, truncated, info): class GymnasiumPufferEnv(gymnasium.Env): def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, postprocessor_cls=BasicPostprocessor): + self.env = make_object(env, env_creator, env_args, env_kwargs) self.postprocessor = postprocessor_cls(self.env, is_multiagent=False) @@ -770,14 +112,24 @@ def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, self.is_observation_checked = False self.is_action_checked = False + + # self.obs_sz = 0 # Cache the observation and action spaces self.observation_space self.action_space + + # BET ADDED 0.7 + self.unflatten_context = pufferlib.namespace( + flat_observation_space=self.flat_observation_space, + flat_observation_structure=self.flat_observation_structure, + obs_sz=self.obs_sz, + ) + self.render_modes = 'human rgb_array'.split() self.render_mode = 'rgb_array' - @property + @cached_property def observation_space(self): '''Returns a flattened, single-tensor observation space''' self.structured_observation_space = self.postprocessor.observation_space @@ -787,9 +139,14 @@ def observation_space(self): self.flat_observation_space, self.flat_observation_structure, self.box_observation_space, self.pad_observation = ( make_flat_and_box_obs_space(self.structured_observation_space)) + self.obs_sz = [ + int(np.prod(subspace.shape)) + for subspace in self.flat_observation_space.values() + ] + return self.box_observation_space - @property + @cached_property def action_space(self): '''Returns a flattened, multi-discrete action space''' self.structured_action_space = self.env.action_space @@ -799,6 +156,11 @@ def action_space(self): self.flat_action_space, self.multidiscrete_action_space = ( make_flat_and_multidiscrete_atn_space(self.env.action_space)) + self.atn_sz = [ + int(np.prod(subspace.shape)) + for subspace in self.flat_action_space.values() + ] + return self.multidiscrete_action_space def seed(self, seed): @@ -838,7 +200,7 @@ def step(self, action): # Unpack actions from multidiscrete into the original action space action = unflatten( split( - action, self.flat_action_space, batched=False + action, self.flat_action_space, self.atn_sz, batched=False ), self.flat_action_structure ) @@ -858,12 +220,13 @@ def close(self): return self.env.close() def unpack_batched_obs(self, batched_obs): - return unpack_batched_obs(batched_obs, self.flat_observation_space, self.flat_observation_structure) + return unpack_batched_obs(batched_obs, self.flat_observation_space, self.flat_observation_structure, self.atn_sz) class PettingZooPufferEnv: def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, postprocessor_cls=Postprocessor, postprocessor_kwargs={}, teams=None): + self.env = make_object(env, env_creator, env_args, env_kwargs) self.initialized = False self.all_done = True @@ -882,6 +245,12 @@ def __init__(self, env=None, env_creator=None, env_args=[], env_kwargs={}, self.observation_space(self.possible_agents[0]) self.action_space(self.possible_agents[0]) + self.unflatten_context = pufferlib.namespace( + flat_observation_space=self.flat_observation_space, + flat_observation_structure=self.flat_observation_structure, + obs_sz=self.obs_sz, + ) + @property def agents(self): return self.env.agents @@ -917,6 +286,11 @@ def observation_space(self, agent): self.flat_observation_space, self.flat_observation_structure, self.box_observation_space, self.pad_observation = ( make_flat_and_box_obs_space(self.structured_observation_space)) + self.obs_sz = [ + int(np.prod(subspace.shape)) + for subspace in self.flat_observation_space.values() + ] + return self.box_observation_space def action_space(self, agent): @@ -937,6 +311,11 @@ def action_space(self, agent): # Store a flat version of the action space for use in step. Return a multidiscrete version for the user self.flat_action_space, self.multidiscrete_action_space = make_flat_and_multidiscrete_atn_space(atn_space) + self.atn_sz = [ + int(np.prod(subspace.shape)) + for subspace in self.flat_action_space.values() + ] + return self.multidiscrete_action_space def reset(self, seed=None): @@ -1006,7 +385,7 @@ def step(self, actions): for agent, atn in actions.items(): if agent in self.agents: unpacked_actions[agent] = unflatten( - split(atn, self.flat_action_space, batched=False), + split(atn, self.flat_action_space, self.atn_sz, batched=False), self.flat_action_structure ) @@ -1048,12 +427,13 @@ def close(self): return self.env.close() def unpack_batched_obs(self, batched_obs): - return unpack_batched_obs(batched_obs, self.flat_observation_space, self.flat_observation_structure) + return unpack_batched_obs(batched_obs, + self.flat_observation_space, self.flat_observation_structure, self.atn_sz) - -def unpack_batched_obs(batched_obs, flat_observation_space, flat_observation_structure): - unpacked = split(batched_obs, flat_observation_space, batched=True) - unflattened = unflatten(unpacked, flat_observation_structure) +def unpack_batched_obs(batched_obs, unflatten_context): + unpacked = split(batched_obs, unflatten_context.flat_observation_space, + unflatten_context.obs_sz, batched=True) + unflattened = unflatten(unpacked, unflatten_context.flat_observation_structure) return unflattened def make_object(object_instance=None, object_creator=None, creator_args=[], creator_kwargs={}): @@ -1128,18 +508,18 @@ def check_space(data, space): try: contains = space.contains(data) except: - raise ValueError( + raise exceptions.APIUsageError( f'Error checking space {space} with sample :\n{data}') if not contains: - raise ValueError( + raise exceptions.APIUsageError( f'Data:\n{data}\n not in space:\n{space}') return True def check_teams(env, teams): if set(env.possible_agents) != {item for team in teams.values() for item in team}: - raise ValueError(f'Invalid teams: {teams} for possible_agents: {env.possible_agents}') + raise exceptions.APIUsageError(f'Invalid teams: {teams} for possible_agents: {env.possible_agents}') def group_into_teams(teams, *args): grouped_data = [] @@ -1147,7 +527,7 @@ def group_into_teams(teams, *args): for agent_data in args: if __debug__: if set(agent_data) != {item for team in teams.values() for item in team}: - raise ValueError(f'Invalid teams: {teams} for agents: {set(agent_data)}') + raise exceptions.APIUsageError(f'Invalid teams: {teams} for agents: {set(agent_data)}') team_data = {} for team_id, team in teams.items(): @@ -1213,30 +593,37 @@ def _recursion_helper(current, key): return flat def concatenate(flat_sample): + # TODO: This section controls whether to special-case + # pure tensor obs to retain shape. Consider whether this is good. if len(flat_sample) == 1: flat_sample = flat_sample[0] if isinstance(flat_sample,(np.ndarray, gymnasium.wrappers.frame_stack.LazyFrames)): return flat_sample return np.array([flat_sample]) + return np.concatenate([ e.ravel() if isinstance(e, np.ndarray) else np.array([e]) for e in flat_sample] ) -def split(stacked_sample, flat_space, batched=True): +def split(stacked_sample, flat_space, sz, batched=True): if not isinstance(stacked_sample, Iterable): return [stacked_sample] if batched: batch = stacked_sample.shape[0] + elif len(sz) == 1: + # This probably breaks for dicts with 1 element etc + return [stacked_sample] leaves = [] ptr = 0 - for subspace in flat_space.values(): + for sz, subspace in zip(sz, flat_space.values()): shape = subspace.shape typ = subspace.dtype - sz = int(np.prod(shape)) + # Patch cached this + #sz = int(np.prod(shape)) if shape == (): shape = (1,) @@ -1302,4 +689,4 @@ def _seed_and_reset(env, seed): obs, info = env.reset() warnings.warn('WARNING: Environment does not support seeding.', DeprecationWarning) - return obs, info + return obs, info \ No newline at end of file diff --git a/pufferlib/environments/pokemon_red/environment.py b/pufferlib/environments/pokemon_red/environment.py index c36ed041..1e1afed6 100755 --- a/pufferlib/environments/pokemon_red/environment.py +++ b/pufferlib/environments/pokemon_red/environment.py @@ -1,13 +1,9 @@ # from pdb import set_trace as T - # import gymnasium # import functools - # from pokegym import Environment - # import pufferlib.emulation - # def env_creator(name='pokemon_red'): # return functools.partial(make, name) @@ -17,27 +13,20 @@ # return pufferlib.emulation.GymnasiumPufferEnv(env=env, # postprocessor_cls=pufferlib.emulation.BasicPostprocessor) - - import functools - import pufferlib.emulation - from pokegym import Environment from stream_wrapper import StreamWrapper - def env_creator(name="pokemon_red"): return functools.partial(make, name) - -def make(name, **kwargs): +def make(name, **kwargs,): """Pokemon Red""" env = Environment(kwargs) - - env = StreamWrapper(env, stream_metadata={"user": "BET\nlittleforleanke\nBET"}) + env = StreamWrapper(env, stream_metadata={"user": " BET \n===PUFFERLIB===\n====BOEY====\n BET"}) # Looks like the following will optionally create the object for you - # Or use theo ne you pass it. I'll just construct it here. + # Or use the one you pass it. I'll just construct it here. return pufferlib.emulation.GymnasiumPufferEnv( env=env, postprocessor_cls=pufferlib.emulation.BasicPostprocessor ) \ No newline at end of file diff --git a/pufferlib/environments/pokemon_red/torch.py b/pufferlib/environments/pokemon_red/torch.py index 59d100b1..99cd9ed9 100755 --- a/pufferlib/environments/pokemon_red/torch.py +++ b/pufferlib/environments/pokemon_red/torch.py @@ -10,7 +10,7 @@ import numpy as np class Recurrent(pufferlib.models.RecurrentWrapper): - def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): + def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1): # input_size=512, hidden_size=512 super().__init__(env, policy, input_size, hidden_size, num_layers) # class Policy(pufferlib.models.Convolutional): @@ -293,25 +293,31 @@ def __init__( # self._features_dim = 406 # self.fc1 = nn.Linear(406,512) - self.fc1 = nn.Linear(1251,512) - self.fc2 = nn.Linear(512,512) + self.fc1 = nn.Linear(1251,512) # self.fc1 = nn.Linear(1251,512) + # self.fc2 = nn.Linear(512,512) + self.fc2 = nn.Linear(512,512) # BET ADDED 2/27/24 self.action = nn.Linear(512, self.action_space.n) self.value_head = nn.Linear(512,1) # breakpoint() + # BET ADDED (this is the forward fn in the CustomCombinedExtractorV2(BaseFeaturesExtractor) class) def encode_observations(self, observations: TensorDict) -> th.Tensor: # sz = [ # int(np.prod(subspace.shape)) # for subspace in self.flat_observation_space.values() # ] - observations = pufferlib.emulation.unpack_batched_obs(observations, - self.flat_observation_space, self.flat_observation_structure) + + # BET ADDED 0.7 + # Adjust the call to unpack_batched_obs to use the unflatten_context + observations = pufferlib.emulation.unpack_batched_obs(observations, self.unflatten_context) + + # observations = pufferlib.emulation.unpack_batched_obs(observations, + # self.flat_observation_space, self.flat_observation_structure) # img = self.image_cnn(observations['image']) # (256, ) img = self.cnn_linear(self.cnn(observations['image'])) # (512, ) - # minimap_sprite minimap_sprite = observations['minimap_sprite'].to(th.int) # (9, 10) embedded_minimap_sprite = self.minimap_sprite_embedding(minimap_sprite) # (9, 10, 8) @@ -325,8 +331,6 @@ def encode_observations(self, observations: TensorDict) -> th.Tensor: minimap = th.cat([minimap, embedded_minimap_sprite, embedded_minimap_warp], dim=1) # (14 + 8 + 8, 9, 10) # minimap minimap = self.minimap_cnn_linear(self.minimap_cnn(minimap)) # (256, ) - - # Pokemon # Moves @@ -366,14 +370,6 @@ def encode_observations(self, observations: TensorDict) -> th.Tensor: item_features = self.item_ids_fc_relu(item_concat) # (20, 32) item_features = self.item_ids_max_pool(item_features).squeeze(-2) # (20, 32) -> (32, ) - # # Items - # embedded_item_ids = self.item_ids_embedding(observations['item_ids'].to(th.int)) # (20, 16) - # # item_quantity - # item_quantity = observations['item_quantity'] # (20, 1) - # item_concat = th.cat([embedded_item_ids, item_quantity], dim=-1) # (20, 17) - # item_features = self.item_ids_fc_relu(item_concat) # (20, 16) - # item_features = self.item_ids_max_pool(item_features).squeeze(-2) # (20, 16) -> (16, ) - # Events embedded_event_ids = self.event_ids_embedding(observations['event_ids'].to(th.int)) # event_step_since @@ -396,8 +392,6 @@ def encode_observations(self, observations: TensorDict) -> th.Tensor: # # Raw vector # vector = observations['vector'] # (54, ) - # Concat all features - # Concat all features all_features = th.cat([img, minimap, poke_party_head, poke_opp_head, item_features, event_features, vector, map_features], dim=-1) # (410 + 256, ) hidden = self.fc2(F.relu(self.fc1(all_features))) diff --git a/pufferlib/models.py b/pufferlib/models.py index 47eefe31..dc82ee52 100755 --- a/pufferlib/models.py +++ b/pufferlib/models.py @@ -42,6 +42,9 @@ def __init__(self, env): self.observation_space = env.single_observation_space self.action_space = env.single_action_space + # Used to unflatten observation in forward pass + self.unflatten_context = env.unflatten_context + self.is_multidiscrete = isinstance(self.action_space, pufferlib.spaces.MultiDiscrete) @@ -56,7 +59,7 @@ def encode_observations(self, flat_observations): function to unflatten observations to their original structured form: observations = pufferlib.emulation.unpack_batched_obs( - self.envs.structured_observation_space, env_outputs) + env_outputs, self.unflatten_context) Args: flat_observations: A tensor of shape (batch, ..., obs_size) @@ -125,7 +128,6 @@ def forward(self, x, state): assert state[0].shape[1] == state[1].shape[1] == B x = x.reshape(B*TT, *space_shape) - # breakpoint() hidden, lookup = self.policy.encode_observations(x) assert hidden.shape == (B*TT, self.input_size) @@ -220,3 +222,73 @@ def decode_actions(self, flat_hidden, lookup, concat=None): action = self.actor(flat_hidden) value = self.value_fn(flat_hidden) return action, value + +# ResNet Procgen baseline +# https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/142d09586d2272a17f44481a115c4bd817cf6a94/models/impala_cnn_torch.py +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv0 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + self.conv1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1) + + def forward(self, x): + inputs = x + x = nn.functional.relu(x) + x = self.conv0(x) + x = nn.functional.relu(x) + x = self.conv1(x) + return x + inputs + +class ConvSequence(nn.Module): + def __init__(self, input_shape, out_channels): + super().__init__() + self._input_shape = input_shape + self._out_channels = out_channels + self.conv = nn.Conv2d(in_channels=self._input_shape[0], out_channels=self._out_channels, kernel_size=3, padding=1) + self.res_block0 = ResidualBlock(self._out_channels) + self.res_block1 = ResidualBlock(self._out_channels) + + def forward(self, x): + x = self.conv(x) + x = nn.functional.max_pool2d(x, kernel_size=3, stride=2, padding=1) + x = self.res_block0(x) + x = self.res_block1(x) + assert x.shape[1:] == self.get_output_shape() + return x + + def get_output_shape(self): + _c, h, w = self._input_shape + return (self._out_channels, (h + 1) // 2, (w + 1) // 2) + +class ProcgenResnet(Policy): + def __init__(self, env, cnn_width=16, mlp_width=256): + super().__init__(env) + h, w, c = env.structured_observation_space.shape + shape = (c, h, w) + conv_seqs = [] + for out_channels in [cnn_width, 2*cnn_width, 2*cnn_width]: + conv_seq = ConvSequence(shape, out_channels) + shape = conv_seq.get_output_shape() + conv_seqs.append(conv_seq) + conv_seqs += [ + nn.Flatten(), + nn.ReLU(), + nn.Linear(in_features=shape[0] * shape[1] * shape[2], out_features=mlp_width), + nn.ReLU(), + ] + self.network = nn.Sequential(*conv_seqs) + self.actor = pufferlib.pytorch.layer_init( + nn.Linear(mlp_width, self.action_space.n), std=0.01) + self.value = pufferlib.pytorch.layer_init( + nn.Linear(mlp_width, 1), std=1) + + def encode_observations(self, x): + x = pufferlib.emulation.unpack_batched_obs(x, self.unflatten_context) + hidden = self.network(x.permute((0, 3, 1, 2)) / 255.0) + return hidden, None + + def decode_actions(self, hidden, lookup): + '''linear decoder function''' + action = self.actor(hidden) + value = self.value(hidden) + return action, value \ No newline at end of file diff --git a/pufferlib/multi_env.py b/pufferlib/multi_env.py new file mode 100755 index 00000000..4e4ef646 --- /dev/null +++ b/pufferlib/multi_env.py @@ -0,0 +1,353 @@ +from pdb import set_trace as T +import numpy as np + +from pufferlib.namespace import Namespace +import pufferlib.exceptions + + +def create_precheck(env_creator, env_args, env_kwargs): + if env_args is None: + env_args = [] + if env_kwargs is None: + env_kwargs = {} + + if not callable(env_creator): + raise pufferlib.exceptions.APIUsageError('env_creator must be callable') + if not isinstance(env_args, list): + raise pufferlib.exceptions.APIUsageError('env_args must be a list') + # TODO: port namespace to Mapping + if not isinstance(env_kwargs, (dict, Namespace)): + raise pufferlib.exceptions.APIUsageError('env_kwargs must be a dictionary or None') + # print(f'\n\n\nMULTI_ENV LINE21 env_args={env_args}\n\n\n') + # print(f'\n\n\nnMULTI_ENV LINE22 env_kwargs={vars(env_kwargs)}\n\n\n') + return env_args, env_kwargs + +def __init__(self, env_creator: callable = None, env_args: list = [], + env_kwargs: dict = {}, n: int = 1): + env_args, env_kwargs = create_precheck(env_creator, env_args, env_kwargs) + self.envs = [env_creator(*env_args, **env_kwargs) for _ in range(n)] + self.preallocated_obs = None + +def put(state, *args, **kwargs): + for e in state.envs: + e.put(*args, **kwargs) + +def get(state, *args, **kwargs): + return [e.get(*args, **kwargs) for e in state.envs] + +def close(state): + for env in state.envs: + env.close() + +class GymnasiumMultiEnv: + __init__ = __init__ + put = put + get = get + close = close + + def reset(self, seed=None): + if self.preallocated_obs is None: + obs_space = self.envs[0].observation_space + obs_n = obs_space.shape[0] + n_envs = len(self.envs) + + self.preallocated_obs = np.empty( + (n_envs, *obs_space.shape), dtype=obs_space.dtype) + self.preallocated_rewards = np.empty(n_envs, dtype=np.float32) + self.preallocated_dones = np.empty(n_envs, dtype=bool) + self.preallocated_truncateds = np.empty(n_envs, dtype=bool) + + infos = [] + for idx, e in enumerate(self.envs): + if seed is None: + ob, i = e.reset() + else: + ob, i = e.reset(seed=hash(1000*seed + idx)) + + i['mask'] = True + infos.append(i) + self.preallocated_obs[idx] = ob + + self.preallocated_rewards[:] = 0 + self.preallocated_dones[:] = False + self.preallocated_truncateds[:] = False + + return (self.preallocated_obs, self.preallocated_rewards, + self.preallocated_dones, self.preallocated_truncateds, infos) + + def step(self, actions): + infos = [] + for idx, (env, atn) in enumerate(zip(self.envs, actions)): + if env.done: + o, i = env.reset() + self.preallocated_rewards[idx] = 0 + self.preallocated_dones[idx] = False + self.preallocated_truncateds[idx] = False + else: + o, r, d, t, i = env.step(atn) + self.preallocated_rewards[idx] = r + self.preallocated_dones[idx] = d + self.preallocated_truncateds[idx] = t + + i['mask'] = True + infos.append(i) + self.preallocated_obs[idx] = o + + return (self.preallocated_obs, self.preallocated_rewards, + self.preallocated_dones, self.preallocated_truncateds, infos) + +class PettingZooMultiEnv: + __init__ = __init__ + put = put + get = get + close = close + + def reset(self, seed=None): + if self.preallocated_obs is None: + obs_space = self.envs[0].single_observation_space + obs_n = obs_space.shape[0] + n_agents = len(self.envs[0].possible_agents) + n_envs = len(self.envs) + n = n_envs * n_agents + + self.preallocated_obs = np.empty( + (n, *obs_space.shape), dtype=obs_space.dtype) + self.preallocated_rewards = np.empty(n, dtype=np.float32) + self.preallocated_dones = np.empty(n, dtype=bool) + self.preallocated_truncateds = np.empty(n, dtype=bool) + + self.agent_keys = [] + infos = [] + ptr = 0 + for idx, e in enumerate(self.envs): + if seed is None: + obs, i = e.reset() + else: + obs, i = e.reset(seed=hash(1000*seed + idx)) + + self.agent_keys.append(list(obs.keys())) + infos.append(i) + + for o in obs.values(): + self.preallocated_obs[ptr] = o + ptr += 1 + + self.preallocated_rewards[:] = 0 + self.preallocated_dones[:] = False + self.preallocated_truncateds[:] = False + + return (self.preallocated_obs, self.preallocated_rewards, + self.preallocated_dones, self.preallocated_truncateds, infos) + + def step(self, actions): + actions = np.array_split(actions, len(self.envs)) + rewards, dones, truncateds, infos = [], [], [], [] + + ptr = 0 + n_envs = len(self.envs) + n_agents = len(self.envs[0].possible_agents) + assert n_envs == len(self.agent_keys) == len(actions) + + for idx in range(n_envs): + a_keys, env, atns = self.agent_keys[idx], self.envs[idx], actions[idx] + start = idx * n_agents + end = start + n_agents + if env.done: + o, i = env.reset() + self.preallocated_rewards[start:end] = 0 + self.preallocated_dones[start:end] = False + self.preallocated_truncateds[start:end] = False + else: + assert len(a_keys) == len(atns) + atns = dict(zip(a_keys, atns)) + o, r, d, t, i = env.step(atns) + self.preallocated_rewards[start:end] = list(r.values()) + self.preallocated_dones[start:end] = list(d.values()) + self.preallocated_truncateds[start:end] = list(t.values()) + + infos.append(i) + self.agent_keys[idx] = list(o.keys()) + + for oo in o.values(): + self.preallocated_obs[ptr] = oo + ptr += 1 + + return (self.preallocated_obs, self.preallocated_rewards, + self.preallocated_dones, self.preallocated_truncateds, infos) + + +# from pdb import set_trace as T +# from collections.abc import Mapping +# import numpy as np + +# import pufferlib.exceptions + + +# def create_precheck(env_creator, env_args, env_kwargs): +# if env_args is None: +# env_args = [] +# if env_kwargs is None: +# env_kwargs = {} + +# if not callable(env_creator): +# raise pufferlib.exceptions.APIUsageError('env_creator must be callable') +# if not isinstance(env_args, list): +# raise pufferlib.exceptions.APIUsageError('env_args must be a list') +# # TODO: port namespace to Mapping +# # if not isinstance(env_kwargs, Mapping): +# # raise pufferlib.exceptions.APIUsageError('env_kwargs must be a dictionary or None') + +# return env_args, env_kwargs + +# def __init__(self, env_creator: callable = None, env_args: list = [], +# env_kwargs: dict = {}, n: int = 1): +# env_args, env_kwargs = create_precheck(env_creator, env_args, env_kwargs) +# self.envs = [env_creator(*env_args, **env_kwargs) for _ in range(n)] +# self.preallocated_obs = None + +# def put(state, *args, **kwargs): +# for e in state.envs: +# e.put(*args, **kwargs) + +# def get(state, *args, **kwargs): +# return [e.get(*args, **kwargs) for e in state.envs] + +# def close(state): +# for env in state.envs: +# env.close() + +# class GymnasiumMultiEnv: +# __init__ = __init__ +# put = put +# get = get +# close = close + +# def reset(self, seed=None): +# if self.preallocated_obs is None: +# obs_space = self.envs[0].observation_space +# obs_n = obs_space.shape[0] +# n_envs = len(self.envs) + +# self.preallocated_obs = np.empty( +# (n_envs, *obs_space.shape), dtype=obs_space.dtype) +# self.preallocated_rewards = np.empty(n_envs, dtype=np.float32) +# self.preallocated_dones = np.empty(n_envs, dtype=bool) +# self.preallocated_truncateds = np.empty(n_envs, dtype=bool) + +# infos = [] +# for idx, e in enumerate(self.envs): +# if seed is None: +# ob, i = e.reset() +# else: +# ob, i = e.reset(seed=hash(1000*seed + idx)) + +# i['mask'] = True +# infos.append(i) +# self.preallocated_obs[idx] = ob + +# self.preallocated_rewards[:] = 0 +# self.preallocated_dones[:] = False +# self.preallocated_truncateds[:] = False + +# return (self.preallocated_obs, self.preallocated_rewards, +# self.preallocated_dones, self.preallocated_truncateds, infos) + +# def step(self, actions): +# infos = [] +# for idx, (env, atn) in enumerate(zip(self.envs, actions)): +# if env.done: +# o, i = env.reset() +# self.preallocated_rewards[idx] = 0 +# self.preallocated_dones[idx] = False +# self.preallocated_truncateds[idx] = False +# else: +# o, r, d, t, i = env.step(atn) +# self.preallocated_rewards[idx] = r +# self.preallocated_dones[idx] = d +# self.preallocated_truncateds[idx] = t + +# i['mask'] = True +# infos.append(i) +# self.preallocated_obs[idx] = o + +# return (self.preallocated_obs, self.preallocated_rewards, +# self.preallocated_dones, self.preallocated_truncateds, infos) + +# class PettingZooMultiEnv: +# __init__ = __init__ +# put = put +# get = get +# close = close + +# def reset(self, seed=None): +# if self.preallocated_obs is None: +# obs_space = self.envs[0].single_observation_space +# obs_n = obs_space.shape[0] +# n_agents = len(self.envs[0].possible_agents) +# n_envs = len(self.envs) +# n = n_envs * n_agents + +# self.preallocated_obs = np.empty( +# (n, *obs_space.shape), dtype=obs_space.dtype) +# self.preallocated_rewards = np.empty(n, dtype=np.float32) +# self.preallocated_dones = np.empty(n, dtype=bool) +# self.preallocated_truncateds = np.empty(n, dtype=bool) + +# self.agent_keys = [] +# infos = [] +# ptr = 0 +# for idx, e in enumerate(self.envs): +# if seed is None: +# obs, i = e.reset() +# else: +# obs, i = e.reset(seed=hash(1000*seed + idx)) + +# self.agent_keys.append(list(obs.keys())) +# infos.append(i) + +# for o in obs.values(): +# self.preallocated_obs[ptr] = o +# ptr += 1 + +# self.preallocated_rewards[:] = 0 +# self.preallocated_dones[:] = False +# self.preallocated_truncateds[:] = False + +# return (self.preallocated_obs, self.preallocated_rewards, +# self.preallocated_dones, self.preallocated_truncateds, infos) + +# def step(self, actions): +# actions = np.array_split(actions, len(self.envs)) +# rewards, dones, truncateds, infos = [], [], [], [] + +# ptr = 0 +# n_envs = len(self.envs) +# n_agents = len(self.envs[0].possible_agents) +# assert n_envs == len(self.agent_keys) == len(actions) + +# for idx in range(n_envs): +# a_keys, env, atns = self.agent_keys[idx], self.envs[idx], actions[idx] +# start = idx * n_agents +# end = start + n_agents +# if env.done: +# o, i = env.reset() +# self.preallocated_rewards[start:end] = 0 +# self.preallocated_dones[start:end] = False +# self.preallocated_truncateds[start:end] = False +# else: +# assert len(a_keys) == len(atns) +# atns = dict(zip(a_keys, atns)) +# o, r, d, t, i = env.step(atns) +# self.preallocated_rewards[start:end] = list(r.values()) +# self.preallocated_dones[start:end] = list(d.values()) +# self.preallocated_truncateds[start:end] = list(t.values()) + +# infos.append(i) +# self.agent_keys[idx] = list(o.keys()) + +# for oo in o.values(): +# self.preallocated_obs[ptr] = oo +# ptr += 1 + +# return (self.preallocated_obs, self.preallocated_rewards, +# self.preallocated_dones, self.preallocated_truncateds, infos) \ No newline at end of file diff --git a/pufferlib/namespace.py b/pufferlib/namespace.py index a9cfd2d1..76ba03da 100755 --- a/pufferlib/namespace.py +++ b/pufferlib/namespace.py @@ -1,5 +1,6 @@ from pdb import set_trace as T from types import SimpleNamespace +from collections.abc import Mapping def __getitem__(self, key): return self.__dict__[key] @@ -13,8 +14,16 @@ def values(self): def items(self): return self.__dict__.items() -class Namespace(SimpleNamespace): +def __iter__(self): + return iter(self.__dict__) + +def __len__(self): + return len(self.__dict__) + +class Namespace(SimpleNamespace, Mapping): __getitem__ = __getitem__ + __iter__ = __iter__ + __len__ = __len__ keys = keys values = values items = items @@ -33,6 +42,8 @@ def __init__(self, **kwargs): cls.__init__ = __init__ setattr(cls, "__getitem__", __getitem__) + setattr(cls, "__iter__", __iter__) + setattr(cls, "__len__", __len__) setattr(cls, "keys", keys) setattr(cls, "values", values) setattr(cls, "items", items) @@ -41,4 +52,4 @@ def __init__(self, **kwargs): def namespace(self=None, **kwargs): if self is None: return Namespace(**kwargs) - self.__dict__.update(kwargs) + self.__dict__.update(kwargs) \ No newline at end of file diff --git a/pufferlib/vectorization.py b/pufferlib/vectorization.py new file mode 100755 index 00000000..82332083 --- /dev/null +++ b/pufferlib/vectorization.py @@ -0,0 +1,558 @@ +from pdb import set_trace as T + +import numpy as np +import gymnasium +from itertools import chain +import psutil +import time + + +from pufferlib import namespace +from pufferlib.emulation import GymnasiumPufferEnv, PettingZooPufferEnv +from pufferlib.multi_env import create_precheck, GymnasiumMultiEnv, PettingZooMultiEnv +from pufferlib.exceptions import APIUsageError +import pufferlib.spaces + +import argparse + + +RESET = 0 +SEND = 1 +RECV = 2 + +space_error_msg = 'env {env} must be an instance of GymnasiumPufferEnv or PettingZooPufferEnv' + + +def calc_scale_params(num_envs, envs_per_batch, envs_per_worker, agents_per_env): + '''These calcs are simple but easy to mess up and hard to catch downstream. + We do them all at once here to avoid that''' + + if num_envs % envs_per_worker != 0: + raise APIUsageError('num_envs must be divisible by envs_per_worker') + + num_workers = num_envs // envs_per_worker + envs_per_batch = num_envs if envs_per_batch is None else envs_per_batch + + if envs_per_batch > num_envs: + raise APIUsageError('envs_per_batch must be <= num_envs') + if envs_per_batch % envs_per_worker != 0: + raise APIUsageError('envs_per_batch must be divisible by envs_per_worker') + if envs_per_batch < 1: + raise APIUsageError('envs_per_batch must be > 0') + + workers_per_batch = envs_per_batch // envs_per_worker + assert workers_per_batch <= num_workers + + agents_per_batch = envs_per_batch * agents_per_env + agents_per_worker = envs_per_worker * agents_per_env + + return num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker + +def setup(env_creator, env_args, env_kwargs): + # breakpoint() + # print(f'env_args VECTOR52={env_args} || env_kwargs={env_kwargs}') + + if isinstance(env_kwargs, argparse.Namespace): + env_kwargs = vars(env_kwargs) # Convert Namespace to dict + else: + env_kwargs = env_kwargs # Use as is if it's already a dict + + env_args, env_kwargs = create_precheck(env_creator, env_args, env_kwargs) + driver_env = env_creator(*env_args, **env_kwargs) + + if isinstance(driver_env, GymnasiumPufferEnv): + multi_env_cls = GymnasiumMultiEnv + env_agents = 1 + is_multiagent = False + elif isinstance(driver_env, PettingZooPufferEnv): + multi_env_cls = PettingZooMultiEnv + env_agents = len(driver_env.possible_agents) + is_multiagent = True + else: + raise TypeError( + 'env_creator must return an instance ' + 'of GymnasiumPufferEnv or PettingZooPufferEnv' + ) + + obs_space = _single_observation_space(driver_env) + return driver_env, multi_env_cls, env_agents + +def _single_observation_space(env): + if isinstance(env, GymnasiumPufferEnv): + return env.observation_space + elif isinstance(env, PettingZooPufferEnv): + return env.single_observation_space + else: + raise TypeError(space_error_msg.format(env=env)) + +def single_observation_space(state): + return _single_observation_space(state.driver_env) + +def _single_action_space(env): + if isinstance(env, GymnasiumPufferEnv): + return env.action_space + elif isinstance(env, PettingZooPufferEnv): + return env.single_action_space + else: + raise TypeError(space_error_msg.format(env=env)) + +def single_action_space(state): + return _single_action_space(state.driver_env) + +def structured_observation_space(state): + return state.driver_env.structured_observation_space + +def flat_observation_space(state): + return state.driver_env.flat_observation_space + +def unpack_batched_obs(state, obs): + return state.driver_env.unpack_batched_obs(obs) + +def recv_precheck(state): + assert state.flag == RECV, 'Call reset before stepping' + state.flag = SEND + +def send_precheck(state): + assert state.flag == SEND, 'Call reset + recv before send' + state.flag = RECV + +def reset_precheck(state): + assert state.flag == RESET, 'Call reset only once on initialization' + state.flag = RECV + +# BET ADDED 0.7 +def reset(self, seed=None): + self.async_reset(seed) + data = self.recv() + return data[0], data[4] + +def step(self, actions): + self.send(actions) + return self.recv()[:-1] + +def aggregate_recvs(state, recvs): + obs, rewards, dones, truncateds, infos, env_ids = list(zip(*recvs)) + assert all(state.workers_per_batch == len(e) for e in + (obs, rewards, dones, truncateds, infos, env_ids)) + + obs = np.concatenate(obs) + rewards = np.concatenate(rewards) + dones = np.concatenate(dones) + truncateds = np.concatenate(truncateds) + infos = [i for ii in infos for i in ii] + + obs_space = state.driver_env.structured_observation_space + if isinstance(obs_space, pufferlib.spaces.Box): + obs = obs.reshape(obs.shape[0], *obs_space.shape) + + # TODO: Masking will break for 1-agent PZ envs + # Replace with check against is_multiagent (add it to state) + if state.agents_per_env > 1: + mask = [e['mask'] for ee in infos for e in ee.values()] + else: + mask = [e['mask'] for e in infos] + + env_ids = np.concatenate([np.arange( # Per-agent env indexing + i*state.agents_per_worker, (i+1)*state.agents_per_worker) for i in env_ids]) + + assert all(state.agents_per_batch == len(e) for e in + (obs, rewards, dones, truncateds, env_ids, mask)) + assert len(infos) == state.envs_per_batch + + # BET ADDED 0.7 + if state.mask_agents: + return obs, rewards, dones, truncateds, infos, env_ids, mask + + return obs, rewards, dones, truncateds, infos, env_ids + +def split_actions(state, actions, env_id=None): + assert isinstance(actions, (list, np.ndarray)) + if type(actions) == list: + actions = np.array(actions) + + assert len(actions) == state.agents_per_batch + return np.array_split(actions, state.workers_per_batch) + + +class Serial: + '''Runs environments in serial on the main process + + Use this vectorization module for debugging environments + ''' + reset = reset + step = step + single_observation_space = property(single_observation_space) + single_action_space = property(single_action_space) + structured_observation_space = property(structured_observation_space) + flat_observation_space = property(flat_observation_space) + unpack_batched_obs = unpack_batched_obs + def __init__(self, + env_creator: callable = None, + env_args: list = [], + env_kwargs: dict = {}, + num_envs: int = 1, + envs_per_worker: int = 1, + envs_per_batch: int = None, + env_pool: bool = False, + mask_agents: bool = False, + ) -> None: + self.driver_env, self.multi_env_cls, self.agents_per_env = setup( + env_creator, env_args, env_kwargs) + + self.num_envs = num_envs + self.num_workers, self.workers_per_batch, self.envs_per_batch, self.agents_per_batch, self.agents_per_worker = calc_scale_params( + num_envs, envs_per_batch, envs_per_worker, self.agents_per_env) + self.envs_per_worker = envs_per_worker + self.mask_agents = mask_agents + + self.multi_envs = [ + self.multi_env_cls( + env_creator, env_args, env_kwargs, envs_per_worker, + ) for _ in range(self.num_workers) + ] + + self.flag = RESET + + def recv(self): + recv_precheck(self) + recvs = [(o, r, d, t, i, env_id) for (o, r, d, t, i), env_id + in zip(self.data, range(self.workers_per_batch))] + return aggregate_recvs(self, recvs) + + def send(self, actions): + send_precheck(self) + actions = split_actions(self, actions) + self.data = [e.step(a) for e, a in zip(self.multi_envs, actions)] + + def async_reset(self, seed=None): + reset_precheck(self) + if seed is None: + self.data = [e.reset() for e in self.multi_envs] + else: + self.data = [e.reset(seed=seed+idx) for idx, e in enumerate(self.multi_envs)] + + def put(self, *args, **kwargs): + for e in self.multi_envs: + e.put(*args, **kwargs) + + def get(self, *args, **kwargs): + return [e.get(*args, **kwargs) for e in self.multi_envs] + + def close(self): + for e in self.multi_envs: + e.close() + +def _unpack_shared_mem(shared_mem, n): + np_buf = np.frombuffer(shared_mem.get_obj(), dtype=float) + obs_arr = np_buf[:-3*n] + rewards_arr = np_buf[-3*n:-2*n] + terminals_arr = np_buf[-2*n:-n] + truncated_arr = np_buf[-n:] + + return obs_arr, rewards_arr, terminals_arr, truncated_arr + +def _worker_process(multi_env_cls, env_creator, env_args, env_kwargs, + agents_per_env, envs_per_worker, + worker_idx, shared_mem, send_pipe, recv_pipe): + + # I don't know if this helps. Sometimes it does, sometimes not. + # Need to run more comprehensive tests + #curr_process = psutil.Process() + #curr_process.cpu_affinity([worker_idx]) + + envs = multi_env_cls(env_creator, env_args, env_kwargs, n=envs_per_worker) + obs_arr, rewards_arr, terminals_arr, truncated_arr = _unpack_shared_mem( + shared_mem, agents_per_env * envs_per_worker) + + while True: + request, args, kwargs = recv_pipe.recv() + func = getattr(envs, request) + response = func(*args, **kwargs) + info = {} + + # TODO: Handle put/get + if request in 'step reset'.split(): + obs, reward, done, truncated, info = response + + # TESTED: There is no overhead associated with 4 assignments to shared memory + # vs. 4 assigns to an intermediate numpy array and then 1 assign to shared memory + obs_arr[:] = obs.ravel() + rewards_arr[:] = reward.ravel() + terminals_arr[:] = done.ravel() + truncated_arr[:] = truncated.ravel() + + send_pipe.send(info) + + +class Multiprocessing: + '''Runs environments in parallel using multiprocessing + + Use this vectorization module for most applications + ''' + reset = reset + step = step + single_observation_space = property(single_observation_space) + single_action_space = property(single_action_space) + structured_observation_space = property(structured_observation_space) + flat_observation_space = property(flat_observation_space) + unpack_batched_obs = unpack_batched_obs + + def __init__(self, + env_creator: callable = None, + env_args: list = [], + env_kwargs: dict = {}, + num_envs: int = 1, + envs_per_worker: int = 1, + envs_per_batch: int = None, + env_pool: bool = False, + mask_agents: bool = False, + ) -> None: + driver_env, multi_env_cls, agents_per_env = setup( + env_creator, env_args, env_kwargs) + num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params( + num_envs, envs_per_batch, envs_per_worker, agents_per_env) + + agents_per_worker = agents_per_env * envs_per_worker + observation_size = int(np.prod(_single_observation_space(driver_env).shape)) + observation_dtype = _single_observation_space(driver_env).dtype + + # Shared memory for obs, rewards, terminals, truncateds + from multiprocessing import Process, Manager, Pipe, Array + shared_mem = [ + Array('d', agents_per_worker*(3+observation_size)) + for _ in range(num_workers) + ] + main_send_pipes, work_recv_pipes = zip(*[Pipe() for _ in range(num_workers)]) + work_send_pipes, main_recv_pipes = zip(*[Pipe() for _ in range(num_workers)]) + + num_cores = psutil.cpu_count() + processes = [Process( + target=_worker_process, + args=(multi_env_cls, env_creator, env_args, env_kwargs, + agents_per_env, envs_per_worker, + i%(num_cores-1), shared_mem[i], + work_send_pipes[i], work_recv_pipes[i]) + ) for i in range(num_workers) + ] + + for p in processes: + p.start() + + # Register all receive pipes with the selector + import selectors + sel = selectors.DefaultSelector() + for pipe in main_recv_pipes: + sel.register(pipe, selectors.EVENT_READ) + + self.processes = processes + self.sel = sel + self.observation_size = observation_size + self.observation_dtype = observation_dtype + self.shared_mem = shared_mem + self.send_pipes = main_send_pipes + self.recv_pipes = main_recv_pipes + self.driver_env = driver_env + self.num_envs = num_envs + self.num_workers = num_workers + self.workers_per_batch = workers_per_batch + self.envs_per_batch = envs_per_batch + self.envs_per_worker = envs_per_worker + self.agents_per_batch = agents_per_batch + self.agents_per_worker = agents_per_worker + self.agents_per_env = agents_per_env + self.async_handles = None + self.flag = RESET + self.prev_env_id = [] + self.env_pool = env_pool + self.mask_agents = mask_agents + + def recv(self): + recv_precheck(self) + recvs = [] + next_env_id = [] + if self.env_pool: + while len(recvs) < self.workers_per_batch: + for key, _ in self.sel.select(timeout=None): + response_pipe = key.fileobj + env_id = self.recv_pipes.index(response_pipe) + + if response_pipe.poll(): + info = response_pipe.recv() + o, r, d, t = _unpack_shared_mem( + self.shared_mem[env_id], self.agents_per_env * self.envs_per_worker) + o = o.reshape( + self.agents_per_env*self.envs_per_worker, + self.observation_size).astype(self.observation_dtype) + + recvs.append((o, r, d, t, info, env_id)) + next_env_id.append(env_id) + + if len(recvs) == self.workers_per_batch: + break + else: + for env_id in range(self.workers_per_batch): + response_pipe = self.recv_pipes[env_id] + info = response_pipe.recv() + o, r, d, t = _unpack_shared_mem( + self.shared_mem[env_id], self.agents_per_env * self.envs_per_worker) + o = o.reshape( + self.agents_per_env*self.envs_per_worker, + self.observation_size).astype(self.observation_dtype) + + recvs.append((o, r, d, t, info, env_id)) + next_env_id.append(env_id) + + self.prev_env_id = next_env_id + return aggregate_recvs(self, recvs) + + def send(self, actions): + send_precheck(self) + actions = split_actions(self, actions) + for i, atns in zip(self.prev_env_id, actions): + self.send_pipes[i].send(("step", [atns], {})) + + def async_reset(self, seed=None): + reset_precheck(self) + if seed is None: + for pipe in self.send_pipes: + pipe.send(("reset", [], {})) + else: + for idx, pipe in enumerate(self.send_pipes): + pipe.send(("reset", [], {"seed": seed+idx})) + + def put(self, *args, **kwargs): + # TODO: Update this + for queue in self.request_queues: + queue.put(("put", args, kwargs)) + + def get(self, *args, **kwargs): + # TODO: Update this + for queue in self.request_queues: + queue.put(("get", args, kwargs)) + + idx = -1 + recvs = [] + while len(recvs) < self.workers_per_batch // self.envs_per_worker: + idx = (idx + 1) % self.num_workers + queue = self.response_queues[idx] + + if queue.empty(): + continue + + response = queue.get() + if response is not None: + recvs.append(response) + + return recvs + + def close(self): + for pipe in self.send_pipes: + pipe.send(("close", [], {})) + + for p in self.processes: + p.terminate() + + for p in self.processes: + p.join() + +class Ray(): + '''Runs environments in parallel on multiple processes using Ray + + Use this module for distributed simulation on a cluster. It can also be + faster than multiprocessing on a single machine for specific environments. + ''' + reset = reset + step = step + single_observation_space = property(single_observation_space) + single_action_space = property(single_action_space) + structured_observation_space = property(structured_observation_space) + flat_observation_space = property(flat_observation_space) + unpack_batched_obs = unpack_batched_obs + + def __init__(self, + env_creator: callable = None, + env_args: list = [], + env_kwargs: dict = {}, + num_envs: int = 1, + envs_per_worker: int = 1, + envs_per_batch: int = None, + env_pool: bool = False, + mask_agents: bool = False, + ) -> None: + driver_env, multi_env_cls, agents_per_env = setup( + env_creator, env_args, env_kwargs) + num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params( + num_envs, envs_per_batch, envs_per_worker, agents_per_env) + + import ray + if not ray.is_initialized(): + import logging + ray.init( + include_dashboard=False, # WSL Compatibility + logging_level=logging.ERROR, + ) + + multi_envs = [ + ray.remote(multi_env_cls).remote( + env_creator, env_args, env_kwargs, envs_per_worker + ) for _ in range(num_workers) + ] + + self.multi_envs = multi_envs + self.driver_env = driver_env + self.num_envs = num_envs + self.num_workers = num_workers + self.workers_per_batch = workers_per_batch + self.envs_per_batch = envs_per_batch + self.envs_per_worker = envs_per_worker + self.agents_per_batch = agents_per_batch + self.agents_per_worker = agents_per_worker + self.agents_per_env = agents_per_env + self.async_handles = None + self.flag = RESET + self.ray = ray + self.prev_env_id = [] + self.env_pool = env_pool + self.mask_agents = mask_agents + + def recv(self): + recv_precheck(self) + recvs = [] + next_env_id = [] + if self.env_pool: + recvs = self.ray.get(self.async_handles) + env_id = [_ for _ in range(self.workers_per_batch)] + else: + ready, busy = self.ray.wait( + self.async_handles, num_returns=self.workers_per_batch) + env_id = [self.async_handles.index(e) for e in ready] + recvs = self.ray.get(ready) + + recvs = [(o, r, d, t, i, eid) + for (o, r, d, t, i), eid in zip(recvs, env_id)] + self.prev_env_id = env_id + return aggregate_recvs(self, recvs) + + def send(self, actions): + send_precheck(self) + actions = split_actions(self, actions) + self.async_handles = [e.step.remote(a) for e, a in zip(self.multi_envs, actions)] + + def async_reset(self, seed=None): + reset_precheck(self) + if seed is None: + self.async_handles = [e.reset.remote() for e in self.multi_envs] + else: + self.async_handles = [e.reset.remote(seed=seed+idx) + for idx, e in enumerate(self.multi_envs)] + + def put(self, *args, **kwargs): + for e in self.multi_envs: + e.put.remote(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.ray.get([e.get.remote(*args, **kwargs) for e in self.multi_envs]) + + def close(self): + self.ray.get([e.close.remote() for e in self.multi_envs]) + self.ray.shutdown() \ No newline at end of file diff --git a/pufferlib/vectorization/__init__.py b/pufferlib/vectorization/__init__.py deleted file mode 100755 index 574092b2..00000000 --- a/pufferlib/vectorization/__init__.py +++ /dev/null @@ -1,68 +0,0 @@ -import gym - -from pufferlib.vectorization import vec_env, serial_vec_env, multiprocessing_vec_env, ray_vec_env - - -class Serial: - '''Runs environments in serial on the main process - - Use this vectorization module for debugging environments - ''' - __init__ = serial_vec_env.init - single_observation_space = property(vec_env.single_observation_space) - single_action_space = property(vec_env.single_action_space) - structured_observation_space = property(vec_env.structured_observation_space) - flat_observation_space = property(vec_env.flat_observation_space) - unpack_batched_obs = vec_env.unpack_batched_obs - send = serial_vec_env.send - recv = serial_vec_env.recv - async_reset = serial_vec_env.async_reset - profile = serial_vec_env.profile - reset = serial_vec_env.reset - step = serial_vec_env.step - put = serial_vec_env.put - get = serial_vec_env.get - close = serial_vec_env.close - -class Multiprocessing: - '''Runs environments in parallel using multiprocessing - - Use this vectorization module for most applications - ''' - __init__ = multiprocessing_vec_env.init - single_observation_space = property(vec_env.single_observation_space) - single_action_space = property(vec_env.single_action_space) - structured_observation_space = property(vec_env.structured_observation_space) - flat_observation_space = property(vec_env.flat_observation_space) - unpack_batched_obs = vec_env.unpack_batched_obs - send = multiprocessing_vec_env.send - recv = multiprocessing_vec_env.recv - async_reset = multiprocessing_vec_env.async_reset - profile = multiprocessing_vec_env.profile - reset = multiprocessing_vec_env.reset - step = multiprocessing_vec_env.step - put = multiprocessing_vec_env.put - get = multiprocessing_vec_env.get - close = multiprocessing_vec_env.close - -class Ray: - '''Runs environments in parallel on multiple processes using Ray - - Use this module for distributed simulation on a cluster. It can also be - faster than multiprocessing on a single machine for specific environments. - ''' - __init__ = ray_vec_env.init - single_observation_space = property(vec_env.single_observation_space) - single_action_space = property(vec_env.single_action_space) - structured_observation_space = property(vec_env.structured_observation_space) - flat_observation_space = property(vec_env.flat_observation_space) - unpack_batched_obs = vec_env.unpack_batched_obs - send = ray_vec_env.send - recv = ray_vec_env.recv - async_reset = ray_vec_env.async_reset - profile = ray_vec_env.profile - reset = ray_vec_env.reset - step = ray_vec_env.step - put = ray_vec_env.put - get = ray_vec_env.get - close = ray_vec_env.close diff --git a/pufferlib/vectorization/gym_multi_env.py b/pufferlib/vectorization/gym_multi_env.py deleted file mode 100755 index 36c692b5..00000000 --- a/pufferlib/vectorization/gym_multi_env.py +++ /dev/null @@ -1,63 +0,0 @@ -from pdb import set_trace as T -import numpy as np - -from pufferlib.vectorization.multi_env import ( - init, - profile, - put, - get, - close, -) - - -def reset(state, seed=None): - infos = [] - for idx, e in enumerate(state.envs): - if seed is None: - ob, i = e.reset() - else: - ob, i = e.reset(seed=hash(1000*seed + idx)) - - i['mask'] = True - infos.append(i) - if state.preallocated_obs is None: - state.preallocated_obs = np.empty( - (len(state.envs), *ob.shape), dtype=ob.dtype) - - state.preallocated_obs[idx] = ob - - rewards = [0] * len(state.preallocated_obs) - dones = [False] * len(state.preallocated_obs) - truncateds = [False] * len(state.preallocated_obs) - - return state.preallocated_obs, rewards, dones, truncateds, infos - -def step(state, actions): - rewards, dones, truncateds, infos = [], [], [], [] - - for idx, (env, atns) in enumerate(zip(state.envs, actions)): - if env.done: - o, i = env.reset() - rewards.append(0) - dones.append(False) - truncateds.append(False) - else: - o, r, d, t, i = env.step(atns) - rewards.append(r) - dones.append(d) - truncateds.append(t) - - i['mask'] = True - infos.append(i) - state.preallocated_obs[idx] = o - - return state.preallocated_obs, rewards, dones, truncateds, infos - -class GymMultiEnv: - __init__ = init - reset = reset - step = step - profile = profile - put = put - get = get - close = close diff --git a/pufferlib/vectorization/multi_env.py b/pufferlib/vectorization/multi_env.py deleted file mode 100755 index d3beee3c..00000000 --- a/pufferlib/vectorization/multi_env.py +++ /dev/null @@ -1,40 +0,0 @@ -from pufferlib import namespace - - -def create_precheck(env_creator, env_args, env_kwargs): - if env_args is None: - env_args = [] - if env_kwargs is None: - env_kwargs = {} - - assert callable(env_creator) - assert isinstance(env_args, list) - #assert isinstance(env_kwargs, dict) - - return env_args, env_kwargs - -def init(self, - env_creator: callable = None, - env_args: list = [], - env_kwargs: dict = {}, - n: int = 1, - ) -> None: - env_args, env_kwargs = create_precheck(env_creator, env_args, env_kwargs) - return namespace(self, - envs = [env_creator(*env_args, **env_kwargs) for _ in range(n)], - preallocated_obs = None, - ) - -def put(state, *args, **kwargs): - for e in state.envs: - e.put(*args, **kwargs) - -def get(state, *args, **kwargs): - return [e.get(*args, **kwargs) for e in state.envs] - -def close(state): - for env in state.envs: - env.close() - -def profile(state): - return [e.timers for e in state.envs] diff --git a/pufferlib/vectorization/multiprocessing_vec_env.py b/pufferlib/vectorization/multiprocessing_vec_env.py deleted file mode 100755 index b52d8f75..00000000 --- a/pufferlib/vectorization/multiprocessing_vec_env.py +++ /dev/null @@ -1,188 +0,0 @@ -from pdb import set_trace as T -import time - -import selectors -from multiprocessing import Process, Queue, Manager, Pipe -from queue import Empty - -from pufferlib import namespace -from pufferlib.vectorization.vec_env import ( - RESET, - calc_scale_params, - setup, - single_observation_space, - single_action_space, - single_action_space, - structured_observation_space, - flat_observation_space, - unpack_batched_obs, - reset_precheck, - recv_precheck, - send_precheck, - aggregate_recvs, - split_actions, - aggregate_profiles, -) - - -def init(self: object = None, - env_creator: callable = None, - env_args: list = [], - env_kwargs: dict = {}, - num_envs: int = 1, - envs_per_worker: int = 1, - envs_per_batch: int = None, - env_pool: bool = False, - ) -> None: - driver_env, multi_env_cls, agents_per_env = setup( - env_creator, env_args, env_kwargs) - num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params( - num_envs, envs_per_batch, envs_per_worker, agents_per_env) - - - - main_send_pipes, work_recv_pipes = zip(*[Pipe() for _ in range(num_workers)]) - work_send_pipes, main_recv_pipes = zip(*[Pipe() for _ in range(num_workers)]) - - processes = [Process( - target=_worker_process, - args=(multi_env_cls, env_creator, env_args, env_kwargs, - envs_per_worker, work_send_pipes[i], work_recv_pipes[i])) - for i in range(num_workers)] - - for p in processes: - p.start() - - # Register all receive pipes with the selector - sel = selectors.DefaultSelector() - for pipe in main_recv_pipes: - sel.register(pipe, selectors.EVENT_READ) - - return namespace(self, - processes = processes, - sel = sel, - send_pipes = main_send_pipes, - recv_pipes = main_recv_pipes, - driver_env = driver_env, - num_envs = num_envs, - num_workers = num_workers, - workers_per_batch = workers_per_batch, - envs_per_batch = envs_per_batch, - envs_per_worker = envs_per_worker, - agents_per_batch = agents_per_batch, - agents_per_worker = agents_per_worker, - agents_per_env = agents_per_env, - async_handles = None, - flag = RESET, - prev_env_id = [], - env_pool = env_pool, - ) - -def _worker_process(multi_env_cls, env_creator, env_args, env_kwargs, n, send_pipe, recv_pipe): - envs = multi_env_cls(env_creator, env_args, env_kwargs, n=n) - - while True: - request, args, kwargs = recv_pipe.recv() - func = getattr(envs, request) - response = func(*args, **kwargs) - send_pipe.send(response) - -def recv(state): - recv_precheck(state) - - recvs = [] - next_env_id = [] - if state.env_pool: - for env_id in range(state.workers_per_batch): - response_pipe = state.recv_pipes[env_id] - response = response_pipe.recv() - - o, r, d, t, i = response - recvs.append((o, r, d, t, i, env_id)) - next_env_id.append(env_id) - else: - while len(recvs) < state.workers_per_batch: - for key, _ in state.sel.select(timeout=None): - response_pipe = key.fileobj - env_id = state.recv_pipes.index(response_pipe) - - if response_pipe.poll(): # Check if data is available - response = response_pipe.recv() - - o, r, d, t, i = response - recvs.append((o, r, d, t, i, env_id)) - next_env_id.append(env_id) - - if len(recvs) == state.workers_per_batch: - break - - state.prev_env_id = next_env_id - return aggregate_recvs(state, recvs) - -def send(state, actions): - send_precheck(state) - actions = split_actions(state, actions) - assert len(actions) == state.workers_per_batch - for i, atns in zip(state.prev_env_id, actions): - state.send_pipes[i].send(("step", [atns], {})) - -def async_reset(state, seed=None): - reset_precheck(state) - if seed is None: - for pipe in state.send_pipes: - pipe.send(("reset", [], {})) - else: - for idx, pipe in enumerate(state.send_pipes): - pipe.send(("reset", [], {"seed": seed+idx})) - -def reset(state, seed=None): - async_reset(state) - obs, _, _, _, info, env_id, mask = recv(state) - return obs, info, env_id, mask - -def step(state, actions): - send(state, actions) - return recv(state) - -def profile(state): - # TODO: Update this - for queue in state.request_queues: - queue.put(("profile", [], {})) - - return aggregate_profiles([queue.get() for queue in state.response_queues]) - -def put(state, *args, **kwargs): - # TODO: Update this - for queue in state.request_queues: - queue.put(("put", args, kwargs)) - -def get(state, *args, **kwargs): - # TODO: Update this - for queue in state.request_queues: - queue.put(("get", args, kwargs)) - - idx = -1 - recvs = [] - while len(recvs) < state.workers_per_batch // state.envs_per_worker: - idx = (idx + 1) % state.num_workers - queue = state.response_queues[idx] - - if queue.empty(): - continue - - response = queue.get() - if response is not None: - recvs.append(response) - - return recvs - - -def close(state): - for pipe in state.send_pipes: - pipe.send(("close", [], {})) - - for p in state.processes: - p.terminate() - - for p in state.processes: - p.join() diff --git a/pufferlib/vectorization/pettingzoo_multi_env.py b/pufferlib/vectorization/pettingzoo_multi_env.py deleted file mode 100755 index 480d582e..00000000 --- a/pufferlib/vectorization/pettingzoo_multi_env.py +++ /dev/null @@ -1,77 +0,0 @@ -from pdb import set_trace as T - -import numpy as np - -from pufferlib.vectorization.multi_env import ( - init, - profile, - put, - get, - close, -) - - -def reset(state, seed=None): - state.agent_keys = [] - infos = [] - - ptr = 0 - for idx, e in enumerate(state.envs): - if seed is None: - obs, i = e.reset() - else: - obs, i = e.reset(seed=hash(1000*seed + idx)) - - state.agent_keys.append(list(obs.keys())) - infos.append(i) - - if state.preallocated_obs is None: - ob = obs[list(obs.keys())[0]] - state.preallocated_obs = np.empty((len(state.envs)*len(obs), *ob.shape), dtype=ob.dtype) - - for o in obs.values(): - state.preallocated_obs[ptr] = o - ptr += 1 - - rewards = [0] * len(state.preallocated_obs) - dones = [False] * len(state.preallocated_obs) - truncateds = [False] * len(state.preallocated_obs) - return state.preallocated_obs, rewards, dones, truncateds, infos - -def step(state, actions): - actions = np.array_split(actions, len(state.envs)) - rewards, dones, truncateds, infos = [], [], [], [] - - ptr = 0 - for idx, (a_keys, env, atns) in enumerate(zip(state.agent_keys, state.envs, actions)): - if env.done: - o, i = env.reset() - num_agents = len(env.possible_agents) - rewards.extend([0] * num_agents) - dones.extend([False] * num_agents) - truncateds.extend([False] * num_agents) - else: - assert len(a_keys) == len(atns) - atns = dict(zip(a_keys, atns)) - o, r, d, t, i = env.step(atns) - rewards.extend(r.values()) - dones.extend(d.values()) - truncateds.extend(t.values()) - - infos.append(i) - state.agent_keys[idx] = list(o.keys()) - - for oo in o.values(): - state.preallocated_obs[ptr] = oo - ptr += 1 - - return state.preallocated_obs, rewards, dones, truncateds, infos - -class PettingZooMultiEnv: - __init__ = init - reset = reset - step = step - profile = profile - put = put - get = get - close = close diff --git a/pufferlib/vectorization/ray_vec_env.py b/pufferlib/vectorization/ray_vec_env.py deleted file mode 100755 index 32e9bf92..00000000 --- a/pufferlib/vectorization/ray_vec_env.py +++ /dev/null @@ -1,125 +0,0 @@ -from pdb import set_trace as T - -import gym - -from pufferlib import namespace -from pufferlib.vectorization.vec_env import ( - RESET, - calc_scale_params, - setup, - single_observation_space, - single_action_space, - single_action_space, - structured_observation_space, - flat_observation_space, - unpack_batched_obs, - reset_precheck, - recv_precheck, - send_precheck, - aggregate_recvs, - split_actions, - aggregate_profiles, -) - - -def init(self: object = None, - env_creator: callable = None, - env_args: list = [], - env_kwargs: dict = {}, - num_envs: int = 1, - envs_per_worker: int = 1, - envs_per_batch: int = None, - env_pool: bool = False, - ) -> None: - driver_env, multi_env_cls, agents_per_env = setup( - env_creator, env_args, env_kwargs) - num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params( - num_envs, envs_per_batch, envs_per_worker, agents_per_env) - - import ray - if not ray.is_initialized(): - import logging - ray.init( - include_dashboard=False, # WSL Compatibility - logging_level=logging.ERROR, - ) - - multi_envs = [ - ray.remote(multi_env_cls).remote( - env_creator, env_args, env_kwargs, envs_per_worker - ) for _ in range(num_workers) - ] - - return namespace(self, - multi_envs = multi_envs, - driver_env = driver_env, - num_envs = num_envs, - num_workers = num_workers, - workers_per_batch = workers_per_batch, - envs_per_batch = envs_per_batch, - envs_per_worker = envs_per_worker, - agents_per_batch = agents_per_batch, - agents_per_worker = agents_per_worker, - agents_per_env = agents_per_env, - async_handles = None, - flag = RESET, - ray = ray, # Save a copy for internal use - prev_env_id = [], - env_pool = env_pool, - ) - -def recv(state): - recv_precheck(state) - - recvs = [] - next_env_id = [] - if state.env_pool: - recvs = state.ray.get(state.async_handles) - env_id = [_ for _ in range(state.workers_per_batch)] - else: - ready, busy = state.ray.wait( - state.async_handles, num_returns=state.workers_per_batch) - env_id = [state.async_handles.index(e) for e in ready] - recvs = state.ray.get(ready) - - recvs = [(o, r, d, t, i, eid) - for (o, r, d, t, i), eid in zip(recvs, env_id)] - state.prev_env_id = env_id - return aggregate_recvs(state, recvs) - -def send(state, actions): - send_precheck(state) - actions = split_actions(state, actions) - state.async_handles = [e.step.remote(a) for e, a in zip(state.multi_envs, actions)] - -def async_reset(state, seed=None): - reset_precheck(state) - if seed is None: - state.async_handles = [e.reset.remote() for e in state.multi_envs] - else: - state.async_handles = [e.reset.remote(seed=seed+idx) - for idx, e in enumerate(state.multi_envs)] - -def reset(state, seed=None): - async_reset(state) - obs, _, _, _, info, env_id, mask = recv(state) - return obs, info, env_id, mask - -def step(state, actions): - send(state, actions) - return recv(state) - -def profile(state): - return aggregate_profiles( - state.ray.get([e.profile.remote() for e in state.multi_envs])) - -def put(state, *args, **kwargs): - for e in state.multi_envs: - e.put.remote(*args, **kwargs) - -def get(state, *args, **kwargs): - return state.ray.get([e.get.remote(*args, **kwargs) for e in state.multi_envs]) - -def close(state): - state.ray.get([e.close.remote() for e in state.multi_envs]) - state.ray.shutdown() diff --git a/pufferlib/vectorization/serial_vec_env.py b/pufferlib/vectorization/serial_vec_env.py deleted file mode 100755 index 78f0f656..00000000 --- a/pufferlib/vectorization/serial_vec_env.py +++ /dev/null @@ -1,98 +0,0 @@ -from pdb import set_trace as T - -import gym - -from pufferlib import namespace -from pufferlib.vectorization.vec_env import ( - RESET, - calc_scale_params, - setup, - single_observation_space, - single_action_space, - single_action_space, - structured_observation_space, - flat_observation_space, - unpack_batched_obs, - reset_precheck, - recv_precheck, - send_precheck, - aggregate_recvs, - split_actions, - aggregate_profiles, -) - -def init(self: object = None, - env_creator: callable = None, - env_args: list = [], - env_kwargs: dict = {}, - num_envs: int = 1, - envs_per_worker: int = 1, - envs_per_batch: int = None, - env_pool: bool = False, - ) -> None: - driver_env, multi_env_cls, agents_per_env = setup( - env_creator, env_args, env_kwargs) - num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker = calc_scale_params( - num_envs, envs_per_batch, envs_per_worker, agents_per_env) - - multi_envs = [ - multi_env_cls( - env_creator, env_args, env_kwargs, envs_per_worker, - ) for _ in range(num_workers) - ] - - return namespace(self, - multi_envs = multi_envs, - driver_env = driver_env, - num_envs = num_envs, - num_workers = num_workers, - workers_per_batch = workers_per_batch, - envs_per_batch = envs_per_batch, - envs_per_worker = envs_per_worker, - agents_per_batch = agents_per_batch, - agents_per_worker = agents_per_worker, - agents_per_env = agents_per_env, - async_handles = None, - flag = RESET, - ) - -def recv(state): - recv_precheck(state) - recvs = [(o, r, d, t, i, env_id) for (o, r, d, t, i), env_id - in zip(state.data, range(state.workers_per_batch))] - return aggregate_recvs(state, recvs) - -def send(state, actions): - send_precheck(state) - actions = split_actions(state, actions) - state.data = [e.step(a) for e, a in zip(state.multi_envs, actions)] - -def async_reset(state, seed=None): - reset_precheck(state) - if seed is None: - state.data = [e.reset() for e in state.multi_envs] - else: - state.data = [e.reset(seed=seed+idx) for idx, e in enumerate(state.multi_envs)] - -def reset(state, seed=None): - async_reset(state) - obs, _, _, _, info, env_id, mask = recv(state) - return obs, info, env_id, mask - -def step(state, actions): - send(state, actions) - return recv(state) - -def profile(state): - return aggregate_profiles([e.profile() for e in state.multi_envs]) - -def put(state, *args, **kwargs): - for e in state.multi_envs: - e.put(*args, **kwargs) - -def get(state, *args, **kwargs): - return [e.get(*args, **kwargs) for e in state.multi_envs] - -def close(state): - for e in state.multi_envs: - e.close() diff --git a/pufferlib/vectorization/vec_env.py b/pufferlib/vectorization/vec_env.py deleted file mode 100755 index 7f268549..00000000 --- a/pufferlib/vectorization/vec_env.py +++ /dev/null @@ -1,139 +0,0 @@ -from pdb import set_trace as T - -import numpy as np -from itertools import chain - -from pufferlib import namespace -from pufferlib.emulation import GymnasiumPufferEnv, PettingZooPufferEnv -from pufferlib.vectorization.multi_env import create_precheck -from pufferlib.vectorization.gym_multi_env import GymMultiEnv -from pufferlib.vectorization.pettingzoo_multi_env import PettingZooMultiEnv - - -RESET = 0 -SEND = 1 -RECV = 2 - -space_error_msg = 'env {env} must be an instance of GymnasiumPufferEnv or PettingZooPufferEnv' - - -def calc_scale_params(num_envs, envs_per_batch, envs_per_worker, agents_per_env): - '''These calcs are simple but easy to mess up and hard to catch downstream. - We do them all at once here to avoid that''' - - assert num_envs % envs_per_worker == 0 - num_workers = num_envs // envs_per_worker - - envs_per_batch = num_envs if envs_per_batch is None else envs_per_batch - assert envs_per_batch % envs_per_worker == 0 - assert envs_per_batch <= num_envs - assert envs_per_batch > 0 - - workers_per_batch = envs_per_batch // envs_per_worker - assert workers_per_batch <= num_workers - - agents_per_batch = envs_per_batch * agents_per_env - agents_per_worker = envs_per_worker * agents_per_env - - return num_workers, workers_per_batch, envs_per_batch, agents_per_batch, agents_per_worker - -def setup(env_creator, env_args, env_kwargs): - env_args, env_kwargs = create_precheck(env_creator, env_args, env_kwargs) - driver_env = env_creator(*env_args, **env_kwargs) - - if isinstance(driver_env, GymnasiumPufferEnv): - multi_env_cls = GymMultiEnv - env_agents = 1 - is_multiagent = False - elif isinstance(driver_env, PettingZooPufferEnv): - multi_env_cls = PettingZooMultiEnv - env_agents = len(driver_env.possible_agents) - is_multiagent = True - else: - raise TypeError( - 'env_creator must return an instance ' - 'of GymnasiumPufferEnv or PettingZooPufferEnv' - ) - - obs_space = _single_observation_space(driver_env) - return driver_env, multi_env_cls, env_agents - -def _single_observation_space(env): - if isinstance(env, GymnasiumPufferEnv): - return env.observation_space - elif isinstance(env, PettingZooPufferEnv): - return env.single_observation_space - else: - raise TypeError(space_error_msg.format(env=env)) - -def single_observation_space(state): - return _single_observation_space(state.driver_env) - -def _single_action_space(env): - if isinstance(env, GymnasiumPufferEnv): - return env.action_space - elif isinstance(env, PettingZooPufferEnv): - return env.single_action_space - else: - raise TypeError(space_error_msg.format(env=env)) - -def single_action_space(state): - return _single_action_space(state.driver_env) - -def structured_observation_space(state): - return state.driver_env.structured_observation_space - -def flat_observation_space(state): - return state.driver_env.flat_observation_space - -def unpack_batched_obs(state, obs): - return state.driver_env.unpack_batched_obs(obs) - -def recv_precheck(state): - assert state.flag == RECV, 'Call reset before stepping' - state.flag = SEND - -def send_precheck(state): - assert state.flag == SEND, 'Call reset + recv before send' - state.flag = RECV - -def reset_precheck(state): - assert state.flag == RESET, 'Call reset only once on initialization' - state.flag = RECV - -def aggregate_recvs(state, recvs): - obs, rewards, dones, truncateds, infos, env_ids = list(zip(*recvs)) - assert all(state.workers_per_batch == len(e) for e in - (obs, rewards, dones, truncateds, infos, env_ids)) - - obs = np.stack(list(chain.from_iterable(obs)), 0) - rewards = list(chain.from_iterable(rewards)) - dones = list(chain.from_iterable(dones)) - truncateds = list(chain.from_iterable(truncateds)) - infos = [i for ii in infos for i in ii] - - # TODO: Masking will break for 1-agent PZ envs - # Replace with check against is_multiagent (add it to state) - if state.agents_per_env > 1: - mask = [e['mask'] for ee in infos for e in ee.values()] - else: - mask = [e['mask'] for e in infos] - - env_ids = np.concatenate([np.arange( # Per-agent env indexing - i*state.agents_per_worker, (i+1)*state.agents_per_worker) for i in env_ids]) - - assert all(state.agents_per_batch == len(e) for e in - (obs, rewards, dones, truncateds, env_ids, mask)) - assert len(infos) == state.envs_per_batch - return obs, rewards, dones, truncateds, infos, env_ids, mask - -def split_actions(state, actions, env_id=None): - assert isinstance(actions, (list, np.ndarray)) - if type(actions) == list: - actions = np.array(actions) - - assert len(actions) == state.agents_per_batch - return np.array_split(actions, state.workers_per_batch) - -def aggregate_profiles(profiles): - return list(chain.from_iterable([profiles])) diff --git a/run.sh b/run.sh index 1491868a..d9cbed26 100755 --- a/run.sh +++ b/run.sh @@ -1,5 +1,2 @@ #!/bin/bash -# python demo.py --config pokemon_red --vectorization multiprocessing --mode train --track -# python demo.py --config pokemon_red --vectorization multiprocessing --mode train -# python demo.py --config pokemon_red --mode train -python demo.py --backend clean_pufferl --config pokemon_red --vectorization multiprocessing --mode train --track +python demo.py --backend clean_pufferl --config pokemon_red --no-render --vectorization multiprocessing --mode train --track diff --git a/stream_wrapper.py b/stream_wrapper.py index 5f851799..72e45e17 100755 --- a/stream_wrapper.py +++ b/stream_wrapper.py @@ -38,10 +38,11 @@ def __init__(self, env, stream_metadata={}): self.stream_metadata = stream_metadata self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.websocket = self.loop.run_until_complete( - self.establish_wc_connection() + self.websocket = None + self.loop.run_until_complete( + self.establish_wc_connection() ) - self.upload_interval = 80 + self.upload_interval = 250 self.steam_step_counter = 0 self.coord_list = [] self.start_time = time.time()