-
Notifications
You must be signed in to change notification settings - Fork 0
/
atari_wrappers.py
208 lines (170 loc) · 7.65 KB
/
atari_wrappers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# taken from OpenAI baselines.
import numpy as np
import gym
from gym.core import ObservationWrapper
from gym.spaces import Box
from scipy.misc import imresize
from gym.core import Wrapper
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip
def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for i in range(self._skip):
obs, reward, done, info = self.env.step(action)
if i == self._skip - 2: self._obs_buffer[0] = obs
if i == self._skip - 1: self._obs_buffer[1] = obs
total_reward += reward
if done:
break
# Note that the observation on the done=True frame
# doesn't matter
max_frame = self._obs_buffer.max(axis=0)
return max_frame, total_reward, done, info
def reset(self, **kwargs):
return self.env.reset(**kwargs)
class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env):
gym.RewardWrapper.__init__(self, env)
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self, **kwargs):
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset(**kwargs)
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset(**kwargs)
return obs
def step(self, ac):
return self.env.step(ac)
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# so it's important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self, **kwargs):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
# in torch imgs have shape [c, h, w] instead of common [h, w, c]
class AntiTorchWrapper(gym.ObservationWrapper):
def __init__(self, env):
gym.ObservationWrapper.__init__(self, env)
self.img_size = [env.observation_space.shape[i]
for i in [1, 2, 0]
]
self.observation_space = gym.spaces.Box(0.0, 1.0, self.img_size)
def _observation(self, img):
"""what happens to each observation"""
img = img.transpose(1, 2, 0)
return img
class PreprocessAtariObs(ObservationWrapper):
def __init__(self, env, height=42, width=42, margins=[1,1,1,1]):
"""A gym wrapper that crops, scales image into the desired shapes and grayscales it."""
ObservationWrapper.__init__(self, env)
self.img_size = (1, width, height)
self.desired_img_size= (width, height)
self.margins = margins
self.observation_space = Box(0.0, 1.0, self.img_size)
def _to_grayscale(self, rgb, channel_weights=[0.1, 0.8, 0.1]):
assert rgb.ndim == 3, "is image rgb? ndim: " + str(rgb.ndim)
return np.dot(rgb[..., :3], channel_weights)
def _resize(self, img, desired_size=(64, 64)):
assert len(desired_size) == 2, "desired size is invalid, desired_size: " + str(desired_size)
return imresize(img, desired_size)
def _crop(self, img, margins=(56, 10, 8, 8)):
""" margins: top, left, right, bottom"""
assert len(margins) == 4, "margins array is invalid"
assert img.ndim == 2, "img is not grayscale"
return img[margins[0]:-margins[-1], margins[1]:-margins[2]]
def _observation(self, img):
"""what happens to each observation"""
img = self._to_grayscale(img)
img = self._crop(img, self.margins)
img = self._resize(img, desired_size=self.desired_img_size)
img = np.reshape(img, self.img_size)
img = np.asarray(img, dtype=np.float32) / 255.0
return img
class RewardScale(Wrapper):
def __init__(self, env, reward_scale=0.1):
"""A gym wrapper that reshapes, crops and scales image into the desired shapes"""
super(RewardScale, self).__init__(env)
self.reward_scale= reward_scale
def step(self,action):
"""plays breakout for 1 step, returns frame buffer"""
new_img, reward, done, info = self.env.step(action)
return new_img, reward * self.reward_scale, done, info
class FrameBuffer(Wrapper):
def __init__(self, env, n_frames=4, dim_order='tensorflow'):
"""A gym wrapper that reshapes, crops and scales image into the desired shapes"""
super(FrameBuffer, self).__init__(env)
self.dim_order = dim_order
if dim_order == 'tensorflow':
height, width, n_channels = env.observation_space.shape
obs_shape = [height, width, n_channels * n_frames]
elif dim_order == 'pytorch':
n_channels, height, width = env.observation_space.shape
obs_shape = [n_channels * n_frames, height, width]
else:
raise ValueError('dim_order should be "tensorflow" or "pytorch", got {}'.format(dim_order))
self.observation_space = Box(0.0, 1.0, obs_shape)
self.framebuffer = np.zeros(obs_shape, 'float32')
def reset(self):
"""resets breakout, returns initial frames"""
self.framebuffer = np.zeros_like(self.framebuffer)
self.update_buffer(self.env.reset())
return self.framebuffer
def step(self, action):
"""plays breakout for 1 step, returns frame buffer"""
new_img, reward, done, info = self.env.step(action)
self.update_buffer(new_img)
return self.framebuffer, reward, done, info
def update_buffer(self, img):
if self.dim_order == 'tensorflow':
offset = self.env.observation_space.shape[-1]
axis = -1
cropped_framebuffer = self.framebuffer[:, :, :-offset]
elif self.dim_order == 'pytorch':
offset = self.env.observation_space.shape[0]
axis = 0
cropped_framebuffer = self.framebuffer[:-offset]
self.framebuffer = np.concatenate([img, cropped_framebuffer], axis=axis)