forked from real-stanford/diffusion_policy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo_pusht.py
120 lines (105 loc) · 4.2 KB
/
demo_pusht.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import numpy as np
import click
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.env.pusht.pusht_keypoints_env import PushTKeypointsEnv
import pygame
@click.command()
@click.option('-o', '--output', required=True)
@click.option('-rs', '--render_size', default=96, type=int)
@click.option('-hz', '--control_hz', default=10, type=int)
def main(output, render_size, control_hz):
"""
Collect demonstration for the Push-T task.
Usage: python demo_pusht.py -o data/pusht_demo.zarr
This script is compatible with both Linux and MacOS.
Hover mouse close to the blue circle to start.
Push the T block into the green area.
The episode will automatically terminate if the task is succeeded.
Press "Q" to exit.
Press "R" to retry.
Hold "Space" to pause.
"""
# create replay buffer in read-write mode
replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
# create PushT env with keypoints
kp_kwargs = PushTKeypointsEnv.genenerate_keypoint_manager_params()
env = PushTKeypointsEnv(render_size=render_size, render_action=False, **kp_kwargs)
agent = env.teleop_agent()
clock = pygame.time.Clock()
# episode-level while loop
while True:
episode = list()
# record in seed order, starting with 0
seed = replay_buffer.n_episodes
print(f'starting seed {seed}')
# set seed for env
env.seed(seed)
# reset env and get observations (including info and render for recording)
obs = env.reset()
info = env._get_info()
img = env.render(mode='human')
# loop state
retry = False
pause = False
done = False
plan_idx = 0
pygame.display.set_caption(f'plan_idx:{plan_idx}')
# step-level while loop
while not done:
# process keypress events
for event in pygame.event.get():
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_SPACE:
# hold Space to pause
plan_idx += 1
pygame.display.set_caption(f'plan_idx:{plan_idx}')
pause = True
elif event.key == pygame.K_r:
# press "R" to retry
retry=True
elif event.key == pygame.K_q:
# press "Q" to exit
exit(0)
if event.type == pygame.KEYUP:
if event.key == pygame.K_SPACE:
pause = False
# handle control flow
if retry:
break
if pause:
continue
# get action from mouse
# None if mouse is not close to the agent
act = agent.act(obs)
if not act is None:
# teleop started
# state dim 2+3
state = np.concatenate([info['pos_agent'], info['block_pose']])
# discard unused information such as visibility mask and agent pos
# for compatibility
keypoint = obs.reshape(2,-1)[0].reshape(-1,2)[:9]
data = {
'img': img,
'state': np.float32(state),
'keypoint': np.float32(keypoint),
'action': np.float32(act),
'n_contacts': np.float32([info['n_contacts']])
}
episode.append(data)
# step env and render
obs, reward, done, info = env.step(act)
img = env.render(mode='human')
# regulate control frequency
clock.tick(control_hz)
if not retry:
# save episode buffer to replay buffer (on disk)
data_dict = dict()
for key in episode[0].keys():
data_dict[key] = np.stack(
[x[key] for x in episode])
replay_buffer.add_episode(data_dict, compressors='disk')
print(f'saved seed {seed}')
else:
print(f'retry seed {seed}')
if __name__ == "__main__":
main()