forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enjoy.py
95 lines (76 loc) · 2.48 KB
/
enjoy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
# workaround to unpickle olf model files
import sys
import numpy as np
import torch
from a2c_ppo_acktr.envs import VecPyTorch, make_vec_envs
from a2c_ppo_acktr.utils import get_render_func, get_vec_normalize
sys.path.append('a2c_ppo_acktr')
parser = argparse.ArgumentParser(description='RL')
parser.add_argument(
'--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
'--log-interval',
type=int,
default=10,
help='log interval, one log per n updates (default: 10)')
parser.add_argument(
'--env-name',
default='PongNoFrameskip-v4',
help='environment to train on (default: PongNoFrameskip-v4)')
parser.add_argument(
'--load-dir',
default='./trained_models/',
help='directory to save agent logs (default: ./trained_models/)')
parser.add_argument(
'--non-det',
action='store_true',
default=False,
help='whether to use a non-deterministic policy')
args = parser.parse_args()
args.det = not args.non_det
env = make_vec_envs(
args.env_name,
args.seed + 1000,
1,
None,
None,
device='cpu',
allow_early_resets=False)
# Get a render function
render_func = get_render_func(env)
# We need to use the same statistics for normalization as used in training
actor_critic, ob_rms = \
torch.load(os.path.join(args.load_dir, args.env_name + ".pt"))
vec_norm = get_vec_normalize(env)
if vec_norm is not None:
vec_norm.eval()
vec_norm.ob_rms = ob_rms
recurrent_hidden_states = torch.zeros(1,
actor_critic.recurrent_hidden_state_size)
masks = torch.zeros(1, 1)
obs = env.reset()
if render_func is not None:
render_func('human')
if args.env_name.find('Bullet') > -1:
import pybullet as p
torsoId = -1
for i in range(p.getNumBodies()):
if (p.getBodyInfo(i)[0].decode() == "torso"):
torsoId = i
while True:
with torch.no_grad():
value, action, _, recurrent_hidden_states = actor_critic.act(
obs, recurrent_hidden_states, masks, deterministic=args.det)
# Obser reward and next obs
obs, reward, done, _ = env.step(action)
masks.fill_(0.0 if done else 1.0)
if args.env_name.find('Bullet') > -1:
if torsoId > -1:
distance = 5
yaw = 0
humanPos, humanOrn = p.getBasePositionAndOrientation(torsoId)
p.resetDebugVisualizerCamera(distance, yaw, -20, humanPos)
if render_func is not None:
render_func('human')