Skip to content

Commit

Permalink
update. runs at about 1k epoch sps. requires pokegym @ https://github…
Browse files Browse the repository at this point in the history
….com/xinpw8/pokegym.git -b boey_0.6_pokegym. edit all params in config.yaml.
  • Loading branch information
xinpw8 committed Feb 28, 2024
1 parent ec05208 commit d90b6bc
Show file tree
Hide file tree
Showing 22 changed files with 1,325 additions and 2,427 deletions.
334 changes: 140 additions & 194 deletions clean_pufferl.py

Large diffs are not rendered by default.

705 changes: 55 additions & 650 deletions config.yaml

Large diffs are not rendered by default.

96 changes: 41 additions & 55 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import pufferlib
import pufferlib.utils

from clean_pufferl import CleanPuffeRL, rollout, done_training
import clean_pufferl


def load_from_config(env):
Expand All @@ -32,22 +32,30 @@ def load_from_config(env):
for key in default_keys:
env_subconfig = env_config.get(key, {})
pkg_subconfig = pkg_config.get(key, {})


# Override first with pkg then with env configs
combined_config[key] = {**defaults[key], **pkg_subconfig, **env_subconfig}
try:
combined_config[key] = {**defaults[key], **pkg_subconfig, **env_subconfig}
# print(f'combo_config: {combined_config[key]}')
except TypeError as e:
pass
# print(f'combined_config={combined_config}')
# print(f' {type(e)} ')
# print(f'key={type(key)}; combined_config[{key}]=sad')
finally:
# print(f'{key} has caused its last problem.')
pass

return pkg, pufferlib.namespace(**combined_config)

def make_policy(env, env_module, args):
policy = env_module.Policy(env, **args.policy)

if args.force_recurrence or env_module.Recurrent is not None:
policy = env_module.Recurrent(env, policy, **args.recurrent)
policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy)

else:
policy = pufferlib.frameworks.cleanrl.Policy(policy)


return policy.to(args.train.device)

Expand Down Expand Up @@ -106,11 +114,12 @@ def get_init_args(fn):
continue
else:
args[name] = param.default if param.default is not inspect.Parameter.empty else None
# print(f'ARGS LINE116 DEMO.PY: {args}\n\n')
return args

def train(args, env_module, make_env):
if args.backend == 'clean_pufferl':
trainer = CleanPuffeRL(
data = clean_pufferl.create(
config=args.train,
agent_creator=make_policy,
agent_kwargs={'env_module': env_module, 'args': args},
Expand All @@ -121,12 +130,12 @@ def train(args, env_module, make_env):
track=args.track,
)

while not done_training(trainer):
trainer.evaluate()
trainer.train()
while not clean_pufferl.done_training(data):
clean_pufferl.evaluate(data)
clean_pufferl.train(data)

print('Done training. Saving data...')
trainer.close()
clean_pufferl.close(data)
print('Run complete')
elif args.backend == 'sb3':
from stable_baselines3 import PPO
Expand All @@ -150,12 +159,12 @@ def train(args, env_module, make_env):
parser.add_argument('--backend', type=str, default='clean_pufferl', help='Train backend (clean_pufferl, sb3)')
parser.add_argument('--config', type=str, default='pokemon_red', help='Configuration in config.yaml to use')
parser.add_argument('--env', type=str, default=None, help='Name of specific environment to run')
parser.add_argument('--mode', type=str, default='train', help='train/sweep/evaluate')
parser.add_argument('--mode', type=str, default='train', choices='train sweep evaluate'.split())
parser.add_argument('--eval-model-path', type=str, default=None, help='Path to model to evaluate')
parser.add_argument('--baseline', action='store_true', help='Baseline run')
parser.add_argument('--no-render', action='store_true', help='Disable render during evaluate')
parser.add_argument('--exp-name', type=str, default=None, help="Resume from experiment")
parser.add_argument('--vectorization', type=str, default='serial', help='Vectorization method (serial, multiprocessing, ray)')
parser.add_argument('--vectorization', type=str, default='serial', choices='serial multiprocessing ray'.split())
parser.add_argument('--wandb-entity', type=str, default='xinpw8', help='WandB entity')
parser.add_argument('--wandb-project', type=str, default='pufferlib', help='WandB project')
parser.add_argument('--wandb-group', type=str, default='debug', help='WandB group')
Expand All @@ -178,60 +187,36 @@ def train(args, env_module, make_env):

# Update config with environment defaults
config.env = {**get_init_args(make_env), **config.env}
# print(f'config.env={config.env}')
config.policy = {**get_init_args(env_module.Policy.__init__), **config.policy}
# print(f'config.policy={config.policy}')
config.recurrent = {**get_init_args(env_module.Recurrent.__init__), **config.recurrent}

# print(f'config.recurrent={config.recurrent}')

# Generate argparse menu from config
# This is also a reason for Spock/Argbind/OmegaConf/pydantic-cli
for name, sub_config in config.items():
args[name] = {}
for key, value in sub_config.items():
data_key = f"{name}.{key}"
cli_key = f"--{data_key}".replace("_", "-")
data_key = f'{name}.{key}'
cli_key = f'--{data_key}'.replace('_', '-')
if isinstance(value, bool) and value is False:
action = "store_false"
parser.add_argument(cli_key, default=value, action="store_true")
clean_parser.add_argument(cli_key, default=value, action="store_true")
action = 'store_false'
parser.add_argument(cli_key, default=value, action='store_true')
clean_parser.add_argument(cli_key, default=value, action='store_true')
elif isinstance(value, bool) and value is True:
data_key = f"{name}.no_{key}"
cli_key = f"--{data_key}".replace("_", "-")
parser.add_argument(cli_key, default=value, action="store_false")
clean_parser.add_argument(cli_key, default=value, action="store_false")
data_key = f'{name}.no_{key}'
cli_key = f'--{data_key}'.replace('_', '-')
parser.add_argument(cli_key, default=value, action='store_false')
clean_parser.add_argument(cli_key, default=value, action='store_false')
else:
parser.add_argument(cli_key, default=value, type=type(value))
clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value))
clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value))

args[name][key] = getattr(parser.parse_known_args()[0], data_key)
args[name] = pufferlib.namespace(**args[name])

clean_parser.parse_args(sys.argv[1:])
args = pufferlib.namespace(**args)

# # Generate argparse menu from config
# for name, sub_config in config.items():
# args[name] = {}
# for key, value in sub_config.items():
# data_key = f'{name}.{key}'
# cli_key = f'--{data_key}'.replace('_', '-')
# if isinstance(value, bool) and value is False:
# action = 'store_false'
# parser.add_argument(cli_key, default=value, action='store_true')
# clean_parser.add_argument(cli_key, default=value, action='store_true')
# elif isinstance(value, bool) and value is True:
# data_key = f'{name}.no_{key}'
# cli_key = f'--{data_key}'.replace('_', '-')
# parser.add_argument(cli_key, default=value, action='store_false')
# clean_parser.add_argument(cli_key, default=value, action='store_false')
# else:
# parser.add_argument(cli_key, default=value, type=type(value))
# clean_parser.add_argument(cli_key, default=value, metavar='', type=type(value))

# args[name][key] = getattr(parser.parse_known_args()[0], data_key)
# args[name] = pufferlib.namespace(**args[name])

# clean_parser.parse_args(sys.argv[1:])
# args = pufferlib.namespace(**args)

vec = args.vectorization
if vec == 'serial':
Expand All @@ -249,23 +234,24 @@ def train(args, env_module, make_env):
args.exp_name = init_wandb(args, env_module).id
elif args.baseline:
args.track = True
args.exp_name = f'puf-{pufferlib.__version__}-{args.config}'
args.wandb_group = f'puf-{pufferlib.__version__}-baseline'
version = '.'.join(pufferlib.__version__.split('.')[:2])
args.exp_name = f'puf-{version}-{args.config}'
args.wandb_group = f'puf-{version}-baseline'
shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True)
run = init_wandb(args, env_module, name=args.exp_name, resume=False)
if args.mode == 'evaluate':
model_name = f'puf-{pufferlib.__version__}-{args.config}_model:latest'
model_name = f'puf-{version}-{args.config}_model:latest'
artifact = run.use_artifact(model_name)
data_dir = artifact.download()
model_file = max(os.listdir(data_dir))
args.eval_model_path = os.path.join(data_dir, model_file)

if args.mode == 'train':
train(args, env_module, make_env)
exit(0)
# exit(0)
elif args.mode == 'sweep':
sweep(args, env_module, make_env)
exit(0)
# exit(0)
elif args.mode == 'evaluate' and pkg != 'pokemon_red':
rollout(
make_env,
Expand All @@ -286,4 +272,4 @@ def train(args, env_module, make_env):
device=args.train.device,
)
elif pkg != 'pokemon_red':
raise ValueError('Mode must be one of train, sweep, or evaluate')
raise ValueError('Mode must be one of train, sweep, or evaluate')
Empty file modified kanto_map_dsv.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 6 additions & 9 deletions pufferlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
from pufferlib import version
__version__ = version.__version__

# Shut deepmind_lab up
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)#, module="deepmind_lab")
try:
from deepmind_lab import dmenv_module # Or whatever the actual module is
except ImportError:
pass

import os
import sys

# Shut pygame up
# Silence noisy packages
original_stdout = sys.stdout
original_stderr = sys.stderr
sys.stdout = open(os.devnull, 'w')
sys.stderr = open(os.devnull, 'w')
try:
import gymnasium
import pygame
except ImportError:
pass
sys.stdout.close()
sys.stderr.close()
sys.stdout = original_stdout
sys.stderr = original_stderr


from pufferlib.namespace import namespace, dataclass
Expand Down
Loading

0 comments on commit d90b6bc

Please sign in to comment.