-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathDeSKO.py
330 lines (254 loc) · 12.9 KB
/
DeSKO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
import tensorflow as tf
from utils import mlp
import numpy as np
import math
import os
import tensorflow_probability as tfp
from base_koopman_operator import base_Koopman
SCALE_DIAG_MIN_MAX = (-20, 2)
"""This version uses stochastic koopman operator with observation matrix"""
class Koopman(base_Koopman):
"""Koopman.
Attributes:
A (tf.Variable): Weights of the Koopman operator
B (tf.Variable): Weights of the Koopman operator
"""
def __init__(
self,
args,
**kwargs
):
"""
Args:
latent_dim (int): Dimension of the observation space.
act_dim (int): Dimension of the action space.
hidden_sizes (list): Sizes of the hidden layers.
activation (function): The hidden layer activation function.
output_activation (function, optional): The activation function used for
the output layers. Defaults to tf.keras.activations.linear.
name (str, optional): The Lyapunov critic name. Defaults to
"lyapunov_critic".
"""
if args['target_entropy'] is None:
self.target_entropy = -args['latent_dim'] # lower bound of the entropy
else:
self.target_entropy = args['target_entropy']
self.log_alpha = tf.get_variable('alpha', None, tf.float32, initializer=tf.log(args['alpha'])) # Entropy Temperature
self.alpha = tf.exp(self.log_alpha)
super(Koopman, self).__init__(args)
def _create_koopman_result_holder(self, args):
self.A_result = np.zeros([args['latent_dim'], args['latent_dim']])
self.A_tensor = tf.Variable(self.A_result,
trainable=False, name="A_tensor", dtype=tf.float32)
self.B_result = np.zeros([args['act_dim'], args['latent_dim']])
self.B_tensor = tf.Variable(self.B_result, trainable=False, name="B_tensor", dtype=tf.float32)
self.C_result = np.zeros([args['latent_dim'], args['state_dim']])
self.C_tensor = tf.Variable(self.C_result, trainable=False, name="C_tensor", dtype=tf.float32)
def _create_encoder(self, args):
# if args['activation'] == 'relu':
# activation = tf.nn.relu
# elif args['activation'] == 'elu':
# activation = tf.nn.elu
# else:
# print(args['activation']+' is not implemented as a activation function')
# raise KeyError
activation = tf.nn.relu
self.mean = mlp(self.x_input,
args['encoder_struct'] + [args['latent_dim']], activation, name='mean', regularizer=self.l2_reg)
log_sigma = mlp(self.x_input,
args['encoder_struct'] + [args['latent_dim']], activation, name='sigma', regularizer=self.l2_reg)
log_sigma = tf.clip_by_value(log_sigma, *SCALE_DIAG_MIN_MAX)
self.sigma = tf.exp(log_sigma)
base_distribution = tfp.distributions.MultivariateNormalDiag(loc=tf.zeros(args['latent_dim']),
scale_diag=tf.ones(args['latent_dim']))
epsilon = base_distribution.sample([tf.shape(self.mean)[0], args['pred_horizon']])
bijector = tfp.bijectors.Affine(shift=self.mean, scale_diag=self.sigma)
self.stochastic_latent = bijector.forward(epsilon)
# # epsilon = tf.random_normal([tf.shape(self.mean)[0], args['pred_horizon'], args['latent_dim']])
# self.stochastic_latent = epsilon * self.sigma + self.mean
def _create_koopman_operator(self, args):
"""
Create the Koopman operators
:param args:
:return:
"""
with tf.variable_scope('koopman', regularizer=self.l2_reg):
self.A = tf.get_variable('A', shape=[args['latent_dim'],
args['latent_dim']])
self.B = tf.get_variable('B', shape=[args['act_dim'], args['latent_dim']])
self.C = tf.get_variable('C', shape=[args['latent_dim'], args['state_dim']])
# self.A_inv = tf.get_variable('A_inv', shape=[args['state_dim'] + args['latent_dim'],
# args['state_dim'] + args['latent_dim']])
return
def _create_forward_pred(self, args):
"""
Iteratively predict future state with the Koopman operator
:param args(list):
:return: forward_pred(Tensor): forward predictions
"""
forward_pred = []
x_mean_forward_pred = []
mean_forward_pred = []
sigma_forward_pred = []
phi_t = self.stochastic_latent[:, 0]
mean_t = self.mean[:, 0]
sigma_t = self.sigma[:, 0]
for t in range(args['pred_horizon']-1):
phi_t = tf.matmul(phi_t, self.A) + tf.matmul(self.a_input[:, t], self.B)
x_t = tf.matmul(phi_t, self.C)
mean_t = tf.matmul(mean_t, self.A) + tf.matmul(self.a_input[:, t], self.B)
sigma_t = tf.matmul(sigma_t, self.A) + tf.matmul(self.a_input[:, t], self.B)
x_mean_t = tf.matmul(mean_t, self.C)
forward_pred.append(x_t)
mean_forward_pred.append(mean_t)
x_mean_forward_pred.append(x_mean_t)
sigma_forward_pred.append(sigma_t)
self.forward_pred = tf.stack(forward_pred, axis=1)
self.x_mean_forward_pred = tf.stack(x_mean_forward_pred, axis=1)
self.mean_forward_pred = tf.stack(mean_forward_pred, axis=1)
self.sigma_forward_pred = tf.stack(sigma_forward_pred, axis=1)
return
def _create_backward_pred(self, args):
"""
Iteratively predict the past states with the Koopman operator
:param args:
:return:
"""
# backward_pred = []
# mean_backward_pred = []
# sigma_backward_pred = []
#
# phi_t = tf.concat([self.x_input[:, -1], self.stochastic_latent[:, -1]], axis=1)
# mean_t = tf.concat([self.x_input[:, 0], self.mean[:, 0]], axis=1)
# sigma_t = tf.concat([self.x_input[:, 0], self.sigma[:, 0]], axis=1)
# for t in range(args['pred_horizon'] - 1, 0, -1):
# phi_t = tf.matmul(phi_t - tf.matmul(self.a_input[:, t-1], self.B), self.A_inv)
# mean_t = tf.matmul(mean_t - tf.matmul(self.a_input[:, t - 1], self.B), self.A_inv)
# sigma_t = tf.matmul(sigma_t - tf.matmul(self.a_input[:, t - 1], self.B), self.A_inv)
# backward_pred.append(phi_t)
# mean_backward_pred.append(mean_t)
# sigma_backward_pred.append(sigma_t)
#
# backward_pred = tf.stack(backward_pred, axis=1)
# mean_backward_pred = tf.stack(mean_backward_pred, axis=1)
# sigma_backward_pred = tf.stack(sigma_backward_pred, axis=1)
# self.backward_pred = tf.reverse(backward_pred, [1])
# self.mean_backward_pred = tf.reverse(mean_backward_pred, [1])
# self.sigma_backward_pred = tf.reverse(sigma_backward_pred, [1])
return
def _create_optimizer(self, args):
mean = tf.reshape(self.mean, [-1, args['latent_dim']])
sigma = tf.reshape(self.sigma, [-1, args['latent_dim']])
sample = tf.reshape(self.stochastic_latent, [-1, args['latent_dim']])
dist = tfp.distributions.MultivariateNormalDiag(loc=mean, scale_diag=sigma)
self.entropy = - tf.reduce_mean(dist.log_prob(sample, name='entropy'))
self.alpha_loss = alpha_loss = self.log_alpha * tf.stop_gradient(self.entropy - self.target_entropy)
self.alpha_train = tf.train.AdamOptimizer(self.lr).minimize(alpha_loss, var_list=self.log_alpha)
forward_pred_loss = tf.losses.mean_squared_error(labels=tf.stop_gradient(self.mean[:, 1:]), predictions=self.mean_forward_pred[:, :])\
+ tf.losses.mean_squared_error(labels=tf.stop_gradient(self.sigma[:, 1:]), predictions=self.sigma_forward_pred[:, :])
self.reconstruct_loss = reconstruct_loss = tf.losses.mean_squared_error(labels=self.x_input[:, 1:],
predictions=self.forward_pred[:, :])
reconstruct_pred_T_loss = tf.reduce_mean(tf.square(self.x_input[:, -1,]- self.forward_pred[:, -1]))
# val_x = self.x_input[:, 1:]* self.scale + self.shift
# val_y = self.x_mean_forward_pred[:, :] * self.scale + self.shift
# self.val_loss = tf.abs(tf.reduce_mean((val_x - val_y)/(tf.abs(val_y)+1e-10)))
val_x = self.x_input[:, 1:]
val_y = self.x_mean_forward_pred[:, :]
self.val_loss = tf.losses.mean_squared_error(labels=val_x, predictions=val_y)
self.loss =1 * forward_pred_loss + 10 * reconstruct_loss+ 0 * reconstruct_pred_T_loss - tf.stop_gradient(self.alpha) * self.entropy #+ weighted_reconstruct_loss
#+ tf.reduce_sum(tf.losses.get_regularization_losses()) #+ weighted_reconstruct_loss + weighted_reconstruct_T_loss
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
self.train = tf.train.AdamOptimizer(self.lr).minimize(self.loss, var_list=params)
grad_norm = []
self.grads = []
for grad in tf.gradients(self.loss, params):
if grad is not None:
self.grads.append(grad)
grad_norm.append(tf.norm(grad))
grad_norm = tf.reduce_max(grad_norm)
self.diagnotics.update({
'loss': self.loss,
'lagrange multiplier': self.alpha,
'entropy': self.entropy,
'gradient': grad_norm
})
self.opt_list.extend([self.train, self.alpha_train])
def _create_prediction_model(self, args):
self.x_t = tf.placeholder(tf.float32, [None, args['state_dim']], 'x_t')
self.a_t = tf.placeholder(tf.float32, [None, args['pred_horizon']-1, args['act_dim']], 'a_t')
self.shifted_x_t = (self.x_t - self.shift)/self.scale
self.shifted_a_t = (self.a_t - self.shift_u) / self.scale_u
self.mean_t = mlp(self.shifted_x_t,
args['encoder_struct'] + [args['latent_dim']], tf.nn.relu, name='mean', reuse=True)
self.sigma_t = mlp(self.shifted_x_t,
args['encoder_struct'] + [args['latent_dim']], tf.nn.relu, name='sigma', reuse=True)
forward_pred = []
phi_t = self.mean_t
for t in range(args['pred_horizon'] - 1):
u = self.shifted_a_t[:, t]
phi_t = tf.matmul(phi_t, self.A) + tf.matmul(u, self.B)
x_t = tf.matmul(phi_t, self.C)
forward_pred.append(x_t)
self.future_states = tf.stack(forward_pred, axis=1)[:, :]
def store_Koopman_operator(self, replay_memory):
batch_dict = replay_memory.get_all_train_data()
x = batch_dict['states']
a = batch_dict['inputs']
feed_in = {}
feed_in[self.x_input] = x
feed_in[self.a_input] = a
# Find loss and perform training operation
feed_out = [self.A, self.B, self.C, tf.assign(self.A_tensor, self.A), tf.assign(self.B_tensor, self.B), tf.assign(self.C_tensor, self.C),]
out = self.sess.run(feed_out, feed_in)
self.A_result = out[0]
self.B_result = out[1]
self.C_result = out[2]
def calc_val_loss(self, replay_memory):
batch_dict = replay_memory.get_all_val_data()
x = batch_dict['states']
u = batch_dict['inputs']
# Construct inputs for network
feed_in = {}
feed_in[self.x_input] = x
feed_in[self.a_input] = u
# Find loss
feed_out = self.val_loss
loss = self.sess.run(feed_out, feed_in)
return loss
def learn(self, batch_dict, lr, args):
x = batch_dict['states']
a = batch_dict['inputs']
#print(x[])
# Construct inputs for network
feed_in = {}
feed_in[self.x_input] = x
feed_in[self.a_input] = a
feed_in[self.lr] = lr
feed_in[self.loss_weight] = self.loss_weight_num
self.sess.run(self.opt_list, feed_in)
diagnotics = self.sess.run([self.diagnotics[key] for key in self.diagnotics.keys()], feed_in)
output = {}
[output.update({key: value}) for (key, value) in zip(self.diagnotics.keys(), diagnotics)]
for key in output.keys():
if math.isnan(output[key]):
print('NaN appears')
raise ValueError
return output
def encode(self, x):
feed_dict = {}
feed_dict[self.x_t] = x
[mean, sigma] = self.sess.run([self.mean_t, self.sigma_t], feed_dict)
return mean, sigma
def restore(self, path):
model_file = tf.train.latest_checkpoint(path+'/model/')
if model_file is None:
success_load = False
return success_load
self.saver.restore(self.sess, model_file)
feed_out = [self.A_tensor, self.B_tensor, self.C_tensor]
out = self.sess.run(feed_out, {})
self.A_result = out[0]
self.B_result = out[1]
self.C_result = out[2]
success_load = True
return success_load