-
Notifications
You must be signed in to change notification settings - Fork 490
/
DeepMimic.py
314 lines (239 loc) · 6.98 KB
/
DeepMimic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import numpy as np
import sys
import random
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GLU import *
from env.deepmimic_env import DeepMimicEnv
from learning.rl_world import RLWorld
from util.arg_parser import ArgParser
from util.logger import Logger
import util.mpi_util as MPIUtil
import util.util as Util
# Dimensions of the window we are drawing into.
win_width = 800
win_height = int(win_width * 9.0 / 16.0)
reshaping = False
# anim
fps = 60
update_timestep = 1.0 / fps
display_anim_time = int(1000 * update_timestep)
animating = True
playback_speed = 1
playback_delta = 0.05
# FPS counter
prev_time = 0
updates_per_sec = 0
args = []
world = None
def build_arg_parser(args):
arg_parser = ArgParser()
arg_parser.load_args(args)
arg_file = arg_parser.parse_string('arg_file', '')
if (arg_file != ''):
succ = arg_parser.load_file(arg_file)
assert succ, Logger.print('Failed to load args from: ' + arg_file)
rand_seed_key = 'rand_seed'
if (arg_parser.has_key(rand_seed_key)):
rand_seed = arg_parser.parse_int(rand_seed_key)
rand_seed += 1000 * MPIUtil.get_proc_rank()
Util.set_global_seeds(rand_seed)
return arg_parser
def update_intermediate_buffer():
if not (reshaping):
if (win_width != world.env.get_win_width() or win_height != world.env.get_win_height()):
world.env.reshape(win_width, win_height)
return
def update_world(world, time_elapsed):
num_substeps = world.env.get_num_update_substeps()
timestep = time_elapsed / num_substeps
num_substeps = 1 if (time_elapsed == 0) else num_substeps
for i in range(num_substeps):
world.update(timestep)
valid_episode = world.env.check_valid_episode()
if valid_episode:
end_episode = world.env.is_episode_end()
if (end_episode):
world.end_episode()
world.reset()
break
else:
world.reset()
break
return
def draw():
global reshaping
update_intermediate_buffer()
world.env.draw()
glutSwapBuffers()
reshaping = False
return
def reshape(w, h):
global reshaping
global win_width
global win_height
reshaping = True
win_width = w
win_height = h
return
def step_anim(timestep):
global animating
global world
update_world(world, timestep)
animating = False
glutPostRedisplay()
return
def reload():
global world
global args
world = build_world(args, enable_draw=True)
return
def reset():
world.reset()
return
def get_num_timesteps():
global playback_speed
num_steps = int(playback_speed)
if (num_steps == 0):
num_steps = 1
num_steps = np.abs(num_steps)
return num_steps
def calc_display_anim_time(num_timestes):
global display_anim_time
global playback_speed
anim_time = int(display_anim_time * num_timestes / playback_speed)
anim_time = np.abs(anim_time)
return anim_time
def shutdown():
global world
Logger.print('Shutting down...')
world.shutdown()
sys.exit(0)
return
def get_curr_time():
curr_time = glutGet(GLUT_ELAPSED_TIME)
return curr_time
def init_time():
global prev_time
global updates_per_sec
prev_time = get_curr_time()
updates_per_sec = 0
return
def animate(callback_val):
global prev_time
global updates_per_sec
global world
counter_decay = 0
if (animating):
num_steps = get_num_timesteps()
curr_time = get_curr_time()
time_elapsed = curr_time - prev_time
prev_time = curr_time;
timestep = -update_timestep if (playback_speed < 0) else update_timestep
for i in range(num_steps):
update_world(world, timestep)
# FPS counting
update_count = num_steps / (0.001 * time_elapsed)
if (np.isfinite(update_count)):
updates_per_sec = counter_decay * updates_per_sec + (1 - counter_decay) * update_count;
world.env.set_updates_per_sec(updates_per_sec);
timer_step = calc_display_anim_time(num_steps)
update_dur = get_curr_time() - curr_time
timer_step -= update_dur
timer_step = np.maximum(timer_step, 0)
glutTimerFunc(int(timer_step), animate, 0)
glutPostRedisplay()
if (world.env.is_done()):
shutdown()
return
def toggle_animate():
global animating
animating = not animating
if (animating):
glutTimerFunc(display_anim_time, animate, 0)
return
def change_playback_speed(delta):
global playback_speed
prev_playback = playback_speed
playback_speed += delta
world.env.set_playback_speed(playback_speed)
if (np.abs(prev_playback) < 0.0001 and np.abs(playback_speed) > 0.0001):
glutTimerFunc(display_anim_time, animate, 0)
return
def toggle_training():
global world
world.enable_training = not world.enable_training
if (world.enable_training):
Logger.print('Training enabled')
else:
Logger.print('Training disabled')
return
def keyboard(key, x, y):
key_val = int.from_bytes(key, byteorder='big')
world.env.keyboard(key_val, x, y)
if (key == b'\x1b'): # escape
shutdown()
elif (key == b' '):
toggle_animate();
elif (key == b'>'):
step_anim(update_timestep);
elif (key == b'<'):
step_anim(-update_timestep);
elif (key == b','):
change_playback_speed(-playback_delta);
elif (key == b'.'):
change_playback_speed(playback_delta);
elif (key == b'/'):
change_playback_speed(-playback_speed + 1);
elif (key == b'l'):
reload();
elif (key == b'r'):
reset();
elif (key == b't'):
toggle_training()
glutPostRedisplay()
return
def mouse_click(button, state, x, y):
world.env.mouse_click(button, state, x, y)
glutPostRedisplay()
def mouse_move(x, y):
world.env.mouse_move(x, y)
glutPostRedisplay()
return
def init_draw():
glutInit()
glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE | GLUT_DEPTH)
glutInitWindowSize(win_width, win_height)
glutCreateWindow(b'DeepMimic')
return
def setup_draw():
glutDisplayFunc(draw)
glutReshapeFunc(reshape)
glutKeyboardFunc(keyboard)
glutMouseFunc(mouse_click)
glutMotionFunc(mouse_move)
glutTimerFunc(display_anim_time, animate, 0)
reshape(win_width, win_height)
world.env.reshape(win_width, win_height)
return
def build_world(args, enable_draw, playback_speed=1):
arg_parser = build_arg_parser(args)
env = DeepMimicEnv(args, enable_draw)
world = RLWorld(env, arg_parser)
world.env.set_playback_speed(playback_speed)
return world
def draw_main_loop():
init_time()
glutMainLoop()
return
def main():
global args
# Command line arguments
args = sys.argv[1:]
init_draw()
reload()
setup_draw()
draw_main_loop()
return
if __name__ == '__main__':
main()