-
Notifications
You must be signed in to change notification settings - Fork 0
/
env.py
165 lines (129 loc) · 4.07 KB
/
env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
'''Simple key collection task implementation.
'''
import gym
import numpy as np
import typing
import collections
import skimage
from skimage import draw
class Obj:
def __init__(self, name, pos=None):
self.name=name
self.pos=pos
class Key(Obj):
def __init__(self, pos=None):
super().__init__('key', pos)
class Player(Obj):
def __init__(self, pos=None):
super().__init__('player', pos)
self.inv = set()
def pickup(self, item, world):
self.inv.add(item)
world.remove(item)
def get_colors(cmap='Set1', num_colors=9):
import matplotlib.pyplot as plt
cm = plt.get_cmap(cmap)
colors = []
for i in range(num_colors):
colors.append((cm(1. * i / num_colors)))
return colors
class KeyTask(gym.Env):
'''Positions are all (x,y). Collect key and deliver it to one corner.'''
Ns_str = 'lrud'
Ns = [(-1,0),(1,0),(0,-1),(0,1)]
def __init__(self, seed=42, max_steps=100):
self.max_steps = max_steps
self.actions = self.Ns
self.action_space = [0,1,2,3]
self.width = 6
self.height = 6
self.seed(seed)
self.reset()
def place(self, obj, pos):
# There are only 2 objects so no need to complicate with multiple items.
if obj.pos in self.map:
del self.map[obj.pos]
self.map[pos] = obj
obj.pos = pos
def remove(self, obj):
del self.map[obj.pos]
self.objects.remove(obj)
def _render(self, objects):
im = np.zeros((self.height*10, self.width*10, 3), dtype=np.float32)
for idx, obj in enumerate(objects):
x, y = obj.pos
rr, cc = skimage.draw.circle(
y*10 + 5, x*10 + 5, 5, im.shape)
im[rr, cc, :] = self.n2color[obj.name]
return im.transpose([2, 0, 1])
def render(self):
return self._render(self.objects)
def render_goal_state(self):
'''player in goal state, no key present'''
fake_player = Player(pos=self.goal_state)
return self._render([fake_player])
def seed(self, seed):
self.random = np.random.default_rng(seed)
def reset(self):
self.curstep = 0
self.player = Player()
self.key = Key()
self.map = {}
self.objects = [self.player, self.key]
randint = self.random.integers
W, H = self.width, self.height
corners = [(0,0), (0, H-1), (W-1, 0), (W-1,H-1)]
self.goal_state = corners[randint(0, 4)]
kpos = randint(0, W-1), randint(0, H-1)
while kpos == self.goal_state:
kpos = randint(0, W-1), randint(0, H-1)
self.place(self.key, kpos)
ppos = randint(0, W-1), randint(0, H-1)
while ppos == self.goal_state or ppos == kpos:
ppos = randint(0, W-1), randint(0, H-1)
self.place(self.player, ppos)
# NOTE in the paper they use 1 color per channel
colors = list(map(lambda c: c[:3], get_colors(num_colors=9)))
self.n2color = {
self.player.name: colors[0],
self.key.name: colors[1]
}
return self.render()
def move(self, obj, delta):
npos = (obj.pos[0] + delta[0], obj.pos[1] + delta[1])
if self.is_valid(npos):
self.place(obj, npos)
def is_valid(self, position):
x, y = position
if not (0 <= x < self.width and 0 <= y < self.height):
return False
obj_in_position = self.map.get(position)
if obj_in_position and obj_in_position != self.key:
return False
return True
@property
def valid_actions(self):
x, y = self.player.pos
return [
i
for i, (dx, dy) in enumerate(self.actions)
if self.is_valid((x + dx, y + dy))
]
def step(self, action):
self.curstep += 1
out_of_time = self.curstep == self.max_steps
self.move(self.player, self.actions[action])
reward = None
won = False
if self.key in self.objects and self.player.pos == self.key.pos:
self.player.pickup(self.key, self)
reward = 1 # reward key pickup
# NOTE(tk) Different dynamics than the paper.
# Their reward is 0 if reaching goal without the key.
in_goal = self.player.pos == self.goal_state
if in_goal:
reward = 1 if self.key in self.player.inv else -1
won = reward == 1
elif not reward:
reward = -.1
return self.render(), reward, in_goal or out_of_time, won