-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_dagger.py
362 lines (319 loc) · 19.4 KB
/
train_dagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import datetime
import os
import copy
import json
import importlib
import numpy as np
import sys
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import alfworld.agents.environment
import alfworld.agents.modules.generic as generic
from alfworld.agents.eval import evaluate_dagger
from alfworld.agents.agent import TextDAggerAgent
from alfworld.agents.utils.misc import extract_admissible_commands
from alfworld.agents.modules.generic import HistoryScoreCache, EpisodicCountingMemory, ObjCentricEpisodicMemory
def train():
time_1 = datetime.datetime.now()
step_time = []
config = generic.load_config()
agent = TextDAggerAgent(config)
alfred_env = getattr(alfworld.agents.environment, config["env"]["type"])(config, train_eval="train")
env = alfred_env.init_env(batch_size=agent.batch_size)
id_eval_env, num_id_eval_game = None, 0
ood_eval_env, num_ood_eval_game = None, 0
if agent.run_eval:
# in distribution
if config['dataset']['eval_id_data_path'] is not None:
alfred_env = getattr(alfworld.agents.environment, config["general"]["evaluate"]["env"]["type"])(config, train_eval="eval_in_distribution")
id_eval_env = alfred_env.init_env(batch_size=agent.eval_batch_size)
num_id_eval_game = alfred_env.num_games
# out of distribution
if config['dataset']['eval_ood_data_path'] is not None:
alfred_env = getattr(alfworld.agents.environment, config["general"]["evaluate"]["env"]["type"])(config, train_eval="eval_out_of_distribution")
ood_eval_env = alfred_env.init_env(batch_size=agent.eval_batch_size)
num_ood_eval_game = alfred_env.num_games
output_dir = config["general"]["save_path"]
data_dir = config["general"]["save_path"]
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visdom
if config["general"]["visdom"]:
import visdom
viz = visdom.Visdom()
reward_win, step_win = None, None
loss_win = None
viz_game_points, viz_game_step, viz_loss = [], [], []
viz_student_points, viz_student_step = [], []
viz_id_eval_game_points, viz_id_eval_step = [], []
viz_ood_eval_game_points, viz_ood_eval_step = [], []
step_in_total = 0
episode_no = 0
running_avg_game_points = HistoryScoreCache(capacity=500)
running_avg_student_points = HistoryScoreCache(capacity=500)
running_avg_game_steps = HistoryScoreCache(capacity=500)
running_avg_student_steps = HistoryScoreCache(capacity=500)
running_avg_dagger_loss = HistoryScoreCache(capacity=500)
json_file_name = agent.experiment_tag.replace(" ", "_")
best_performance_so_far, best_ood_performance_so_far = 0.0, 0.0
# load model from checkpoint
if agent.load_pretrained:
print("Checking {} for {}".format(data_dir, agent.load_from_tag))
if os.path.exists(data_dir + "/" + agent.load_from_tag + ".pt"):
agent.load_pretrained_model(data_dir + "/" + agent.load_from_tag + ".pt")
agent.update_target_net()
while(True):
if episode_no > agent.max_episode:
break
np.random.seed(episode_no)
env.seed(episode_no)
obs, infos = env.reset()
game_names = infos["extra.gamefile"]
batch_size = len(obs)
agent.train()
agent.init(batch_size)
previous_dynamics = None
execute_actions = []
prev_step_dones, prev_rewards = [], []
for _ in range(batch_size):
execute_actions.append("restart")
prev_step_dones.append(0.0)
prev_rewards.append(0.0)
observation_strings = list(obs)
task_desc_strings, observation_strings = agent.get_task_and_obs(observation_strings)
task_desc_strings = agent.preprocess_task(task_desc_strings)
observation_strings = agent.preprocess_observation(observation_strings)
first_sight_strings = copy.deepcopy(observation_strings)
agent.observation_pool.push_first_sight(first_sight_strings)
if agent.action_space == "exhaustive":
action_candidate_list = [extract_admissible_commands(intro, obs) for intro, obs in zip(first_sight_strings, observation_strings)]
else:
action_candidate_list = list(infos["admissible_commands"])
action_candidate_list = agent.preprocess_action_candidates(action_candidate_list)
observation_strings = [item + " [SEP] " + a for item, a in zip(observation_strings, execute_actions)] # appending the chosen action at previous step into the observation
# it requires to store sequences of transitions into memory with order,
# so we use a cache to keep what agents returns, and push them into memory
# altogether in the end of game.
transition_cache = []
still_running_mask = []
sequence_game_points = []
print_actions = []
report = agent.report_frequency > 0 and (episode_no % agent.report_frequency <= (episode_no - batch_size) % agent.report_frequency)
for step_no in range(agent.max_nb_steps_per_episode):
# push obs into observation pool
agent.observation_pool.push_batch(observation_strings)
# get most recent k observations
most_recent_observation_strings = agent.observation_pool.get()
# predict actions
if agent.action_space == "generation":
agent_actions, current_dynamics = agent.command_generation_greedy_generation(most_recent_observation_strings, task_desc_strings, previous_dynamics)
elif agent.action_space in ["admissible", "exhaustive"]:
agent_actions, _, current_dynamics = agent.admissible_commands_greedy_generation(most_recent_observation_strings, task_desc_strings, action_candidate_list, previous_dynamics)
else:
raise NotImplementedError()
random_actions, _ = agent.admissible_commands_random_generation(action_candidate_list)
expert_actions, expert_indices = [], []
for b in range(batch_size):
if "expert_plan" in infos and len(infos["expert_plan"][b]) > 0:
next_action = infos["expert_plan"][b][0]
expert_actions.append(next_action)
expert_indices.append(action_candidate_list[b].index(next_action) if agent.action_space in ["admissible", "exhaustive"] else -1)
else:
expert_actions.append("look")
expert_indices.append(action_candidate_list[b].index("look") if agent.action_space in ["admissible", "exhaustive"] else -1)
from_which = np.random.uniform(low=0.0, high=1.0, size=(batch_size,))
execute_actions = []
for b in range(batch_size):
if not report and from_which[b] <= agent.fraction_assist:
execute_actions.append(expert_actions[b])
elif not report and from_which[b] <= agent.fraction_assist + agent.fraction_random:
execute_actions.append(random_actions[b])
else:
execute_actions.append(agent_actions[b])
replay_info = [most_recent_observation_strings, task_desc_strings, action_candidate_list, expert_actions, expert_indices]
transition_cache.append(replay_info)
env_step_start_time = datetime.datetime.now()
obs, _, dones, infos = env.step(execute_actions)
env_step_stop_time = datetime.datetime.now()
step_time.append((env_step_stop_time-env_step_start_time).microseconds / (float(batch_size)))
scores = [float(item) for item in infos["won"]]
dones = [float(item) for item in dones]
observation_strings = list(obs)
observation_strings = agent.preprocess_observation(observation_strings)
if agent.action_space == "exhaustive":
action_candidate_list = [extract_admissible_commands(intro, obs) for intro, obs in zip(first_sight_strings, observation_strings)]
else:
action_candidate_list = list(infos["admissible_commands"])
action_candidate_list = agent.preprocess_action_candidates(action_candidate_list)
observation_strings = [item + " [SEP] " + a for item, a in zip(observation_strings, execute_actions)] # appending the chosen action at previous step into the observation
previous_dynamics = current_dynamics
if step_in_total % agent.dagger_update_per_k_game_steps == 0:
dagger_loss = agent.update_dagger()
if dagger_loss is not None:
running_avg_dagger_loss.push(dagger_loss)
if step_no == agent.max_nb_steps_per_episode - 1:
# terminate the game because DQN requires one extra step
dones = [1.0 for _ in dones]
step_in_total += 1
still_running = [1.0 - float(item) for item in prev_step_dones] # list of float
prev_step_dones = dones
step_rewards = [float(curr) - float(prev) for curr, prev in zip(scores, prev_rewards)] # list of float
prev_rewards = scores
sequence_game_points.append(step_rewards)
still_running_mask.append(still_running)
print_actions.append(execute_actions[0] if still_running[0] else "--")
# if all ended, break
if np.sum(still_running) == 0:
break
still_running_mask_np = np.array(still_running_mask)
game_points_np = np.array(sequence_game_points) * still_running_mask_np # step x batch
# push experience into replay buffer (dagger)
if not report:
for b in range(batch_size):
trajectory = []
for i in range(len(transition_cache)):
observation_strings, task_strings, action_candidate_list, expert_actions, expert_indices = transition_cache[i]
trajectory.append([observation_strings[b], task_strings[b], action_candidate_list[b],
expert_actions[b], expert_indices[b]])
if still_running_mask_np[i][b] == 0.0:
break
agent.dagger_memory.push(trajectory)
for b in range(batch_size):
if report:
running_avg_student_points.push(np.sum(game_points_np, 0)[b])
running_avg_student_steps.push(np.sum(still_running_mask_np, 0)[b])
else:
running_avg_game_points.push(np.sum(game_points_np, 0)[b])
running_avg_game_steps.push(np.sum(still_running_mask_np, 0)[b])
# finish game
agent.finish_of_episode(episode_no, batch_size)
episode_no += batch_size
if not report:
continue
time_2 = datetime.datetime.now()
time_spent_seconds = (time_2-time_1).seconds
eps_per_sec = float(episode_no) / time_spent_seconds
avg_step_time = np.mean(np.array(step_time))
print("Model: {:s} | Episode: {:3d} | {:s} | time spent: {:s} | eps/sec : {:2.3f} | avg step time: {:2.10f} | loss: {:2.3f} | game points: {:2.3f} | used steps: {:2.3f} | student points: {:2.3f} | student steps: {:2.3f} | fraction assist: {:2.3f} | fraction random: {:2.3f}".format(agent.experiment_tag, episode_no, game_names[0], str(time_2 - time_1).rsplit(".")[0], eps_per_sec, avg_step_time, running_avg_dagger_loss.get_avg(), running_avg_game_points.get_avg(), running_avg_game_steps.get_avg(), running_avg_student_points.get_avg(), running_avg_student_steps.get_avg(), agent.fraction_assist, agent.fraction_random))
# print(game_id + ": " + " | ".join(print_actions))
print(" | ".join(print_actions))
# evaluate
id_eval_game_points, id_eval_game_step = 0.0, 0.0
ood_eval_game_points, ood_eval_game_step = 0.0, 0.0
if agent.run_eval:
if id_eval_env is not None:
id_eval_res = evaluate_dagger(id_eval_env, agent, num_id_eval_game)
id_eval_game_points, id_eval_game_step = id_eval_res['average_points'], id_eval_res['average_steps']
if ood_eval_env is not None:
ood_eval_res = evaluate_dagger(ood_eval_env, agent, num_ood_eval_game)
ood_eval_game_points, ood_eval_game_step = ood_eval_res['average_points'], ood_eval_res['average_steps']
if id_eval_game_points >= best_performance_so_far:
best_performance_so_far = id_eval_game_points
agent.save_model_to_path(output_dir + "/" + agent.experiment_tag + "_id.pt")
if ood_eval_game_points >= best_ood_performance_so_far:
best_ood_performance_so_far = ood_eval_game_points
agent.save_model_to_path(output_dir + "/" + agent.experiment_tag + "_ood.pt")
else:
if running_avg_student_points.get_avg() >= best_performance_so_far:
best_performance_so_far = running_avg_student_points.get_avg()
agent.save_model_to_path(output_dir + "/" + agent.experiment_tag + ".pt")
# plot using visdom
if config["general"]["visdom"]:
viz_game_points.append(running_avg_game_points.get_avg())
viz_game_step.append(running_avg_game_steps.get_avg())
viz_student_points.append(running_avg_student_points.get_avg())
viz_student_step.append(running_avg_student_steps.get_avg())
viz_loss.append(running_avg_dagger_loss.get_avg())
viz_id_eval_game_points.append(id_eval_game_points)
viz_id_eval_step.append(id_eval_game_step)
viz_ood_eval_game_points.append(ood_eval_game_points)
viz_ood_eval_step.append(ood_eval_game_step)
viz_x = np.arange(len(viz_game_points)).tolist()
if reward_win is None:
reward_win = viz.line(X=viz_x, Y=viz_game_points,
opts=dict(title=agent.experiment_tag + "_game_points"),
name="game points")
viz.line(X=viz_x, Y=viz_student_points,
opts=dict(title=agent.experiment_tag + "_student_points"),
win=reward_win, update='append', name="student points")
viz.line(X=viz_x, Y=viz_id_eval_game_points,
opts=dict(title=agent.experiment_tag + "_id_eval_game_points"),
win=reward_win, update='append', name="id eval game points")
viz.line(X=viz_x, Y=viz_ood_eval_game_points,
opts=dict(title=agent.experiment_tag + "_ood_eval_game_points"),
win=reward_win, update='append', name="ood eval game points")
else:
viz.line(X=[len(viz_game_points) - 1], Y=[viz_game_points[-1]],
opts=dict(title=agent.experiment_tag + "_game_points"),
win=reward_win,
update='append', name="game points")
viz.line(X=[len(viz_student_points) - 1], Y=[viz_student_points[-1]],
opts=dict(title=agent.experiment_tag + "_student_points"),
win=reward_win,
update='append', name="student points")
viz.line(X=[len(viz_id_eval_game_points) - 1], Y=[viz_id_eval_game_points[-1]],
opts=dict(title=agent.experiment_tag + "_id_eval_game_points"),
win=reward_win,
update='append', name="id eval game points")
viz.line(X=[len(viz_ood_eval_game_points) - 1], Y=[viz_ood_eval_game_points[-1]],
opts=dict(title=agent.experiment_tag + "_ood_eval_game_points"),
win=reward_win,
update='append', name="ood eval game points")
if step_win is None:
step_win = viz.line(X=viz_x, Y=viz_game_step,
opts=dict(title=agent.experiment_tag + "_game_step"),
name="game step")
viz.line(X=viz_x, Y=viz_student_step,
opts=dict(title=agent.experiment_tag + "_student_step"),
win=step_win, update='append', name="student step")
viz.line(X=viz_x, Y=viz_id_eval_step,
opts=dict(title=agent.experiment_tag + "_id_eval_step"),
win=step_win, update='append', name="id eval step")
viz.line(X=viz_x, Y=viz_ood_eval_step,
opts=dict(title=agent.experiment_tag + "_ood_eval_step"),
win=step_win, update='append', name="ood eval step")
else:
viz.line(X=[len(viz_game_step) - 1], Y=[viz_game_step[-1]],
opts=dict(title=agent.experiment_tag + "_game_step"),
win=step_win,
update='append', name="game step")
viz.line(X=[len(viz_student_step) - 1], Y=[viz_student_step[-1]],
opts=dict(title=agent.experiment_tag + "_student_step"),
win=step_win,
update='append', name="student step")
viz.line(X=[len(viz_id_eval_step) - 1], Y=[viz_id_eval_step[-1]],
opts=dict(title=agent.experiment_tag + "_id_eval_step"),
win=step_win,
update='append', name="id eval step")
viz.line(X=[len(viz_ood_eval_step) - 1], Y=[viz_ood_eval_step[-1]],
opts=dict(title=agent.experiment_tag + "_ood_eval_step"),
win=step_win,
update='append', name="ood eval step")
if loss_win is None:
loss_win = viz.line(X=viz_x, Y=viz_loss,
opts=dict(title=agent.experiment_tag + "_loss"),
name="loss")
else:
viz.line(X=[len(viz_loss) - 1], Y=[viz_loss[-1]],
opts=dict(title=agent.experiment_tag + "_loss"),
win=loss_win,
update='append', name="loss")
# write accuracies down into file
_s = json.dumps({"time spent": str(time_2 - time_1).rsplit(".")[0],
"time spent seconds": time_spent_seconds,
"episodes": episode_no,
"episodes per second": eps_per_sec,
"loss": str(running_avg_dagger_loss.get_avg()),
"train game points": str(running_avg_game_points.get_avg()),
"train game steps": str(running_avg_game_steps.get_avg()),
"train student points": str(running_avg_student_points.get_avg()),
"train student steps": str(running_avg_student_steps.get_avg()),
"id eval game points": str(id_eval_game_points),
"id eval steps": str(id_eval_game_step),
"ood eval game points": str(ood_eval_game_points),
"ood eval steps": str(ood_eval_game_step)})
with open(output_dir + "/" + json_file_name + '.json', 'a+') as outfile:
outfile.write(_s + '\n')
outfile.flush()
agent.save_model_to_path(output_dir + "/" + agent.experiment_tag + "_final.pt")
if __name__ == '__main__':
train()