-
Notifications
You must be signed in to change notification settings - Fork 21
/
02_train_a2c.py
executable file
·130 lines (104 loc) · 4.72 KB
/
02_train_a2c.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
import os
import time
import math
import ptan
import gym
import pybullet_envs
import argparse
from tensorboardX import SummaryWriter
from lib import model, common
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
ENV_ID = "MinitaurBulletEnv-v0"
GAMMA = 0.99
REWARD_STEPS = 2
BATCH_SIZE = 32
LEARNING_RATE = 5e-5
ENTROPY_BETA = 1e-4
TEST_ITERS = 1000
def test_net(net, env, count=10, device="cpu"):
rewards = 0.0
steps = 0
for _ in range(count):
obs = env.reset()
while True:
obs_v = ptan.agent.float32_preprocessor([obs]).to(device)
mu_v = net(obs_v)[0]
action = mu_v.squeeze(dim=0).data.cpu().numpy()
action = np.clip(action, -1, 1)
obs, reward, done, _ = env.step(action)
rewards += reward
steps += 1
if done:
break
return rewards / count, steps / count
def calc_logprob(mu_v, var_v, actions_v):
p1 = - ((mu_v - actions_v) ** 2) / (2*var_v.clamp(min=1e-3))
p2 = - torch.log(torch.sqrt(2 * math.pi * var_v))
return p1 + p2
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--cuda", default=False, action='store_true', help='Enable CUDA')
parser.add_argument("-n", "--name", required=True, help="Name of the run")
args = parser.parse_args()
device = torch.device("cuda" if args.cuda else "cpu")
save_path = os.path.join("saves", "a2c-" + args.name)
os.makedirs(save_path, exist_ok=True)
env = gym.make(ENV_ID)
test_env = gym.make(ENV_ID)
net = model.ModelA2C(env.observation_space.shape[0], env.action_space.shape[0]).to(device)
print(net)
writer = SummaryWriter(comment="-a2c_" + args.name)
agent = model.AgentA2C(net, device=device)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, GAMMA, steps_count=REWARD_STEPS)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
batch = []
best_reward = None
with ptan.common.utils.RewardTracker(writer) as tracker:
with ptan.common.utils.TBMeanTracker(writer, batch_size=10) as tb_tracker:
for step_idx, exp in enumerate(exp_source):
rewards_steps = exp_source.pop_rewards_steps()
if rewards_steps:
rewards, steps = zip(*rewards_steps)
tb_tracker.track("episode_steps", steps[0], step_idx)
tracker.reward(rewards[0], step_idx)
if step_idx % TEST_ITERS == 0:
ts = time.time()
rewards, steps = test_net(net, test_env, device=device)
print("Test done is %.2f sec, reward %.3f, steps %d" % (
time.time() - ts, rewards, steps))
writer.add_scalar("test_reward", rewards, step_idx)
writer.add_scalar("test_steps", steps, step_idx)
if best_reward is None or best_reward < rewards:
if best_reward is not None:
print("Best reward updated: %.3f -> %.3f" % (best_reward, rewards))
name = "best_%+.3f_%d.dat" % (rewards, step_idx)
fname = os.path.join(save_path, name)
torch.save(net.state_dict(), fname)
best_reward = rewards
batch.append(exp)
if len(batch) < BATCH_SIZE:
continue
states_v, actions_v, vals_ref_v = \
common.unpack_batch_a2c(batch, net, last_val_gamma=GAMMA ** REWARD_STEPS, device=device)
batch.clear()
optimizer.zero_grad()
mu_v, var_v, value_v = net(states_v)
loss_value_v = F.mse_loss(value_v.squeeze(-1), vals_ref_v)
adv_v = vals_ref_v.unsqueeze(dim=-1) - value_v.detach()
log_prob_v = adv_v * calc_logprob(mu_v, var_v, actions_v)
loss_policy_v = -log_prob_v.mean()
entropy_loss_v = ENTROPY_BETA * (-(torch.log(2*math.pi*var_v) + 1)/2).mean()
loss_v = loss_policy_v + entropy_loss_v + loss_value_v
loss_v.backward()
optimizer.step()
tb_tracker.track("advantage", adv_v, step_idx)
tb_tracker.track("values", value_v, step_idx)
tb_tracker.track("batch_rewards", vals_ref_v, step_idx)
tb_tracker.track("loss_entropy", entropy_loss_v, step_idx)
tb_tracker.track("loss_policy", loss_policy_v, step_idx)
tb_tracker.track("loss_value", loss_value_v, step_idx)
tb_tracker.track("loss_total", loss_v, step_idx)