-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
67 lines (52 loc) · 2 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import functools # https://danijar.com/structuring-your-tensorflow-models/
import tensorflow as tf
from scipy.stats import norm
from numpy.linalg import cholesky
def doublewrap(function):
"""
A decorator decorator, allowing to use the decorator to be used without
parentheses if not arguments are provided. All arguments must be optional.
"""
@functools.wraps(function)
def decorator(*args, **kwargs):
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
return function(args[0])
else:
return lambda wrapee: function(wrapee, *args, **kwargs)
return decorator
@doublewrap
def define_scope(function, scope=None, *args, **kwargs):
"""
A decorator for functions that define TensorFlow operations. The wrapped
function will only be executed once. Subsequent calls to it will directly
return the result so that operations are added to the graph only once.
The operations added by the function live within a tf.variable_scope(). If
this decorator is used with arguments, they will be forwarded to the
variable scope. The scope name defaults to the name of the wrapped
function.
"""
attribute = '_cache_' + function.__name__
name = scope or function.__name__
@property
@functools.wraps(function)
def decorator(self):
if not hasattr(self, attribute):
with tf.variable_scope(name, *args, **kwargs):
setattr(self, attribute, function(self))
return getattr(self, attribute)
return decorator
def xx_t(x):
""" x * x' """
return tf.matmul(x, x, transpose_b=True)
def x_tx(x):
""" x' * x """
return tf.matmul(x, x, transpose_a=True)
def quad_form(x, y):
""" x' * y * x """
return tf.matmul(x, tf.matmul(y, x), transpose_a=True)
def scaled_I(x, size):
""" x * I_{size} """
return tf.diag(tf.ones([size], dtype=tf.float64) * x)
def quad_form_trp(x, y):
""" x * y * x' """
return tf.matmul(x, tf.matmul(y, x, transpose_b=True))