-
Notifications
You must be signed in to change notification settings - Fork 1
/
map.py
70 lines (64 loc) · 3.43 KB
/
map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import numpy as np
import matplotlib.pyplot as plt
class Map(object):
def __init__(self):
self.map = np.array([[0,0,0,0,0,0,0,0,0,0,0,0,],
[0,1,1,1,1,1,1,1,1,1,3,0,],
[0,1,0,0,1,0,0,0,1,1,1,0,],
[0,1,0,0,1,1,0,1,1,1,1,0,],
[0,1,1,1,1,1,1,1,1,1,1,0,],
[0,1,1,1,1,1,1,1,0,0,0,0,],
[0,1,1,0,0,0,1,1,0,1,1,0,],
[0,1,1,0,1,1,1,1,1,1,1,0,],
[0,1,1,1,1,1,1,1,0,0,1,0,],
[0,1,1,1,1,0,0,1,1,0,0,0,],
[0,2,1,1,0,0,1,1,1,1,1,0,],
[0,0,0,0,0,0,0,0,0,0,0,0,]])
self.size = self.map.shape[0]
self.init_pos = [1,1]
self.goal_pos = [10,10]
plt.figure(figsize=(7,7))
def chack_movable(self,pos):
up = bool(self.map[11-pos[1]-1][pos[0]])
down = bool(self.map[11-pos[1]+1][pos[0]])
right = bool(self.map[11-pos[1]][pos[0]+1])
left = bool(self.map[11-pos[1]][pos[0]-1])
return([up,down,right,left])
def plot(self,pos=[1,1],q_table=[]):
gpoint = np.arange(0, self.size, 1)
plt.vlines(gpoint, 0, self.size, linewidth=0.3, colors="k")
plt.hlines(gpoint, 0, self.size, linewidth=0.3, colors="k")
plt.xlim(0, self.size)
plt.ylim(0, self.size)
for i in range (self.size):
for j in range (self.size):
if self.map[11-j][i] == 0:
plt.axvspan(xmin=i, xmax=i+1, ymin=j/self.size, ymax=(j+1)/self.size, color = "k", alpha=0.9)
elif self.map[11-j][i] == 1:
if q_table[i*12+j][np.argmax(q_table[i*12+j])] > 1.0:
plt.axvspan(xmin=i, xmax=i+1, ymin=j/self.size, ymax=(j+1)/self.size, color = "y", \
alpha=np.clip(q_table[i*12+j][np.argmax(q_table[i*12+j])]/100,0,0.9))
if np.argmax(q_table[i*12+j]) == 0:
plt.text(i, j, "↑", size=30, alpha=np.clip(q_table[i*12+j][np.argmax(q_table[i*12+j])]/100,0,0.9))
elif np.argmax(q_table[i*12+j]) == 1:
plt.text(i, j, "↓", size=30, alpha=np.clip(q_table[i*12+j][np.argmax(q_table[i*12+j])]/100,0,0.9))
elif np.argmax(q_table[i*12+j]) == 2:
plt.text(i, j, "→", size=30, alpha=np.clip(q_table[i*12+j][np.argmax(q_table[i*12+j])]/100,0,0.9))
elif np.argmax(q_table[i*12+j]) == 3:
plt.text(i, j, "←", size=30, alpha=np.clip(q_table[i*12+j][np.argmax(q_table[i*12+j])]/100,0,0.9))
else:pass
else:
pass
elif self.map[11-j][i] == 2:
plt.axvspan(xmin=i, xmax=i+1, ymin=j/self.size, ymax=(j+1)/self.size, color = "g", alpha=0.5)
elif self.map[11-j][i] == 3:
plt.axvspan(xmin=i, xmax=i+1, ymin=j/self.size, ymax=(j+1)/self.size, color = "b", alpha=0.5)
else:
pass
plt.axvspan(xmin=pos[0], xmax=pos[0]+1, ymin=pos[1]/self.size, ymax=(pos[1]+1)/self.size,\
color = "r", alpha=0.5)
plt.plot()
plt.pause(0.0001)
plt.cla()
if __name__ == "__main__":
Map().plot()