diff --git a/config.yaml b/config.yaml index baedabe2..53138d01 100755 --- a/config.yaml +++ b/config.yaml @@ -32,6 +32,8 @@ train: bptt_horizon: 16 #8 vf_clip_coef: 0.1 + debug: False + sweep: method: random name: sweep @@ -671,6 +673,8 @@ pokemon_red: gamma: 0.998 batch_size: 32768 batch_rows: 64 + debug: False + env: name: pokemon_red pokemon-red: diff --git a/demo.py b/demo.py index e4738fe7..f20f3d5e 100755 --- a/demo.py +++ b/demo.py @@ -182,30 +182,57 @@ def train(args, env_module, make_env): config.policy = {**get_init_args(env_module.Policy.__init__), **config.policy} config.recurrent = {**get_init_args(env_module.Recurrent.__init__), **config.recurrent} + # Generate argparse menu from config + # This is also a reason for Spock/Argbind/OmegaConf/pydantic-cli for name, sub_config in config.items(): args[name] = {} for key, value in sub_config.items(): - data_key = f'{name}.{key}' - cli_key = f'--{data_key}'.replace('_', '-') + data_key = f"{name}.{key}" + cli_key = f"--{data_key}".replace("_", "-") if isinstance(value, bool) and value is False: - action = 'store_false' - parser.add_argument(cli_key, default=value, action='store_true') - clean_parser.add_argument(cli_key, default=value, action='store_true') + action = "store_false" + parser.add_argument(cli_key, default=value, action="store_true") + clean_parser.add_argument(cli_key, default=value, action="store_true") elif isinstance(value, bool) and value is True: - data_key = f'{name}.no_{key}' - cli_key = f'--{data_key}'.replace('_', '-') - parser.add_argument(cli_key, default=value, action='store_false') - clean_parser.add_argument(cli_key, default=value, action='store_false') + data_key = f"{name}.no_{key}" + cli_key = f"--{data_key}".replace("_", "-") + parser.add_argument(cli_key, default=value, action="store_false") + clean_parser.add_argument(cli_key, default=value, action="store_false") else: parser.add_argument(cli_key, default=value, type=type(value)) - clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value)) + clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value)) args[name][key] = getattr(parser.parse_known_args()[0], data_key) args[name] = pufferlib.namespace(**args[name]) clean_parser.parse_args(sys.argv[1:]) args = pufferlib.namespace(**args) + + # # Generate argparse menu from config + # for name, sub_config in config.items(): + # args[name] = {} + # for key, value in sub_config.items(): + # data_key = f'{name}.{key}' + # cli_key = f'--{data_key}'.replace('_', '-') + # if isinstance(value, bool) and value is False: + # action = 'store_false' + # parser.add_argument(cli_key, default=value, action='store_true') + # clean_parser.add_argument(cli_key, default=value, action='store_true') + # elif isinstance(value, bool) and value is True: + # data_key = f'{name}.no_{key}' + # cli_key = f'--{data_key}'.replace('_', '-') + # parser.add_argument(cli_key, default=value, action='store_false') + # clean_parser.add_argument(cli_key, default=value, action='store_false') + # else: + # parser.add_argument(cli_key, default=value, type=type(value)) + # clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value)) + + # args[name][key] = getattr(parser.parse_known_args()[0], data_key) + # args[name] = pufferlib.namespace(**args[name]) + + # clean_parser.parse_args(sys.argv[1:]) + # args = pufferlib.namespace(**args) vec = args.vectorization if vec == 'serial': diff --git a/pufferlib/environments/pokemon_red/torch.py b/pufferlib/environments/pokemon_red/torch.py index d84853d7..1f5ad886 100755 --- a/pufferlib/environments/pokemon_red/torch.py +++ b/pufferlib/environments/pokemon_red/torch.py @@ -49,14 +49,33 @@ def __init__( normalized_image: bool = False, ) -> None: # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! - super().__init__(env, features_dim=1) + super().__init__(env) # observation_space.spaces.items() # image (3, 36, 40) # self.image_cnn = NatureCNN(observation_space['image'], features_dim=cnn_output_dim, normalized_image=normalized_image) # nature cnn (4, 36, 40), output_dim = 512 cnn_output_dim - n_input_channels = env['image'].shape[0] + # if not isinstance(env.observation_space, spaces.Dict): + # raise TypeError("Expected env.observation_space to be a gym.spaces.Dict") + breakpoint() + # n_input_channels = env.observation_space.spaces['image'].shape[0] + ''' + (Pdb) dir(env) +['__annotations__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__enter__', '__eq__', '__exit__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_is_protocol', '_np_random', 'action_space', 'box_observation_space', 'close', 'done', 'env', 'flat_action_space', 'flat_action_structure', 'flat_observation_space', 'flat_observation_structure', 'get_wrapper_attr', 'initialized', 'is_action_checked', 'is_observation_checked', 'metadata', 'multidiscrete_action_space', 'np_random', 'observation_space', 'pad_observation', 'postprocessor', 'render', 'render_mode', 'render_modes', 'reset', 'reward_range', 'seed', 'spec', 'step', 'structured_action_space', 'structured_observation_space', 'unpack_batched_obs', 'unwrapped'] +(Pdb) env.render +> +(Pdb) env.observation_space +Box(-3.4028235e+38, 3.4028235e+38, (19591,), float32) +(Pdb) env.observation_space[0] +*** TypeError: 'Box' object is not subscriptable +(Pdb) env.observation_space +Box(-3.4028235e+38, 3.4028235e+38, (19591,), float32) +(Pdb) + ''' + + + n_input_channels = env.observation_space self.cnn = nn.Sequential( nn.Conv2d(n_input_channels, 32*2, kernel_size=8, stride=4, padding=(2, 0)), nn.ReLU(), @@ -70,7 +89,7 @@ def __init__( # Compute shape by doing one forward pass with th.no_grad(): - n_flatten = self.cnn(th.as_tensor(env['image'].sample()[None]).float()).shape[1] + n_flatten = self.cnn(th.as_tensor(env.observation_space['image'].sample()[None]).float()).shape[1] self.cnn_linear = nn.Sequential(nn.Linear(n_flatten, cnn_output_dim), nn.ReLU()) @@ -87,7 +106,7 @@ def __init__( self.minimap_warp_embedding = nn.Embedding(830, warp_emb_dim, padding_idx=0) # minimap (14 + 8 + 8, 9, 10) - n_input_channels = env['minimap'].shape[0] + sprite_emb_dim + warp_emb_dim + n_input_channels = env.observation_space['minimap'].shape[0] + sprite_emb_dim + warp_emb_dim self.minimap_cnn = nn.Sequential( nn.Conv2d(n_input_channels, 32*2, kernel_size=4, stride=1, padding=0), nn.ReLU(),