-
Notifications
You must be signed in to change notification settings - Fork 0
/
atari_util.py
43 lines (35 loc) · 1.56 KB
/
atari_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import atari_wrappers
def PrimaryAtariWrap(env,
clip_rewards=True,
frame_skip=True,
fire_reset_event=False,
episodic_life=False,
width=44,
height=44,
margins=[1,1,1,1],
n_frames=4,
reward_scale=0):
# This wrapper holds the same action for <skip> frames and outputs
# the maximal pixel value of 2 last frames (to handle blinking
# in some envs)
if frame_skip:
env = atari_wrappers.MaxAndSkipEnv(env, skip=4)
# This wrapper sends done=True when each life is lost
# (not all the 5 lives that are givern by the game rules).
# It should make easier for the agent to understand that losing is bad.
if episodic_life:
env = atari_wrappers.EpisodicLifeEnv(env)
# This wrapper laucnhes the ball when an episode starts.
# Without it the agent has to learn this action, too.
# Actually it can but learning would take longer.
if fire_reset_event:
env = atari_wrappers.FireResetEnv(env)
# This wrapper transforms rewards to {-1, 0, 1} according to their sign
if clip_rewards:
env = atari_wrappers.ClipRewardEnv(env)
if reward_scale != 0:
env = atari_wrappers.RewardScale(env, reward_scale)
# This wrapper is yours :)
env = atari_wrappers.PreprocessAtariObs(env,height=height, width=width, margins=margins)
env = atari_wrappers.FrameBuffer(env, n_frames=n_frames, dim_order='pytorch')
return env