Skip to content

Commit

Permalink
Save & load baselines with wandb!
Browse files Browse the repository at this point in the history
  • Loading branch information
jsuarez5341 committed Jan 14, 2024
1 parent b7d2384 commit 5d26463
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pdb import set_trace as T
import argparse
import shutil
import sys
import os

Expand Down Expand Up @@ -47,11 +48,11 @@ def make_policy(env, env_module, args):

return policy.to(args.train.device)

def init_wandb(args, env_module):
def init_wandb(args, env_module, name=None, resume=True):
#os.environ["WANDB_SILENT"] = "true"

import wandb
wandb.init(
return wandb.init(
id=args.exp_name or wandb.util.generate_id(),
project=args.wandb_project,
entity=args.wandb_entity,
Expand All @@ -62,12 +63,11 @@ def init_wandb(args, env_module):
'policy': args.policy,
'recurrent': args.recurrent,
},
name=args.config,
name=name or args.config,
monitor_gym=True,
save_code=True,
resume=True,
resume=resume,
)
return wandb.run.id

def sweep(args, env_module, make_env):
import wandb
Expand Down Expand Up @@ -149,6 +149,7 @@ def train(args, env_module, make_env):
parser.add_argument('--env', type=str, default=None, help='Name of specific environment to run')
parser.add_argument('--mode', type=str, default='train', help='train/sweep/evaluate')
parser.add_argument('--eval-model-path', type=str, default=None, help='Path to model to evaluate')
parser.add_argument('--baseline', action='store_true', help='Baseline run')
parser.add_argument('--no-render', action='store_true', help='Disable render during evaluate')
parser.add_argument('--exp-name', type=str, default=None, help="Resume from experiment")
parser.add_argument('--vectorization', type=str, default='serial', help='Vectorization method (serial, multiprocessing, ray)')
Expand Down Expand Up @@ -215,7 +216,19 @@ def train(args, env_module, make_env):
if args.mode == 'sweep':
args.track = True
elif args.track:
args.exp_name = init_wandb(args, env_module)
args.exp_name = init_wandb(args, env_module).id
elif args.baseline:
args.track = True
args.exp_name = args.config
args.wandb_group = f'puf-{pufferlib.__version__}-baseline'
shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True)
run = init_wandb(args, env_module, name=args.exp_name, resume=False)
if args.mode == 'evaluate':
model_name = f'puf{pufferlib.__version__}-{args.config}_model:latest'
artifact = run.use_artifact(model_name)
data_dir = artifact.download()
model_file = max(os.listdir(data_dir))
args.eval_model_path = os.path.join(data_dir, model_file)

if args.mode == 'train':
train(args, env_module, make_env)
Expand Down

0 comments on commit 5d26463

Please sign in to comment.