-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
113 lines (88 loc) · 4.09 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import time
import os
import numpy as np
import torch
from tqdm import tqdm
import d4rl
class Trainer:
def __init__(
self,
algo,
eval_env,
task,
epoch,
step_per_epoch,
rollout_freq,
logger,
log_freq,
eval_episodes=10
):
self.algo = algo
self.eval_env = eval_env
self.task = task
self._epoch = epoch
self._step_per_epoch = step_per_epoch
self._rollout_freq = rollout_freq
self.logger = logger
self._log_freq = log_freq
self._eval_episodes = eval_episodes
def train_dynamics(self):
self.algo.learn_dynamics()
self.algo.save_dynamics_model(
save_path=os.path.join(self.logger.writer.get_logdir(), "dynamics_model")
)
def train_policy(self):
start_time = time.time()
num_timesteps = 0
# train loop
for e in range(1, self._epoch + 1):
self.algo.policy.train()
with tqdm(total=self._step_per_epoch, desc=f"Epoch #{e}/{self._epoch}") as t:
while t.n < t.total:
if num_timesteps % self._rollout_freq == 0:
self.algo.rollout_transitions()
# update policy by sac
loss = self.algo.learn_policy()
t.set_postfix(**loss)
# log
if num_timesteps % self._log_freq == 0:
for k, v in loss.items():
self.logger.record(k, v, num_timesteps, printed=False)
num_timesteps += 1
t.update(1)
# evaluate current policy
eval_info = self._evaluate()
ep_reward_mean, ep_reward_std = np.mean(eval_info["eval/episode_reward"]), np.std(eval_info["eval/episode_reward"])
ep_length_mean, ep_length_std = np.mean(eval_info["eval/episode_length"]), np.std(eval_info["eval/episode_length"])
normalized_rewards = d4rl.get_normalized_score(self.task, np.array([eval_info["eval/episode_reward"]])) * 100.0
ep_nomarlized_returns_mean, ep_nomarlized_returns_std = normalized_rewards.mean(), normalized_rewards.std()
self.logger.record("eval/episode_reward", ep_reward_mean, num_timesteps, printed=False)
self.logger.record("eval/episode_normalized_return", ep_nomarlized_returns_mean, num_timesteps, printed=False)
self.logger.record("eval/episode_length", ep_length_mean, num_timesteps, printed=False)
self.logger.print(f"Epoch #{e}: episode_reward: {ep_reward_mean:.3f} ± {ep_reward_std:.3f}, episode_normalized_return: {ep_nomarlized_returns_mean:.3f} ± {ep_nomarlized_returns_std:.3f}, episode_length: {ep_length_mean:.3f} ± {ep_length_std:.3f}")
# save policy
torch.save(self.algo.policy.state_dict(), os.path.join(self.logger.writer.get_logdir(), "policy.pth"))
self.logger.print("total time: {:.3f}s".format(time.time() - start_time))
def _evaluate(self):
self.algo.policy.eval()
obs = self.eval_env.reset()
eval_ep_info_buffer = []
num_episodes = 0
episode_reward, episode_length = 0, 0
while num_episodes < self._eval_episodes:
action = self.algo.policy.sample_action(obs, deterministic=True)
next_obs, reward, terminal, _ = self.eval_env.step(action)
episode_reward += reward
episode_length += 1
obs = next_obs
if terminal:
eval_ep_info_buffer.append(
{"episode_reward": episode_reward, "episode_length": episode_length}
)
num_episodes +=1
episode_reward, episode_length = 0, 0
obs = self.eval_env.reset()
return {
"eval/episode_reward": [ep_info["episode_reward"] for ep_info in eval_ep_info_buffer],
"eval/episode_length": [ep_info["episode_length"] for ep_info in eval_ep_info_buffer]
}