From e800153cae408abaed19c8564a1f64d26d739ee3 Mon Sep 17 00:00:00 2001 From: Joseph Suarez Date: Mon, 15 Jan 2024 17:10:22 +0000 Subject: [PATCH] Fix path for wandb models --- config.yaml | 1 + demo.py | 4 ++-- pufferlib/version.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/config.yaml b/config.yaml index fc0a7aa9..851c53e7 100644 --- a/config.yaml +++ b/config.yaml @@ -607,6 +607,7 @@ ocean: num_envs: 8 batch_rows: 32 bptt_horizon: 4 + device: cpu env: name: squared bandit: diff --git a/demo.py b/demo.py index 07baf6c5..8d104667 100644 --- a/demo.py +++ b/demo.py @@ -219,12 +219,12 @@ def train(args, env_module, make_env): args.exp_name = init_wandb(args, env_module).id elif args.baseline: args.track = True - args.exp_name = args.config + args.exp_name = f'puf-{pufferlib.__version__}-{args.config}' args.wandb_group = f'puf-{pufferlib.__version__}-baseline' shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True) run = init_wandb(args, env_module, name=args.exp_name, resume=False) if args.mode == 'evaluate': - model_name = f'puf{pufferlib.__version__}-{args.config}_model:latest' + model_name = f'puf-{pufferlib.__version__}-{args.config}_model:latest' artifact = run.use_artifact(model_name) data_dir = artifact.download() model_file = max(os.listdir(data_dir)) diff --git a/pufferlib/version.py b/pufferlib/version.py index ef7eb44d..8411e551 100644 --- a/pufferlib/version.py +++ b/pufferlib/version.py @@ -1 +1 @@ -__version__ = '0.6.0' +__version__ = '0.6.1'