forked from rmst/ddpg
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddpg_nets_dm.py
77 lines (61 loc) · 2.52 KB
/
ddpg_nets_dm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
import numpy as np
def hist_summaries(*args):
return tf.merge_summary([tf.histogram_summary(t.name,t) for t in args])
def fanin_init(shape,fanin=None):
fanin = fanin or shape[0]
v = 1/np.sqrt(fanin)
return tf.random_uniform(shape,minval=-v,maxval=v)
l1 = 400 # dm 400
l2 = 300 # dm 300
def conv(x, W, b, stride = 4):
c = conv2d(x_image, filter = W, strides = [1, stride, stride, 1], padding = 'VALID')
z = c + b
y = tf.nn.relu(z)
return y
def theta_p(dimO,dimA):
with tf.variable_scope("theta_p"):
if np.ndim(dimO) > 1:
def xavier():
return tf.contrib.layers.xavier_initializer_conv2d()
#TODO: continue
else:
dimO = dimO[0]
dimA = dimA[0]
return [tf.Variable(fanin_init([dimO,l1]),name='1w'),
tf.Variable(fanin_init([l1],dimO),name='1b'),
tf.Variable(fanin_init([l1,l2]),name='2w'),
tf.Variable(fanin_init([l2],l1),name='2b'),
tf.Variable(tf.random_uniform([l2,dimA],-3e-3,3e-3),name='3w'),
tf.Variable(tf.random_uniform([dimA],-3e-3,3e-3),name='3b')]
def policy(obs,theta,name='policy'):
with tf.variable_op_scope([obs],name,name):
h0 = tf.identity(obs,name='h0-obs')
h1 = tf.nn.relu( tf.matmul(h0,theta[0]) + theta[1],name='h1')
h2 = tf.nn.relu( tf.matmul(h1,theta[2]) + theta[3],name='h2')
h3 = tf.identity(tf.matmul(h2,theta[4]) + theta[5],name='h3')
action = tf.nn.tanh(h3,name='h4-action')
# print(action.get_shape())
summary = hist_summaries(h0,h1,h2,h3,action)
return action,summary
def theta_q(dimO,dimA):
dimO = dimO[0]
dimA = dimA[0]
with tf.variable_scope("theta_q"):
return [tf.Variable(fanin_init([dimO,l1]),name='1w'),
tf.Variable(fanin_init([l1],dimO),name='1b'),
tf.Variable(fanin_init([l1+dimA,l2]),name='2w'),
tf.Variable(fanin_init([l2],l1+dimA),name='2b'),
tf.Variable(tf.random_uniform([l2,1],-3e-4,3e-4),name='3w'),
tf.Variable(tf.random_uniform([1],-3e-4,3e-4),name='3b')]
def qfunction(obs,act,theta, name="qfunction"):
with tf.variable_op_scope([obs,act],name,name):
h0 = tf.identity(obs,name='h0-obs')
h0a = tf.identity(act,name='h0-act')
h1 = tf.nn.relu( tf.matmul(h0,theta[0]) + theta[1],name='h1')
h1a = tf.concat(1,[h1,act])
h2 = tf.nn.relu( tf.matmul(h1a,theta[2]) + theta[3],name='h2')
qs = tf.matmul(h2,theta[4]) + theta[5]
q = tf.squeeze(qs,[1],name='h3-q')
summary = hist_summaries(h0,h0a,h1,h2,q)
return q,summary