Skip to content

Commit

Permalink
boey v2 incomplete 2
Browse files Browse the repository at this point in the history
  • Loading branch information
xinpw8 committed Feb 23, 2024
1 parent 75bb946 commit fdf5d9b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
4 changes: 4 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ train:
bptt_horizon: 16 #8
vf_clip_coef: 0.1

debug: False

sweep:
method: random
name: sweep
Expand Down Expand Up @@ -671,6 +673,8 @@ pokemon_red:
gamma: 0.998
batch_size: 32768
batch_rows: 64
debug: False

env:
name: pokemon_red
pokemon-red:
Expand Down
47 changes: 37 additions & 10 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,30 +182,57 @@ def train(args, env_module, make_env):
config.policy = {**get_init_args(env_module.Policy.__init__), **config.policy}
config.recurrent = {**get_init_args(env_module.Recurrent.__init__), **config.recurrent}


# Generate argparse menu from config
# This is also a reason for Spock/Argbind/OmegaConf/pydantic-cli
for name, sub_config in config.items():
args[name] = {}
for key, value in sub_config.items():
data_key = f'{name}.{key}'
cli_key = f'--{data_key}'.replace('_', '-')
data_key = f"{name}.{key}"
cli_key = f"--{data_key}".replace("_", "-")
if isinstance(value, bool) and value is False:
action = 'store_false'
parser.add_argument(cli_key, default=value, action='store_true')
clean_parser.add_argument(cli_key, default=value, action='store_true')
action = "store_false"
parser.add_argument(cli_key, default=value, action="store_true")
clean_parser.add_argument(cli_key, default=value, action="store_true")
elif isinstance(value, bool) and value is True:
data_key = f'{name}.no_{key}'
cli_key = f'--{data_key}'.replace('_', '-')
parser.add_argument(cli_key, default=value, action='store_false')
clean_parser.add_argument(cli_key, default=value, action='store_false')
data_key = f"{name}.no_{key}"
cli_key = f"--{data_key}".replace("_", "-")
parser.add_argument(cli_key, default=value, action="store_false")
clean_parser.add_argument(cli_key, default=value, action="store_false")
else:
parser.add_argument(cli_key, default=value, type=type(value))
clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value))
clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value))

args[name][key] = getattr(parser.parse_known_args()[0], data_key)
args[name] = pufferlib.namespace(**args[name])

clean_parser.parse_args(sys.argv[1:])
args = pufferlib.namespace(**args)

# # Generate argparse menu from config
# for name, sub_config in config.items():
# args[name] = {}
# for key, value in sub_config.items():
# data_key = f'{name}.{key}'
# cli_key = f'--{data_key}'.replace('_', '-')
# if isinstance(value, bool) and value is False:
# action = 'store_false'
# parser.add_argument(cli_key, default=value, action='store_true')
# clean_parser.add_argument(cli_key, default=value, action='store_true')
# elif isinstance(value, bool) and value is True:
# data_key = f'{name}.no_{key}'
# cli_key = f'--{data_key}'.replace('_', '-')
# parser.add_argument(cli_key, default=value, action='store_false')
# clean_parser.add_argument(cli_key, default=value, action='store_false')
# else:
# parser.add_argument(cli_key, default=value, type=type(value))
# clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value))

# args[name][key] = getattr(parser.parse_known_args()[0], data_key)
# args[name] = pufferlib.namespace(**args[name])

# clean_parser.parse_args(sys.argv[1:])
# args = pufferlib.namespace(**args)

vec = args.vectorization
if vec == 'serial':
Expand Down
27 changes: 23 additions & 4 deletions pufferlib/environments/pokemon_red/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,33 @@ def __init__(
normalized_image: bool = False,
) -> None:
# TODO we do not know features-dim here before going over all the items, so put something there. This is dirty!
super().__init__(env, features_dim=1)
super().__init__(env)

# observation_space.spaces.items()

# image (3, 36, 40)
# self.image_cnn = NatureCNN(observation_space['image'], features_dim=cnn_output_dim, normalized_image=normalized_image)
# nature cnn (4, 36, 40), output_dim = 512 cnn_output_dim
n_input_channels = env['image'].shape[0]
# if not isinstance(env.observation_space, spaces.Dict):
# raise TypeError("Expected env.observation_space to be a gym.spaces.Dict")
breakpoint()
# n_input_channels = env.observation_space.spaces['image'].shape[0]
'''
(Pdb) dir(env)
['__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_is_protocol', '_np_random', 'action_space', 'box_observation_space', 'close', 'done', 'env', 'flat_action_space', 'flat_action_structure', 'flat_observation_space', 'flat_observation_structure', 'get_wrapper_attr', 'initialized', 'is_action_checked', 'is_observation_checked', 'metadata', 'multidiscrete_action_space', 'np_random', 'observation_space', 'pad_observation', 'postprocessor', 'render', 'render_mode', 'render_modes', 'reset', 'reward_range', 'seed', 'spec', 'step', 'structured_action_space', 'structured_observation_space', 'unpack_batched_obs', 'unwrapped']
(Pdb) env.render
<bound method GymnasiumPufferEnv.render of <pufferlib.emulation.GymnasiumPufferEnv object at 0x7fb43917df30>>
(Pdb) env.observation_space
Box(-3.4028235e+38, 3.4028235e+38, (19591,), float32)
(Pdb) env.observation_space[0]
*** TypeError: 'Box' object is not subscriptable
(Pdb) env.observation_space
Box(-3.4028235e+38, 3.4028235e+38, (19591,), float32)
(Pdb)
'''


n_input_channels = env.observation_space
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32*2, kernel_size=8, stride=4, padding=(2, 0)),
nn.ReLU(),
Expand All @@ -70,7 +89,7 @@ def __init__(

# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(env['image'].sample()[None]).float()).shape[1]
n_flatten = self.cnn(th.as_tensor(env.observation_space['image'].sample()[None]).float()).shape[1]

self.cnn_linear = nn.Sequential(nn.Linear(n_flatten, cnn_output_dim), nn.ReLU())

Expand All @@ -87,7 +106,7 @@ def __init__(
self.minimap_warp_embedding = nn.Embedding(830, warp_emb_dim, padding_idx=0)

# minimap (14 + 8 + 8, 9, 10)
n_input_channels = env['minimap'].shape[0] + sprite_emb_dim + warp_emb_dim
n_input_channels = env.observation_space['minimap'].shape[0] + sprite_emb_dim + warp_emb_dim
self.minimap_cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32*2, kernel_size=4, stride=1, padding=0),
nn.ReLU(),
Expand Down

0 comments on commit fdf5d9b

Please sign in to comment.