forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
position_encoding.py
222 lines (184 loc) · 7.92 KB
/
position_encoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Position encodings and utilities."""
import abc
import functools
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def generate_fourier_features(
pos, num_bands, max_resolution=(224, 224),
concat_pos=True, sine_only=False):
"""Generate a Fourier frequency position encoding with linear spacing.
Args:
pos: The position of n points in d dimensional space.
A jnp array of shape [n, d].
num_bands: The number of bands (K) to use.
max_resolution: The maximum resolution (i.e. the number of pixels per dim).
A tuple representing resolution for each dimension
concat_pos: Concatenate the input position encoding to the Fourier features?
sine_only: Whether to use a single phase (sin) or two (sin/cos) for each
frequency band.
Returns:
embedding: A 1D jnp array of shape [n, n_channels]. If concat_pos is True
and sine_only is False, output dimensions are ordered as:
[dim_1, dim_2, ..., dim_d,
sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ...,
sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d),
cos(pi*f_1*dim_1), ..., cos(pi*f_K*dim_1), ...,
cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)],
where dim_i is pos[:, i] and f_k is the kth frequency band.
"""
min_freq = 1.0
# Nyquist frequency at the target resolution:
freq_bands = jnp.stack([
jnp.linspace(min_freq, res / 2, num=num_bands, endpoint=True)
for res in max_resolution], axis=0)
# Get frequency bands for each spatial dimension.
# Output is size [n, d * num_bands]
per_pos_features = pos[:, :, None] * freq_bands[None, :, :]
per_pos_features = jnp.reshape(per_pos_features,
[-1, np.prod(per_pos_features.shape[1:])])
if sine_only:
# Output is size [n, d * num_bands]
per_pos_features = jnp.sin(jnp.pi * (per_pos_features))
else:
# Output is size [n, 2 * d * num_bands]
per_pos_features = jnp.concatenate(
[jnp.sin(jnp.pi * per_pos_features),
jnp.cos(jnp.pi * per_pos_features)], axis=-1)
# Concatenate the raw input positions.
if concat_pos:
# Adds d bands to the encoding.
per_pos_features = jnp.concatenate([pos, per_pos_features], axis=-1)
return per_pos_features
def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
"""Generate an array of position indices for an N-D input array.
Args:
index_dims: The shape of the index dimensions of the input array.
output_range: The min and max values taken by each input index dimension.
Returns:
A jnp array of shape [index_dims[0], index_dims[1], .., index_dims[-1], N].
"""
def _linspace(n_xels_per_dim):
return jnp.linspace(
output_range[0], output_range[1],
num=n_xels_per_dim,
endpoint=True, dtype=jnp.float32)
dim_ranges = [
_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
array_index_grid = jnp.meshgrid(*dim_ranges, indexing='ij')
return jnp.stack(array_index_grid, axis=-1)
class AbstractPositionEncoding(hk.Module, metaclass=abc.ABCMeta):
"""Abstract Perceiver decoder."""
@abc.abstractmethod
def __call__(self, batch_size, pos):
raise NotImplementedError
class TrainablePositionEncoding(AbstractPositionEncoding):
"""Trainable position encoding."""
def __init__(self, index_dim, num_channels=128, init_scale=0.02, name=None):
super(TrainablePositionEncoding, self).__init__(name=name)
self._index_dim = index_dim
self._num_channels = num_channels
self._init_scale = init_scale
def __call__(self, batch_size, pos=None):
del pos # Unused.
pos_embs = hk.get_parameter(
'pos_embs', [self._index_dim, self._num_channels],
init=hk.initializers.TruncatedNormal(stddev=self._init_scale))
if batch_size is not None:
pos_embs = jnp.broadcast_to(
pos_embs[None, :, :], (batch_size,) + pos_embs.shape)
return pos_embs
def _check_or_build_spatial_positions(pos, index_dims, batch_size):
"""Checks or builds spatial position features (x, y, ...).
Args:
pos: None, or an array of position features. If None, position features
are built. Otherwise, their size is checked.
index_dims: An iterable giving the spatial/index size of the data to be
featurized.
batch_size: The batch size of the data to be featurized.
Returns:
An array of position features, of shape [batch_size, prod(index_dims)].
"""
if pos is None:
pos = build_linear_positions(index_dims)
pos = jnp.broadcast_to(pos[None], (batch_size,) + pos.shape)
pos = jnp.reshape(pos, [batch_size, np.prod(index_dims), -1])
else:
# Just a warning label: you probably don't want your spatial features to
# have a different spatial layout than your pos coordinate system.
# But feel free to override if you think it'll work!
assert pos.shape[-1] == len(index_dims)
return pos
class FourierPositionEncoding(AbstractPositionEncoding):
"""Fourier (Sinusoidal) position encoding."""
def __init__(self, index_dims, num_bands, concat_pos=True,
max_resolution=None, sine_only=False, name=None):
super(FourierPositionEncoding, self).__init__(name=name)
self._num_bands = num_bands
self._concat_pos = concat_pos
self._sine_only = sine_only
self._index_dims = index_dims
# Use the index dims as the maximum resolution if it's not provided.
self._max_resolution = max_resolution or index_dims
def __call__(self, batch_size, pos=None):
pos = _check_or_build_spatial_positions(pos, self._index_dims, batch_size)
build_ff_fn = functools.partial(
generate_fourier_features,
num_bands=self._num_bands,
max_resolution=self._max_resolution,
concat_pos=self._concat_pos,
sine_only=self._sine_only)
return jax.vmap(build_ff_fn, 0, 0)(pos)
class PositionEncodingProjector(AbstractPositionEncoding):
"""Projects a position encoding to a target size."""
def __init__(self, output_size, base_position_encoding, name=None):
super(PositionEncodingProjector, self).__init__(name=name)
self._output_size = output_size
self._base_position_encoding = base_position_encoding
def __call__(self, batch_size, pos=None):
base_pos = self._base_position_encoding(batch_size, pos)
projected_pos = hk.Linear(output_size=self._output_size)(base_pos)
return projected_pos
def build_position_encoding(
position_encoding_type,
index_dims,
project_pos_dim=-1,
trainable_position_encoding_kwargs=None,
fourier_position_encoding_kwargs=None,
name=None):
"""Builds the position encoding."""
if position_encoding_type == 'trainable':
assert trainable_position_encoding_kwargs is not None
output_pos_enc = TrainablePositionEncoding(
# Construct 1D features:
index_dim=np.prod(index_dims),
name=name,
**trainable_position_encoding_kwargs)
elif position_encoding_type == 'fourier':
assert fourier_position_encoding_kwargs is not None
output_pos_enc = FourierPositionEncoding(
index_dims=index_dims,
name=name,
**fourier_position_encoding_kwargs)
else:
raise ValueError(f'Unknown position encoding: {position_encoding_type}.')
if project_pos_dim > 0:
# Project the position encoding to a target dimension:
output_pos_enc = PositionEncodingProjector(
output_size=project_pos_dim,
base_position_encoding=output_pos_enc)
return output_pos_enc