forked from NeuralMMO/baselines
-
Notifications
You must be signed in to change notification settings - Fork 8
/
train.py
139 lines (118 loc) · 5.43 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import logging
import torch
from pufferlib.vectorization import Serial, Multiprocessing
from pufferlib.policy_store import DirectoryPolicyStore
from pufferlib.frameworks import cleanrl
import environment
from reinforcement_learning import clean_pufferl, policy, config
# NOTE: this file changes when running curriculum generation track
# Run test_task_encoder.py to regenerate this file (or get it from the repo)
BASELINE_CURRICULUM_FILE = "reinforcement_learning/curriculum_with_embedding.pkl"
CUSTOM_CURRICULUM_FILE = "curriculum_generation/custom_curriculum_with_embedding.pkl"
def setup_env(args):
run_dir = os.path.join(args.runs_dir, args.run_name)
os.makedirs(run_dir, exist_ok=True)
logging.info("Training run: %s (%s)", args.run_name, run_dir)
logging.info("Training args: %s", args)
policy_store = None
if args.policy_store_dir is None:
args.policy_store_dir = os.path.join(run_dir, "policy_store")
logging.info("Using policy store from %s", args.policy_store_dir)
policy_store = DirectoryPolicyStore(args.policy_store_dir)
def make_policy(envs):
learner_policy = policy.Baseline(
envs.driver_env,
input_size=args.input_size,
hidden_size=args.hidden_size,
task_size=args.task_size
)
return cleanrl.Policy(learner_policy)
trainer = clean_pufferl.CleanPuffeRL(
device=torch.device(args.device),
seed=args.seed,
env_creator=environment.make_env_creator(args),
env_creator_kwargs={},
agent_creator=make_policy,
data_dir=run_dir,
exp_name=args.run_name,
policy_store=policy_store,
wandb_entity=args.wandb_entity,
wandb_project=args.wandb_project,
wandb_extra_data=args,
checkpoint_interval=args.checkpoint_interval,
vectorization=Serial if args.use_serial_vecenv else Multiprocessing,
total_timesteps=args.train_num_steps,
num_envs=args.num_envs,
num_cores=args.num_cores or args.num_envs,
num_buffers=args.num_buffers,
batch_size=args.rollout_batch_size,
learning_rate=args.ppo_learning_rate,
selfplay_learner_weight=args.learner_weight,
selfplay_num_policies=args.max_opponent_policies + 1,
#record_loss = args.record_loss,
)
return trainer
def reinforcement_learning_track(trainer, args):
while not trainer.done_training():
trainer.evaluate()
trainer.train(
update_epochs=args.ppo_update_epochs,
bptt_horizon=args.bptt_horizon,
batch_rows=args.ppo_training_batch_size // args.bptt_horizon,
clip_coef=args.clip_coef,
)
def curriculum_generation_track(trainer, args, use_elm=True):
from curriculum_generation.task_encoder import TaskEncoder
LLM_CHECKPOINT = "Salesforce/codegen25-7b-instruct"
if use_elm:
from curriculum_generation import manual_curriculum
from curriculum_generation.elm import OpenELMTaskGenerator
NUM_SEED_TASKS = 20
NUM_NEW_TASKS = 5
ELM_DEBUG = True
task_encoder = TaskEncoder(LLM_CHECKPOINT, manual_curriculum, batch_size=2)
task_generator = OpenELMTaskGenerator(manual_curriculum.curriculum, LLM_CHECKPOINT)
# Generating new tasks and evaluating all candidate training tasks
for _ in range(3):
# NOTE: adjust NUM_SEED_TASKS to fit your gpu
seed_task_list = task_generator.sample_tasks(NUM_SEED_TASKS, random_ratio=1)
new_task_list = task_generator.evolve_tasks(seed_task_list, NUM_NEW_TASKS, debug=ELM_DEBUG)
task_generator.add_tasks(new_task_list)
task_encoder.get_task_embedding(seed_task_list + new_task_list, save_to_file=CUSTOM_CURRICULUM_FILE)
# CHECK ME: the trainer will automatically use the new task embedding file
_, _, infos = trainer.evaluate()
task_generator.update(infos) # update the task stats
# NOTE: sample_tasks() uses task stats to sample learnable tasks
curriculum = task_generator.sample_tasks(NUM_SEED_TASKS*3, random_ratio=0.3) # NOTE: arbitrary numbers
else:
from curriculum_generation import curriculum_tutorial # custom tutorial
task_encoder = TaskEncoder(LLM_CHECKPOINT, curriculum_tutorial, batch_size=2)
curriculum = curriculum_tutorial.curriculum
# Use the train_task_spec to train agents
task_encoder.get_task_embedding(curriculum, save_to_file=CUSTOM_CURRICULUM_FILE)
task_encoder.close()
trainer.data.sort_keys = []
reinforcement_learning_track(trainer, args)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# You can either edit the defaults in config.py or set args
# from the commandline.
args = config.create_config(config.Config)
# Avoid OOMing your machine for local testing
if args.local_mode:
args.num_envs = 1
args.num_buffers = 1
args.use_serial_vecenv = True
args.rollout_batch_size = 2**10
if args.track == "rl":
args.tasks_path = BASELINE_CURRICULUM_FILE
trainer = setup_env(args)
reinforcement_learning_track(trainer, args)
elif args.track == "curriculum":
args.tasks_path = CUSTOM_CURRICULUM_FILE
trainer = setup_env(args)
curriculum_generation_track(trainer, args, use_elm=True)
else:
raise ValueError(f"Unknown track {args.track}, must be 'rl' or 'curriculum'")
trainer.close()