forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
274 lines (212 loc) · 8.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities functions for Jax."""
from collections import abc
import functools
from typing import Any, Callable, Dict, Mapping, Union
import distrax
import jax
from jax import core
from jax import lax
from jax import nn
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
from jaxline import utils
import numpy as np
HaikuParams = Mapping[str, Mapping[str, jnp.ndarray]]
Params = Union[Mapping[str, jnp.ndarray], HaikuParams, jnp.ndarray]
_Activation = Callable[[jnp.ndarray], jnp.ndarray]
tf_leaky_relu = functools.partial(nn.leaky_relu, negative_slope=0.2)
def filter_only_scalar_stats(stats):
return {k: v for k, v in stats.items() if v.size == 1}
def to_numpy(obj):
return jax.tree_map(np.array, obj)
@jax.custom_gradient
def geco_lagrange_product(lagrange_multiplier, constraint_ema, constraint_t):
"""Modifies the gradients so that they work as described in GECO.
The evaluation gives:
lagrange * C_ema
The gradient w.r.t lagrange:
- g * C_t
The gradient w.r.t constraint_ema:
0.0
The gradient w.r.t constraint_t:
g * lagrange
Note that if you pass the same value for `constraint_ema` and `constraint_t`
this would only flip the gradient for the lagrange multiplier.
Args:
lagrange_multiplier: The lagrange multiplier
constraint_ema: The moving average of the constraint
constraint_t: The current constraint
Returns:
"""
def grad(gradient):
return (- gradient * constraint_t,
jnp.zeros_like(constraint_ema),
gradient * lagrange_multiplier)
return lagrange_multiplier * constraint_ema, grad
def bcast_if(x, t, n):
return [x] * n if isinstance(x, t) else x
def stack_time_into_channels(
images: jnp.ndarray,
data_format: str
) -> jnp.ndarray:
axis = data_format.index("C")
list_of_time = [jnp.squeeze(v, axis=1) for v in
jnp.split(images, images.shape[1], axis=1)]
return jnp.concatenate(list_of_time, axis)
def stack_device_dim_into_batch(obj):
return jax.tree_map(lambda x: x.reshape((-1,) + x.shape[2:]), obj)
def nearest_neighbour_upsampling(x, scale, data_format="NHWC"):
"""Performs nearest-neighbour upsampling."""
if data_format == "NCHW":
b, c, h, w = x.shape
x = jnp.reshape(x, [b, c, h, 1, w, 1])
ones = jnp.ones([1, 1, 1, scale, 1, scale], dtype=x.dtype)
return jnp.reshape(x * ones, [b, c, scale * h, scale * w])
elif data_format == "NHWC":
b, h, w, c = x.shape
x = jnp.reshape(x, [b, h, 1, w, 1, c])
ones = jnp.ones([1, 1, scale, 1, scale, 1], dtype=x.dtype)
return jnp.reshape(x * ones, [b, scale * h, scale * w, c])
else:
raise ValueError(f"Unrecognized data_format={data_format}.")
def get_activation(arg: Union[_Activation, str]) -> _Activation:
"""Returns an activation from provided string."""
if isinstance(arg, str):
# Try fetch in order - [this module, jax.nn, jax.numpy]
if arg in globals():
return globals()[arg]
if hasattr(nn, arg):
return getattr(nn, arg)
elif hasattr(jnp, arg):
return getattr(jnp, arg)
else:
raise ValueError(f"Unrecognized activation with name {arg}.")
if not callable(arg):
raise ValueError(f"Expected a callable, but got {type(arg)}")
return arg
def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.ndarray:
return x.reshape((-1,) + x.shape[num_dims_to_merge:])
def extract_image(
inputs: Union[jnp.ndarray, Mapping[str, jnp.ndarray]]
) -> jnp.ndarray:
"""Extracts a tensor with key `image` or `x_image` if it is a dict, otherwise returns the inputs."""
if isinstance(inputs, dict):
if "image" in inputs:
return inputs["image"]
else:
return inputs["x_image"]
elif isinstance(inputs, jnp.ndarray):
return inputs
raise NotImplementedError(f"Not implemented of inputs of type"
f" {type(inputs)}.")
def extract_gt_state(inputs: Any) -> jnp.ndarray:
if isinstance(inputs, dict):
return inputs["x"]
elif not isinstance(inputs, jnp.ndarray):
raise NotImplementedError(f"Not implemented of inputs of type"
f" {type(inputs)}.")
return inputs
def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1):
q, p = jnp.split(conv_latents, 2, axis=-1)
q = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), q)
p = jax.tree_map(lambda x: x.reshape(x.shape[:axis_n_to_keep] + (-1,)), p)
flat_latents = jnp.concatenate([q, p], axis=-1)
return flat_latents
def triu_matrix_from_v(x, ndim):
assert x.shape[-1] == (ndim * (ndim + 1)) // 2
matrix = jnp.zeros(x.shape[:-1] + (ndim, ndim))
idx = jnp.triu_indices(ndim)
index_update = lambda x, idx, y: x.at[idx].set(y)
for _ in range(x.ndim - 1):
index_update = jax.vmap(index_update, in_axes=(0, None, 0))
return index_update(matrix, idx, x)
def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, Any]:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def convert_to_pytype(target, reference):
"""Makes target the same pytype as reference, by jax.tree_flatten."""
_, pytree = jax.tree_flatten(reference)
leaves, _ = jax.tree_flatten(target)
return jax.tree_unflatten(pytree, leaves)
def func_if_not_scalar(func):
"""Makes a function that uses func only on non-scalar values."""
@functools.wraps(func)
def wrapped(array, axis=0):
if array.ndim == 0:
return array
return func(array, axis=axis)
return wrapped
mean_if_not_scalar = func_if_not_scalar(jnp.mean)
class MultiBatchAccumulator(object):
"""Class for abstracting statistics accumulation over multiple batches."""
def __init__(self):
self._obj = None
self._obj_max = None
self._obj_min = None
self._num_samples = None
def add(self, averaged_values, num_samples):
"""Adds an element to the moving average and the max."""
if self._obj is None:
self._obj_max = jax.tree_map(lambda y: y * 1.0, averaged_values)
self._obj_min = jax.tree_map(lambda y: y * 1.0, averaged_values)
self._obj = jax.tree_map(lambda y: y * num_samples, averaged_values)
self._num_samples = num_samples
else:
self._obj_max = jax.tree_map(jnp.maximum, self._obj_max,
averaged_values)
self._obj_min = jax.tree_map(jnp.minimum, self._obj_min,
averaged_values)
self._obj = jax.tree_map(lambda x, y: x + y * num_samples, self._obj,
averaged_values)
self._num_samples += num_samples
def value(self):
return jax.tree_map(lambda x: x / self._num_samples, self._obj)
def max(self):
return jax.tree_map(float, self._obj_max)
def min(self):
return jax.tree_map(float, self._obj_min)
def sum(self):
return self._obj
register_pytree_node(
distrax.Normal,
lambda instance: ([instance.loc, instance.scale], None),
lambda _, args: distrax.Normal(*args)
)
def inner_product(x: Any, y: Any) -> jnp.ndarray:
products = jax.tree_map(lambda x_, y_: jnp.sum(x_ * y_), x, y)
return sum(jax.tree_leaves(products))
get_first = utils.get_first
bcast_local_devices = utils.bcast_local_devices
py_prefetch = utils.py_prefetch
p_split = jax.pmap(lambda x, num: list(jax.random.split(x, num)),
static_broadcasted_argnums=1)
def wrap_if_pmap(p_func):
def p_func_if_pmap(obj, axis_name):
try:
core.axis_frame(axis_name)
return p_func(obj, axis_name)
except NameError:
return obj
return p_func_if_pmap
pmean_if_pmap = wrap_if_pmap(lax.pmean)
psum_if_pmap = wrap_if_pmap(lax.psum)