-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
241 lines (190 loc) · 8.32 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import gin
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch import optim
import numpy as np
from tqdm import trange
import logging
L = logging.getLogger(__name__)
def to_one_hot(indices, max_index):
zeros = torch.zeros(
indices.shape[0], max_index, dtype=torch.float32,
device=indices.device)
return zeros.scatter_(1, indices.unsqueeze(-1), 1)
def pairwise_distance(x, y):
'''BxNxD, BxMxD -> BxNxM, dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3'''
x_norm = (x**2).sum(2).view(x.shape[0],x.shape[1],1)
y_t = y.permute(0,2,1).contiguous()
y_norm = (y**2).sum(2).view(y.shape[0],1,y.shape[1])
dist = x_norm + y_norm - 2.0 * torch.bmm(x, y_t)
dist[dist != dist] = 0 # replace nan values with 0
return 0.5 * torch.clamp(dist, 0.0, np.inf)
def pairwise_softmax(z1, z2, temperature=1.0):
'''Eq. (9,12). Returns (A, N, M) softmax of negative distance.'''
A, N, D = z1.shape
A, M, D = z2.shape
return F.softmax(-pairwise_distance(z1, z2) / temperature, dim=-1)
class ValueIteration:
def __init__(self, learner, prototype_states, backup=500):
self.device = learner.device
self.learner = learner
#> we opt for a low value for n to assign most Q-value to closest state
self.interpolation_param = 1e-20
#> temperature is grid-searched [1.0, 0.1, 0.001, ..., 1e-20]
self.trans_temperature = 1.0
self.backup = backup
action_dim = learner.action_dim
self.actions = torch.arange(0, action_dim).to(self.device)
self.prototype_states = self.learner.Z_theta(torch.tensor(
prototype_states).to(self.device)).unique(dim=0)
self.prototype_values = torch.zeros(len(self.prototype_states)).to(self.device)
#L.debug('Uniq prototypes: %s', self.prototype_states.shape)
def pi(self, state, valid_actions):
'''Finds argmax. z*=Z(s) => argmax Q(Z(s),a) for a \in valid_actions.'''
state = torch.tensor(state).to(self.device)
valid_actions = torch.tensor(valid_actions).to(self.device)
z = self.learner.Z_theta(state.unsqueeze(0))
zs = z.unsqueeze(0).expand(valid_actions.shape[0], 1, z.shape[1])
Q_zs_a = self.Q(zs, valid_actions)
return valid_actions[Q_zs_a.argmax(dim=0)]
def Q(self, zs, actions):
'''Interpolate Q for each action based on nearest prototypes' Q-vals.
Q(z*, a) = sum_x { w(z*, x) * Q(x, a) } where x are prototype states.'''
w_zs_x = self.w(zs)
# cached; could keep only proto values and recompute if memory is an issue.
Q_x_a = self.qvals[actions].unsqueeze(1).expand_as(w_zs_x)
return torch.sum(w_zs_x * Q_x_a, dim=-1) # A,Nz,Np -> A,Nz (interpolated Q)
def w(self, zs):
'''Eq. (12) softmax over prototypes. returns A,Nz,Np(softmax)'''
A, Nz, D = zs.shape
Np, D = self.prototype_states.shape
x = self.prototype_states.unsqueeze(0).expand(A, Np, D)
return pairwise_softmax(zs, x, temperature=self.interpolation_param)
def closest(self, zs):
'''Find indices of closest prototypes to `zs`. (A, Nz)'''
A, Nz, D = zs.shape
Np, D = self.prototype_states.shape
x = self.prototype_states.unsqueeze(0).expand(A, Np, D)
return pairwise_distance(zs, x).argmin(dim=-1)
def plan(self, goal_state, discount=0.9, stop_iter_epsilon=1e-50):
'''Add goal state to protos; plan using a discretized latent-space MDP.'''
with torch.no_grad():
goal_state = torch.tensor(goal_state).to(self.device).unsqueeze(0)
z = self.learner.Z_theta(goal_state)
idx = self.closest(z.unsqueeze(0)).item()
already_proto = torch.max((z - self.prototype_states[idx])**2) < 1e-9
if not already_proto.item():
q = torch.tensor([0.]).to(self.device)
self.prototype_states = torch.cat((self.prototype_states, z), dim=0)
self.prototype_values = torch.cat((self.prototype_values, q), dim=0)
idx = self.closest(z.unsqueeze(0)).item() # "unit test"
assert idx == len(self.prototype_values) - 1
self.goal_idx = idx
return self._plan(discount, stop_iter_epsilon)
def _plan(self, discount=0.9, stop_iter_epsilon=1e-20):
A, = self.actions.shape
Np, D = self.prototype_states.shape
qstates = self.prototype_states.unsqueeze(0).expand(A, Np, D)
actions = self.actions.unsqueeze(1).expand(A, Np)
# Discretize: find *closest prototype* for each predicted next latent state.
next_states_predicted = self.learner.T(qstates, actions)
# Eqn. (9) == assume states connected in state space will also be in latent.
# T(zj|zi,a) = softmax(-d(zj, zi+A(zi,a)) / t) where zi \in X
# NOTE(tk) I would assume that high-reward states might be far away.
T = pairwise_softmax(next_states_predicted, qstates, self.trans_temperature)
next_idxs = T.argmax(-1)
#> We use this reward function in planning, R(x)=1 if x=Z(sg) else 0
# NOTE(tk) try using reward function here?
reward = self.learner.R(qstates, actions)
reward[next_idxs == self.goal_idx] = 1.0
reward = (T*reward.expand_as(T).transpose(-1,-2)).sum(-1)
for _ in range(self.backup):
next_vals = self.prototype_values.unsqueeze(-1).unsqueeze(0)
next_vals = (T * next_vals.expand_as(T).transpose(-1,-2)).sum(-1)
self.qvals = reward+discount*next_vals
vnew, _ = self.qvals.max(dim=0)
delta = (self.prototype_values - vnew).abs().max().item()
self.prototype_values = vnew
if delta < stop_iter_epsilon:
break
@gin.configurable
class Learner(nn.Module):
'''Uses notation from the paper mostly.'''
def __init__(self, device, in_shape=(3, 60, 60), latent_dim=50,
negative_samples=1, action_dim=4):
super().__init__()
self.action_dim = action_dim
self.device = device
# J=1 is fine if reward loss is taken into account.
self.J = negative_samples
self.hinge = 1.0
self.latent_dim = latent_dim
self.A = LatentTransform(
latent_dim=self.latent_dim,
output_dim=self.latent_dim,
action_dim=self.action_dim).to(device)
self.R = LatentTransform(
latent_dim=self.latent_dim,
output_dim=1,
action_dim=self.action_dim).to(device)
self.Z_theta = ObservationToLatent(
latent_dim=self.latent_dim,
in_shape=in_shape).to(device)
def T(self, zs, act):
'''Transition model.'''
return zs + self.A(zs, act)
def loss(self, obs, act, reward, obs_next):
B = obs.shape[0]
# B x latent_dim
zs = self.Z_theta(obs)
zs_next = self.Z_theta(obs_next)
zs_next_pred = self.T(zs, act)
distance = lambda a, b: 0.5 * torch.sum((a - b) ** 2, dim=-1)
# B <<scalar>> == transition->map vs. map->transition
loss_T = distance(zs_next_pred, zs_next)
# B - B <<scalar>> == reward distance in latent space vs. original MDP
loss_R = distance(reward, self.R(zs, act).flatten())
zs_neg = zs.repeat(self.J, 1)[np.random.permutation(B*self.J)]
zs_next_pred = zs_next_pred.repeat(self.J, 1)
zeros = torch.zeros(B*self.J).to(self.device)
loss_neg = torch.max(zeros, self.hinge - distance(zs_neg, zs_next_pred))
return (loss_T.sum() + loss_R + loss_neg.sum()) / B
class ObservationToLatent(nn.Module):
def __init__(self, latent_dim, in_shape):
super().__init__()
in_channels, w, h = in_shape
self.cnn1 = nn.Sequential(
nn.Conv2d(in_channels, 16, (3,3), padding=1),
nn.ReLU())
self.cnn2 = nn.Sequential(
nn.Conv2d(16, 16, (3,3), padding=1),
nn.ReLU())
self.fc1 = nn.Sequential(
nn.Linear(w*h*16, 64),
nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU())
self.fc3 = nn.Linear(32, latent_dim)
def forward(self, obs):
cnn = self.cnn1(obs)
cnn = self.cnn2(cnn)
flat = cnn.flatten(start_dim=-3)
return self.fc3(self.fc2(self.fc1(flat)))
class LatentTransform(nn.Module):
'''Predicts transition based on state and action.'''
def __init__(self, latent_dim, output_dim, action_dim=4):
super().__init__()
self.action_dim = action_dim
self.fc1 = nn.Sequential(
nn.Linear(latent_dim + action_dim, 64),
nn.ReLU())
self.fc2 = nn.Linear(64, output_dim)
def forward(self, obs, act): # (*B,N), (*B,A) -> *B, N
act_shape = act.shape
act = to_one_hot(act.reshape(-1,), max_index=self.action_dim)
obs = torch.cat([obs, act.reshape(*act_shape, self.action_dim)], dim=-1)
return self.fc2(self.fc1(obs))