forked from ClarityCoders/MarioPPO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RandomAgent.py
68 lines (60 loc) · 1.91 KB
/
RandomAgent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import retro
import gym
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv
class TimeLimitWrapper(gym.Wrapper):
"""
:param env: (gym.Env) Gym environment that will be wrapped
:param max_steps: (int) Max number of steps per episode
"""
def __init__(self, env, max_steps=10000):
# Call the parent constructor, so we can access self.env later
super(TimeLimitWrapper, self).__init__(env)
self.max_steps = max_steps
# Counter of steps per episode
self.current_step = 0
def reset(self):
"""
Reset the environment
"""
# Reset the counter
self.current_step = 0
return self.env.reset()
def step(self, action):
"""
:param action: ([float] or int) Action taken by the agent
:return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations
"""
self.current_step += 1
obs, reward, done, info = self.env.step(action)
# Overwrite the done signal when
if self.current_step >= self.max_steps:
done = True
# Update the info dict to signal that the limit was exceeded
info['time_limit_reached'] = True
info['Current_Step'] = self.current_step
return obs, reward, done, info
def main():
steps = 0
#env = retro.make(game='MegaMan2-Nes')
env = retro.make(game='SuperMarioBros-Nes')
env = TimeLimitWrapper(env)
env = MaxAndSkipEnv(env, 4)
obs = env.reset()
print(obs.shape)
done = False
while not done:
obs, rew, done, info = env.step(env.action_space.sample())
env.render()
if done:
obs = env.reset()
steps += 1
#print(rew)
if steps % 1000 == 0:
print(f"Total Steps: {steps}")
print(info)
print("Final Info")
print(info)
env.close()
if __name__ == "__main__":
main()