Skip to content

Commit

Permalink
NOT-WORKING BOEY BASELINES PUFFERLIB PORT
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Jan 17, 2024
1 parent e800153 commit e317984
Show file tree
Hide file tree
Showing 153 changed files with 237 additions and 41 deletions.
Empty file modified .gitignore
100644 → 100755
Empty file.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified MANIFEST.in
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
Empty file modified clean_pufferl.py
100644 → 100755
Empty file.
Empty file modified cleanrl_ppo_atari.py
100644 → 100755
Empty file.
14 changes: 7 additions & 7 deletions config.yaml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ train:
torch_deterministic: True
device: cuda
total_timesteps: 10_000_000
learning_rate: 2.5e-4
learning_rate: 0.0004
num_steps: 128
anneal_lr: True
gamma: 0.99
Expand All @@ -22,7 +22,7 @@ train:
envs_per_worker: 1
envs_per_batch: ~
env_pool: True
verbose: True
verbose: False
data_dir: experiments
checkpoint_interval: 200
cpu_offload: True
Expand Down Expand Up @@ -664,13 +664,13 @@ pokemon_red:
package: pokemon_red
train:
total_timesteps: 100_000_000
num_envs: 4
num_envs: 64
envs_per_worker: 1
envpool_batch_size: 4
update_epochs: 3
envpool_batch_size: 32
update_epochs: 10
gamma: 0.998
batch_size: 1024
batch_rows: 16
batch_size: 32768
batch_rows: 64
env:
name: pokemon_red
pokemon-red:
Expand Down
6 changes: 5 additions & 1 deletion demo.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,16 @@ def load_from_config(env):
return pkg, pufferlib.namespace(**combined_config)

def make_policy(env, env_module, args):

policy = env_module.Policy(env, **args.policy)

if args.force_recurrence or env_module.Recurrent is not None:
policy = env_module.Recurrent(env, policy, **args.recurrent)
policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy)

else:
policy = pufferlib.frameworks.cleanrl.Policy(policy)


return policy.to(args.train.device)

Expand Down Expand Up @@ -153,7 +157,7 @@ def train(args, env_module, make_env):
parser.add_argument('--no-render', action='store_true', help='Disable render during evaluate')
parser.add_argument('--exp-name', type=str, default=None, help="Resume from experiment")
parser.add_argument('--vectorization', type=str, default='serial', help='Vectorization method (serial, multiprocessing, ray)')
parser.add_argument('--wandb-entity', type=str, default='jsuarez', help='WandB entity')
parser.add_argument('--wandb-entity', type=str, default='xinpw8', help='WandB entity')
parser.add_argument('--wandb-project', type=str, default='pufferlib', help='WandB project')
parser.add_argument('--wandb-group', type=str, default='debug', help='WandB group')
parser.add_argument('--track', action='store_true', help='Track on WandB')
Expand Down
Empty file modified pokemon_red_eval.py
100644 → 100755
Empty file.
Empty file modified pufferlib/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/emulation.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/atari/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/atari/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/atari/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/bsuite/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/bsuite/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/bsuite/squared.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/bsuite/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/butterfly/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/butterfly/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/butterfly/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/classic_control/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/classic_control/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/classic_control/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/crafter/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/crafter/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/crafter/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_control/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_control/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_control/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_lab/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_lab/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/dm_lab/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/griddly/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/griddly/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/griddly/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/links_awaken/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/links_awaken/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/links_awaken/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/magent/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/magent/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/magent/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/microrts/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/microrts/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/microrts/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minerl/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minerl/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minerl/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minigrid/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minigrid/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minigrid/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minihack/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minihack/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/minihack/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nethack/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nethack/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nethack/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nmmo/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nmmo/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/nmmo/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/README.md
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/bandit.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/memory.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/password.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/squared.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/stochastic.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/ocean/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/gymnasium_environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/pettingzoo_environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/open_spiel/utils.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/pokemon_red/__init__.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion pufferlib/environments/pokemon_red/environment.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def env_creator(name='pokemon_red'):

def make(name, headless: bool = True, state_path=None):
'''Pokemon Red'''
env = Environment(headless=headless, state_path=state_path)
env = Environment()
return pufferlib.emulation.GymnasiumPufferEnv(env=env,
postprocessor_cls=pufferlib.emulation.BasicPostprocessor)
211 changes: 200 additions & 11 deletions pufferlib/environments/pokemon_red/torch.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,19 +1,208 @@
from torch.nn import functional as F
from pdb import set_trace as T
import pufferlib.models
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, is_image_space, get_flattened_obs_dim, NatureCNN, TensorDict, gym
from gymnasium import spaces
import torch as th
from torch import nn


class Recurrent(pufferlib.models.RecurrentWrapper):
def __init__(self, env, policy, input_size=512, hidden_size=512, num_layers=1):
super().__init__(env, policy, input_size, hidden_size, num_layers)

class Policy(pufferlib.models.Convolutional):
def __init__(self, env, input_size=512, hidden_size=512, output_size=512,
framestack=3, flat_size=64*5*6):
super().__init__(
env=env,
input_size=input_size,
hidden_size=hidden_size,
output_size=output_size,
framestack=framestack,
flat_size=flat_size,
channels_last=True,
# class Policy(pufferlib.models.Convolutional):
# def __init__(self, env, input_size=512, hidden_size=512, output_size=512,
# framestack=3, flat_size=64*5*6):
# super().__init__(
# env=env,
# input_size=input_size,
# hidden_size=hidden_size,
# output_size=output_size,
# framestack=framestack,
# flat_size=flat_size,
# channels_last=True,
# )


class Policy(pufferlib.models.Policy):
"""
Combined features extractor for Dict observation spaces.
Builds a features extractor for each key of the space. Input from each space
is fed through a separate submodule (CNN or MLP, depending on input shape),
the output features are concatenated and fed through additional MLP network ("combined").
:param observation_space:
:param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to
256 to avoid exploding network sizes.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""

def __init__(
self,
env,
cnn_output_dim: int = 256,
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(env)

self.flat_observation_space = env.flat_observation_space
self.flat_observation_structure = env.flat_observation_structure
# observation_space.spaces.items()

# image (3, 36, 40)
self.image_cnn = NatureCNN(env.structured_observation_space['image'], features_dim=cnn_output_dim, normalized_image=normalized_image)

# poke_move_ids (12, 4) -> (12, 4, 8)
self.poke_move_ids_embedding = nn.Embedding(167, 8, padding_idx=0)
# concat with poke_move_pps (12, 4, 2)
# input (12, 4, 10) for fc relu
self.move_fc_relu = nn.Sequential(
nn.Linear(10, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
)
# max pool
self.move_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 16))
# output (12, 1, 16), sqeeze(-2) -> (12, 16)

# poke_type_ids (12, 2) -> (12, 2, 8)
self.poke_type_ids_embedding = nn.Embedding(17, 8, padding_idx=0)
# (12, 2, 8) -> (12, 8) by sum(dim=-2)

# poke_ids (12, ) -> (12, 8)
self.poke_ids_embedding = nn.Embedding(192, 16, padding_idx=0)

# pokemon fc relu
self.poke_fc_relu = nn.Sequential(
nn.Linear(63, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
)

# pokemon party head
self.poke_party_head = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)
# get the first 6 pokemon and do max pool
self.poke_party_head_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 32))

# pokemon opp head
self.poke_opp_head = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)
# get the last 6 pokemon and do max pool
self.poke_opp_head_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 32))

# item_ids embedding
self.item_ids_embedding = nn.Embedding(256, 16, padding_idx=0) # (20, 16)
# item_ids fc relu
self.item_ids_fc_relu = nn.Sequential(
nn.Linear(17, 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
)
# item_ids max pool
self.item_ids_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 16))

# event_ids embedding
self.event_ids_embedding = nn.Embedding(2570, 16, padding_idx=0) # (20, )
# event_ids fc relu
self.event_ids_fc_relu = nn.Sequential(
nn.Linear(17, 16),
nn.ReLU(),
nn.Linear(16, 16),
nn.ReLU(),
)
# event_ids max pool
self.event_ids_max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 16))


self._features_dim = 406

self.fc1 = nn.Linear(406,512)
self.fc2 = nn.Linear(512,512)
self.action = nn.Linear(512, self.action_space.n)
self.value_head = nn.Linear(512,1)



def encode_observations(self, observations: TensorDict) -> th.Tensor:
observations = pufferlib.emulation.unpack_batched_obs(observations,
self.flat_observation_space, self.flat_observation_structure)

img = self.image_cnn(observations['image']) # (256, )

# Pokemon
# Moves
embedded_poke_move_ids = self.poke_move_ids_embedding(observations['poke_move_ids'].to(th.int))
poke_move_pps = observations['poke_move_pps']
poke_moves = th.cat([embedded_poke_move_ids, poke_move_pps], dim=-1)
poke_moves = self.move_fc_relu(poke_moves)
poke_moves = self.move_max_pool(poke_moves).squeeze(-2) # (12, 16)
# Types
embedded_poke_type_ids = self.poke_type_ids_embedding(observations['poke_type_ids'].to(th.int))
poke_types = th.sum(embedded_poke_type_ids, dim=-2) # (12, 8)
# Pokemon ID
embedded_poke_ids = self.poke_ids_embedding(observations['poke_ids'].to(th.int))
poke_ids = embedded_poke_ids # (12, 8)
# Pokemon stats (12, 23)
poke_stats = observations['poke_all']
# All pokemon features
pokemon_concat = th.cat([poke_moves, poke_types, poke_ids, poke_stats], dim=-1) # (12, 63)
pokemon_features = self.poke_fc_relu(pokemon_concat) # (12, 32)

# Pokemon party head
party_pokemon_features = pokemon_features[..., :6, :] # (6, 32), ... for batch dim
poke_party_head = self.poke_party_head(party_pokemon_features) # (6, 32)
poke_party_head = self.poke_party_head_max_pool(poke_party_head).squeeze(-2) # (6, 32) -> (32, )

# Pokemon opp head
opp_pokemon_features = pokemon_features[..., 6:, :] # (6, 32), ... for batch dim
poke_opp_head = self.poke_opp_head(opp_pokemon_features) # (6, 32)
poke_opp_head = self.poke_opp_head_max_pool(poke_opp_head).squeeze(-2) # (6, 32) -> (32, )

# Items
embedded_item_ids = self.item_ids_embedding(observations['item_ids'].to(th.int)) # (20, 16)
# item_quantity
item_quantity = observations['item_quantity'] # (20, 1)
item_concat = th.cat([embedded_item_ids, item_quantity], dim=-1) # (20, 17)
item_features = self.item_ids_fc_relu(item_concat) # (20, 16)
item_features = self.item_ids_max_pool(item_features).squeeze(-2) # (20, 16) -> (16, )

# Events
embedded_event_ids = self.event_ids_embedding(observations['event_ids'].to(th.int))
# event_step_since
event_step_since = observations['event_step_since'] # (20, 1)
event_concat = th.cat([embedded_event_ids, event_step_since], dim=-1) # (20, 17)
event_features = self.event_ids_fc_relu(event_concat)
event_features = self.event_ids_max_pool(event_features).squeeze(-2) # (20, 16) -> (16, )

# Map_IDs


# Raw vector
vector = observations['vector'] # (54, )

# Concat all features
all_features = th.cat([img, poke_party_head, poke_opp_head, item_features, event_features, vector], dim=-1) # (406, )

hidden = self.fc2(F.relu(self.fc1(all_features)))
return hidden, None

def decode_actions(self, hidden, lookup):
action = self.action(hidden)
value = self.value_head(hidden)
return action, value
Empty file modified pufferlib/environments/procgen/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/procgen/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/procgen/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/smac/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/smac/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/smac/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/stable_retro/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/stable_retro/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/stable_retro/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/test/__init__.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/test/environment.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/test/mock_environments.py
100644 → 100755
Empty file.
Empty file modified pufferlib/environments/test/torch.py
100644 → 100755
Empty file.
Empty file modified pufferlib/evaluation.py
100644 → 100755
Empty file.
Empty file modified pufferlib/exceptions.py
100644 → 100755
Empty file.
Loading

0 comments on commit e317984

Please sign in to comment.