-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
88 lines (69 loc) · 2.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Common functions you may find useful in your implementation."""
import semver
import tensorflow as tf
def get_uninitialized_variables(variables=None):
"""Return a list of uninitialized tf variables.
Parameters
----------
variables: tf.Variable, list(tf.Variable), optional
Filter variable list to only those that are uninitialized. If no
variables are specified the list of all variables in the graph
will be used.
Returns
-------
list(tf.Variable)
List of uninitialized tf variables.
"""
sess = tf.get_default_session()
if variables is None:
variables = tf.global_variables()
else:
variables = list(variables)
if len(variables) == 0:
return []
if semver.match(tf.__version__, '<1.0.0'):
init_flag = sess.run(
tf.pack([tf.is_variable_initialized(v) for v in variables]))
else:
init_flag = sess.run(
tf.stack([tf.is_variable_initialized(v) for v in variables]))
return [v for v, f in zip(variables, init_flag) if not f]
def get_soft_target_model_updates(target, source, tau):
r"""Return list of target model update ops.
These are soft target updates. Meaning that the target values are
slowly adjusted, rather than directly copied over from the source
model.
The update is of the form:
$W' \gets (1- \tau) W' + \tau W$ where $W'$ is the target weight
and $W$ is the source weight.
Parameters
----------
target: keras.models.Model
The target model. Should have same architecture as source model.
source: keras.models.Model
The source model. Should have same architecture as target model.
tau: float
The weight of the source weights to the target weights used
during update.
Returns
-------
list(tf.Tensor)
List of tensor update ops.
"""
pass
def get_hard_target_model_updates(target, source):
"""Return list of target model update ops.
These are hard target updates. The source weights are copied
directly to the target network.
Parameters
----------
target: keras.models.Model
The target model. Should have same architecture as source model.
source: keras.models.Model
The source model. Should have same architecture as target model.
Returns
-------
list(tf.Tensor)
List of tensor update ops.
"""
pass