This repository has been archived by the owner on Oct 6, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tme_5_actor_critic_cartpole.py
88 lines (70 loc) · 2.39 KB
/
tme_5_actor_critic_cartpole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gym
import torch
from torch.utils.tensorboard import SummaryWriter
from agent import ActorCritic
from logger import get_logger
from experiment import Experiment
number_of_episodes = 3000
optimize_every = 10 # Number of steps.
show_every = 100 # Number of episodes.
if __name__ == "__main__":
env = gym.make("CartPole-v1")
# Create a new agent here.
experiment = Experiment.create(
base_name="actor_critic/actor_critic_CartPole-v1",
model_class=ActorCritic,
hp={
"observation_space": env.observation_space,
"action_space": env.action_space,
"learning_rate": 0.0002,
"gamma": 0.98,
},
)
experiment.save()
# Or load a previous one.
# experiment = Experiment.load("...")
logger = get_logger(experiment.name, file_path=experiment.log_path)
writer = SummaryWriter(
log_dir=experiment.writer_path, purge_step=experiment.episode
)
experiment.info(logger)
while experiment.episode < number_of_episodes:
experiment.episode += 1
show = (experiment.episode + 1) % show_every == 0
state = env.reset()
episode_reward, episode_steps = 0, 0
while True:
# Draw an action and act on the environment.
action = experiment.model.step(torch.from_numpy(state).float())
end_state, reward, done, info = env.step(action)
# Record the transition.
experiment.model.add_transition(
(
state,
action,
reward,
end_state,
False if info.get("TimeLimit.truncated") else done,
)
)
state = end_state
experiment.step += 1
episode_steps += 1
episode_reward += reward
# Optimize if needed.
if (experiment.step + 1) % optimize_every == 0:
experiment.model.optimize()
# Show if needed.
if show:
env.render()
if done:
break
# Log.
if show:
logger.info(f"Episode {experiment.episode}: reward = {episode_reward}.")
writer.add_scalars(
"train",
{"reward": episode_reward, "steps": episode_steps},
global_step=experiment.episode,
)
env.close()