-
Notifications
You must be signed in to change notification settings - Fork 6
/
flows.py
167 lines (143 loc) · 5.51 KB
/
flows.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import nux
import jax
import jax.numpy as jnp
import jax.random as rnd
import nux.networks as net
from nux.internal.layer import InvertibleLayer
from nux.flows.bijective import Reverse, MAF
from nux.flows.stochastic import ContinuouslyIndexed
from jax.scipy.stats import norm
import optax
from jax import jit, value_and_grad
import haiku as hk
from typing import Optional, Mapping, Tuple
class BatchNorm(InvertibleLayer):
def __init__(self, axis=0, name: str = "batch_norm"):
""" Elementwise shift + scale
Args:
axis: Axes to apply to
name: Optional name for this module.
"""
super().__init__(name=name)
self.axes = (axis,) if isinstance(axis, int) else axis
def call(
self, inputs, rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs,
) -> Mapping[str, jnp.ndarray]:
eps = 1e-5
outputs = {}
x = inputs["x"]
x_shape = self.get_unbatched_shapes(sample)["x"]
# x_shape = x.shape[1:]
m = jnp.mean(x, axis=0)
v = jnp.std(x, axis=0) ** 2
gamma = hk.get_parameter(
"gamma",
shape=x_shape,
dtype=x.dtype,
init=lambda *args: jnp.ones(*args) * -2,
)
beta = hk.get_parameter("beta", shape=x_shape, dtype=x.dtype, init=jnp.zeros)
if sample == False:
outputs["x"] = ((x - m) / jnp.sqrt((v + eps))) * jnp.exp(gamma) + beta
else:
raise NotImplementedError
log_det = jnp.sum(gamma - 0.5 * jnp.log(v + eps))
outputs["log_det"] = log_det
return outputs
def get_flow_CIF(
rand_key,
input_shape,
num_layers,
batch_size,
num_components=8,
threshold=-1_000,
init_std=3.0,
hidden_sizes=None,
pretrain=True,
noise_dim=1,
):
def create_network(out_shape):
return net.MLP(
out_dim=out_shape[-1],
layer_sizes=[32] * 2,
nonlinearity="relu",
parameter_norm=None,
dropout_rate=None,
)
def get_CIF_MAF():
return ContinuouslyIndexed(MAF(hidden_layer_sizes=[32, 32]))
def create_flow():
layers = []
for _ in range(num_layers):
layers += [get_CIF_MAF(), Reverse(), BatchNorm()]
return nux.sequential(*layers)
# Perform data-dependent initialization
train_inputs = {
"x": rnd.normal(rand_key, shape=((batch_size, input_shape))) * init_std
}
# flow = nux.Flow(create_flow, rand_key, train_inputs, batch_axes=(0,))
flow = nux.Flow(create_flow, rand_key, train_inputs, batch_axes=(0,))
# outputs = flow.apply(rnd.PRNGKey(0), train_inputs)
# flow = nux.Flow(create_flow, rand_key, train_inputs)
def sample_flow(
params: hk.Params, state: hk.State, rng_key, n: int
) -> Tuple[jnp.ndarray, jnp.ndarray, hk.State]:
# Sample from base distribution, i.e. normal
# We get Z -> L
samples = rnd.normal(rng_key, shape=(n, input_shape))
# samples = rnd.normal(rng_key, shape=(batch_size, input_shape))
logprob = jnp.sum(norm.logpdf(samples), axis=-1)
# samples = rnd.sample(...)
out, state = flow.stateful_apply(rng_key, {"x": samples}, params, state)
# Need to add a minus to the log det term since we want p(L),
# not P(Z) which is what log det would normally be
sample_log_probs = jnp.clip(
-out["log_det"] + logprob, a_min=threshold, a_max=None
)
return (
out["x"], # type: ignore
sample_log_probs,
state,
) # I *think* this is right # type: ignore
if pretrain:
layers = [
nux.util.scale_by_belief(eps=1e-8),
optax.scale(3e-3),
optax.clip(15.0),
]
opt = optax.chain(*layers)
opt_state = opt.init(flow.params)
n_steps = 200
key = jax.random.split(rand_key, 2)[0]
noise = rnd.normal(rand_key, shape=(batch_size, input_shape)) * init_std
print("Pretraining Flow")
p, state = flow.params, flow.state
def loss_fn(p, state):
outputs, state = flow.stateful_apply(key, train_inputs, p, state)
return jnp.mean(outputs["log_px"]), state
@jit
def step(p, state, opt_p, key):
key, data_key = rnd.split(key, 2)
noise = rnd.normal(data_key, shape=(batch_size, input_shape)) * init_std
train_inputs = {"x": noise}
(loss, state), flow_grad = value_and_grad(loss_fn, has_aux=True)(p, state)
p_updates, opt_state = opt.update(flow_grad, opt_p, p)
p = optax.apply_updates(p, p_updates)
return p, state, opt_state, key
for i in range(n_steps):
p, state, opt_state, key = step(p, state, opt_state, key)
params = p
print("Finished pretraining flow")
else:
params = flow.params
def get_flow_arrays():
# We need this so that we can pmap the function to get the params
return flow.params, flow.state
def get_density(params, state, samples, rng_key):
logprob = jnp.sum(norm.logpdf(samples), axis=-1)
output = flow.stateful_apply(rng_key, {"x": samples}, params, state)
# Need to do - since the way it's set up, we'd get
logprobs = -output[0]["log_det"] + logprob # I *think* this is right
# Deal with occasional issues with point somehow outside the support
return jnp.clip(logprobs, a_min=threshold, a_max=None)
return flow.params, sample_flow, get_flow_arrays, get_density