forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
io_processors.py
830 lines (685 loc) · 28.7 KB
/
io_processors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""IO pre- and post-processors for Perceiver."""
import functools
import math
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple
import einops
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
from perceiver import position_encoding
ModalitySizeT = Mapping[str, int]
PreprocessorOutputT = Tuple[jnp.ndarray, Optional[jnp.ndarray], jnp.ndarray]
PreprocessorT = Callable[..., PreprocessorOutputT]
PostprocessorT = Callable[..., Any]
def reverse_space_to_depth(
frames: jnp.ndarray,
temporal_block_size: int = 1,
spatial_block_size: int = 1) -> jnp.ndarray:
"""Reverse space to depth transform."""
if len(frames.shape) == 4:
return einops.rearrange(
frames, 'b h w (dh dw c) -> b (h dh) (w dw) c',
dh=spatial_block_size, dw=spatial_block_size)
elif len(frames.shape) == 5:
return einops.rearrange(
frames, 'b t h w (dt dh dw c) -> b (t dt) (h dh) (w dw) c',
dt=temporal_block_size, dh=spatial_block_size, dw=spatial_block_size)
else:
raise ValueError(
'Frames should be of rank 4 (batch, height, width, channels)'
' or rank 5 (batch, time, height, width, channels)')
def space_to_depth(
frames: jnp.ndarray,
temporal_block_size: int = 1,
spatial_block_size: int = 1) -> jnp.ndarray:
"""Space to depth transform."""
if len(frames.shape) == 4:
return einops.rearrange(
frames, 'b (h dh) (w dw) c -> b h w (dh dw c)',
dh=spatial_block_size, dw=spatial_block_size)
elif len(frames.shape) == 5:
return einops.rearrange(
frames, 'b (t dt) (h dh) (w dw) c -> b t h w (dt dh dw c)',
dt=temporal_block_size, dh=spatial_block_size, dw=spatial_block_size)
else:
raise ValueError(
'Frames should be of rank 4 (batch, height, width, channels)'
' or rank 5 (batch, time, height, width, channels)')
def extract_patches(images: jnp.ndarray,
sizes: Sequence[int],
strides: Sequence[int],
rates: Sequence[int],
padding: str = 'VALID') -> jnp.ndarray:
"""Extract patches from images.
This function is a wrapper for jax.lax.conv_general_dilated_patches
to conforms to the same interface as tf.image.extract_patches.
The function extracts patches of shape sizes from the input images in the same
manner as a convolution with kernel of shape sizes, stride equal to strides,
and the given padding scheme.
The patches are stacked in the channel dimension.
Args:
images: input batch of images of shape [B, H, W, C].
sizes: size of extracted patches. Must be [1, size_rows, size_cols, 1].
strides: strides, must be [1, stride_rows, stride_cols, 1].
rates: sampling rate (as in dilated convolutions),
must be [1, rate_rows, rate_cols, 1].
padding: padding algorithm to use.
Returns:
Tensor of shape [B, patch_rows, patch_cols, size_rows * size_cols * C]
"""
if len(sizes) != 4 or sizes[0] != 1 or sizes[3] != 1:
raise ValueError(
f'Shape of sizes must be [1, size_rows, size_cols, 1], got {sizes}.')
if len(strides) != 4 or strides[0] != 1 or strides[3] != 1:
raise ValueError(
f'Shape of strides must be [1, size_rows, size_cols, 1], '
f'got {strides}.')
if len(rates) != 4 or rates[0] != 1 or rates[3] != 1:
raise ValueError(
f'Shape of rates must be [1, size_rows, size_cols, 1], got {rates}.')
if images.ndim != 4:
raise ValueError(
f'Rank of images must be 4 (got tensor of shape {jnp.shape(images)})')
# Rearrange axes of images to NCHW for conv_general_dilated_patches
images = einops.rearrange(images, 'n h w c -> n c h w')
channels = images.shape[1]
patches = jax.lax.conv_general_dilated_patches(
images, sizes[1:-1], strides[1:-1], padding, rhs_dilation=rates[1:-1])
# conv_general_dilated_patches returns patches in channel-major order.
# Rearrange to match interface of tf.image.extract_patches.
patches = einops.rearrange(patches, 'n (c ph pw) h w -> n h w (ph pw c)',
c=channels, ph=sizes[1], pw=sizes[2])
return patches
def patches_for_flow(inputs: jnp.ndarray) -> jnp.ndarray:
"""Extract 3x3x2 image patches for flow inputs."""
def pad_and_extract_patches(inputs):
padded_inputs = jnp.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]],
mode='constant')
return extract_patches(
padded_inputs,
sizes=[1, 3, 3, 1],
strides=[1, 1, 1, 1],
padding='VALID',
rates=[1, 1, 1, 1])
return jax.vmap(pad_and_extract_patches, in_axes=1, out_axes=1)(inputs)
# ------------------------------------------------------------
# ------------------- Up/down-sampling ---------------------
# ------------------------------------------------------------
class Conv2DDownsample(hk.Module):
"""Downsamples 4x by applying a 2D convolution and doing max pooling."""
def __init__(
self,
num_layers: int = 1,
num_channels: int = 64,
use_batchnorm: bool = True,
bn_config: Optional[Mapping[str, float]] = None,
name: Optional[str] = None,
):
"""Constructs a Conv2DDownsample model.
Args:
num_layers: The number of conv->max_pool layers.
num_channels: The number of conv output channels.
use_batchnorm: Whether to use batchnorm.
bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be
passed on to the :class:`~haiku.BatchNorm` layers. By default the
``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``.
name: Name of the module.
"""
super().__init__(name=name)
self._num_layers = num_layers
self._use_batchnorm = use_batchnorm
bn_config = dict(bn_config or {})
bn_config.setdefault('decay_rate', 0.9)
bn_config.setdefault('eps', 1e-5)
bn_config.setdefault('create_scale', True)
bn_config.setdefault('create_offset', True)
self.layers = []
for _ in range(self._num_layers):
conv = hk.Conv2D(
output_channels=num_channels,
kernel_shape=7,
stride=2,
with_bias=False,
padding='SAME',
name='conv')
if use_batchnorm:
batchnorm = hk.BatchNorm(name='batchnorm', **bn_config)
else:
batchnorm = None
self.layers.append(dict(conv=conv, batchnorm=batchnorm))
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
test_local_stats: bool = False) -> jnp.ndarray:
out = inputs
for layer in self.layers:
out = layer['conv'](out)
if layer['batchnorm'] is not None:
out = layer['batchnorm'](out, is_training, test_local_stats)
out = jax.nn.relu(out)
out = hk.max_pool(out,
window_shape=(1, 3, 3, 1),
strides=(1, 2, 2, 1),
padding='SAME')
return out
class Conv2DUpsample(hk.Module):
"""Upsamples 4x using 2 2D transposed convolutions."""
def __init__(
self,
n_outputs: int,
name: Optional[str] = None,
):
"""Constructs a Conv2DUpsample model.
Args:
n_outputs: The number of output channels of the module.
name: Name of the module.
"""
super().__init__(name=name)
self.transp_conv1 = hk.Conv2DTranspose(
output_channels=n_outputs*2,
kernel_shape=4,
stride=2,
with_bias=True,
padding='SAME',
name='transp_conv_1')
self.transp_conv2 = hk.Conv2DTranspose(
output_channels=n_outputs,
kernel_shape=4,
stride=2,
with_bias=True,
padding='SAME',
name='transp_conv_2')
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
test_local_stats: bool = False) -> jnp.ndarray:
out = inputs
out = self.transp_conv1(out)
out = jax.nn.relu(out)
out = self.transp_conv2(out)
return out
class Conv3DUpsample(hk.Module):
"""Simple convolutional auto-encoder."""
def __init__(self,
n_outputs: int,
n_time_upsamples: int = 2,
n_space_upsamples: int = 4,
name: Optional[str] = None):
super().__init__(name=name)
self._n_outputs = n_outputs
self._n_time_upsamples = n_time_upsamples
self._n_space_upsamples = n_space_upsamples
def __call__(self, x: jnp.ndarray, *, is_training: bool) -> jnp.ndarray:
n_upsamples = max(self._n_time_upsamples, self._n_space_upsamples)
time_stride = 2
space_stride = 2
for i in range(n_upsamples):
if i >= self._n_time_upsamples:
time_stride = 1
if i >= self._n_space_upsamples:
space_stride = 1
channels = self._n_outputs * pow(2, n_upsamples - 1 - i)
x = hk.Conv3DTranspose(output_channels=channels,
stride=[time_stride, space_stride, space_stride],
kernel_shape=[4, 4, 4],
name=f'conv3d_transpose_{i}')(x)
if i != n_upsamples - 1:
x = jax.nn.relu(x)
return x
class ImagePreprocessor(hk.Module):
"""Image preprocessing for Perceiver Encoder."""
def __init__(
self,
prep_type='conv',
spatial_downsample: int = 4,
temporal_downsample: int = 1,
position_encoding_type: str = 'fourier',
n_extra_pos_mlp: int = 0,
num_channels: int = 64,
conv_after_patching: bool = False,
conv2d_use_batchnorm: bool = True,
concat_or_add_pos: str = 'concat',
name: Optional[str] = None,
**position_encoding_kwargs):
super().__init__(name=name)
if prep_type not in ('conv', 'patches', 'pixels', 'conv1x1'):
raise ValueError('Invalid prep_type!')
if concat_or_add_pos not in ['concat', 'add']:
raise ValueError(
f'Invalid value {concat_or_add_pos} for concat_or_add_pos.')
self._prep_type = prep_type
self._spatial_downsample = spatial_downsample
self._temporal_downsample = temporal_downsample
self._concat_or_add_pos = concat_or_add_pos
self._conv_after_patching = conv_after_patching
self._num_channels = num_channels
if self._prep_type == 'conv':
# Downsampling with conv is currently restricted
convnet_num_layers = math.log(spatial_downsample, 4)
convnet_num_layers_is_int = (
convnet_num_layers == np.round(convnet_num_layers))
if not convnet_num_layers_is_int or temporal_downsample != 1:
raise ValueError('Only powers of 4 expected for spatial '
'and 1 expected for temporal '
'downsampling with conv.')
self.convnet = Conv2DDownsample(
num_layers=int(convnet_num_layers),
num_channels=num_channels,
use_batchnorm=conv2d_use_batchnorm)
elif self._prep_type == 'conv1x1':
assert temporal_downsample == 1, 'conv1x1 does not downsample in time.'
self.convnet_1x1 = hk.Conv2D(
num_channels, kernel_shape=[1, 1],
# spatial_downsample is unconstrained for 1x1 convolutions.
stride=[spatial_downsample, spatial_downsample])
# Partially construct the positional encoding function.
# We fully construct it when we know the input size.
self._positional_encoding_ctor = functools.partial(
position_encoding.build_position_encoding,
position_encoding_type=position_encoding_type,
**position_encoding_kwargs)
# Stack MLPs to get a deeper positional embedding.
self._n_extra_pos_mlp = n_extra_pos_mlp
def _build_network_inputs(
self, inputs: jnp.ndarray, pos: jnp.ndarray,
network_input_is_1d: bool = True) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Construct the final input, including position encoding."""
batch_size = inputs.shape[0]
index_dims = inputs.shape[1:-1]
# Reshape input features to a 1D index dimension if necessary.
if len(inputs.shape) > 3 and network_input_is_1d:
inputs = jnp.reshape(
inputs, [batch_size, np.prod(index_dims), -1])
# Construct the position encoding.
pos_enc = self._positional_encoding_ctor(
index_dims=index_dims)(batch_size=batch_size, pos=pos)
for i in range(0, self._n_extra_pos_mlp):
pos_enc += hk.Linear(pos_enc.shape[-1])(pos_enc)
if i < (self._n_extra_pos_mlp-1):
pos_enc = jax.nn.relu(pos_enc)
if not network_input_is_1d:
# Reshape pos to match the input feature shape
# if the network takes non-1D inputs
sh = inputs.shape
pos_enc = jnp.reshape(pos_enc, list(sh)[:-1]+[-1])
if self._concat_or_add_pos == 'concat':
inputs_with_pos = jnp.concatenate([inputs, pos_enc], axis=-1)
elif self._concat_or_add_pos == 'add':
inputs_with_pos = inputs + pos_enc
return inputs_with_pos, inputs
def __call__(
self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
network_input_is_1d: bool = True) -> PreprocessorOutputT:
if self._prep_type == 'conv':
# Convnet image featurization.
# Downsamples spatially by a factor of 4
conv = self.convnet
if len(inputs.shape) == 5:
conv = hk.BatchApply(conv)
inputs = conv(inputs, is_training=is_training)
elif self._prep_type == 'conv1x1':
# maps inputs to 64d
conv = self.convnet_1x1
if len(inputs.shape) == 5:
conv = hk.BatchApply(conv)
inputs = conv(inputs)
elif self._prep_type == 'patches':
# Space2depth featurization.
# Video: B x T x H x W x C
inputs = space_to_depth(
inputs,
temporal_block_size=self._temporal_downsample,
spatial_block_size=self._spatial_downsample)
if inputs.ndim == 5 and inputs.shape[1] == 1:
# for flow
inputs = jnp.squeeze(inputs, axis=1)
if self._conv_after_patching:
inputs = hk.Linear(self._num_channels, name='patches_linear')(inputs)
elif self._prep_type == 'pixels':
# if requested, downsamples in the crudest way
if inputs.ndim == 4:
inputs = inputs[:,
::self._spatial_downsample, ::self._spatial_downsample]
elif inputs.ndim == 5:
inputs = inputs[:, ::self._temporal_downsample,
::self._spatial_downsample, ::self._spatial_downsample]
else:
raise ValueError('Unsupported data format for pixels.')
inputs, inputs_without_pos = self._build_network_inputs(
inputs, pos, network_input_is_1d)
modality_sizes = None # Size for each modality, only needed for multimodal
return inputs, modality_sizes, inputs_without_pos
class ImagePostprocessor(hk.Module):
"""Image postprocessing for Perceiver."""
def __init__(
self,
postproc_type: str = 'pixels',
spatial_upsample: int = 1,
temporal_upsample: int = 1,
n_outputs: int = -1, # only relevant for 'conv1x1', 'conv', and 'raft'
input_reshape_size: Optional[Sequence[int]] = None,
name: Optional[str] = None):
super().__init__(name=name)
if postproc_type not in ('conv', 'patches', 'pixels', 'raft', 'conv1x1'):
raise ValueError('Invalid postproc_type!')
# Architecture parameters:
self._postproc_type = postproc_type
self._temporal_upsample = temporal_upsample
self._spatial_upsample = spatial_upsample
self._input_reshape_size = input_reshape_size
if self._postproc_type == 'pixels':
# No postprocessing.
if self._temporal_upsample != 1 or self._spatial_upsample != 1:
raise ValueError('Pixels postprocessing should not currently upsample.')
elif self._postproc_type == 'conv1x1':
assert self._temporal_upsample == 1, 'conv1x1 does not upsample in time.'
if n_outputs == -1:
raise ValueError('Expected value for n_outputs')
self.conv1x1 = hk.Conv2D(
n_outputs, kernel_shape=[1, 1],
# spatial_downsample is unconstrained for 1x1 convolutions.
stride=[self._spatial_upsample, self._spatial_upsample])
elif self._postproc_type == 'conv':
if n_outputs == -1:
raise ValueError('Expected value for n_outputs')
if self._temporal_upsample != 1:
def int_log2(x):
return int(np.round(np.log(x) / np.log(2)))
self.convnet = Conv3DUpsample(
n_outputs, int_log2(temporal_upsample), int_log2(spatial_upsample))
else:
self.convnet = Conv2DUpsample(n_outputs)
def __call__(
self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
if self._input_reshape_size is not None:
inputs = jnp.reshape(
inputs,
[inputs.shape[0]] + list(self._input_reshape_size)
+ [inputs.shape[-1]])
if self._postproc_type == 'conv' or self._postproc_type == 'raft':
# Convnet image featurization.
conv = self.convnet
if len(inputs.shape) == 5 and self._temporal_upsample == 1:
conv = hk.BatchApply(conv)
inputs = conv(inputs, is_training=is_training)
elif self._postproc_type == 'conv1x1':
inputs = self.conv1x1(inputs)
elif self._postproc_type == 'patches':
inputs = reverse_space_to_depth(
inputs, self._temporal_upsample, self._spatial_upsample)
return inputs
class OneHotPreprocessor(hk.Module):
"""One-hot preprocessor for Perceiver Encoder."""
def __init__(self, name: Optional[str] = None):
super().__init__(name=name)
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
network_input_is_1d: bool = True) -> PreprocessorOutputT:
# Add a dummy index dimension.
inputs = inputs[:, None, :]
# No position encodings, so the 1st (input) and 3rd (inputs_without_pos)
# outputs are identical.
return inputs, None, inputs
class AudioPreprocessor(hk.Module):
"""Audio preprocessing for Perceiver Encoder."""
def __init__(
self,
prep_type: str = 'patches',
samples_per_patch: int = 96,
position_encoding_type: str = 'fourier',
n_extra_pos_mlp: int = 0,
concat_or_add_pos: str = 'concat',
name: Optional[str] = None,
**position_encoding_kwargs):
super().__init__(name=name)
if prep_type not in ('patches',):
raise ValueError('Invalid prep_type!')
if concat_or_add_pos not in ['concat', 'add']:
raise ValueError(
f'Invalid value {concat_or_add_pos} for concat_or_add_pos.')
self._samples_per_patch = samples_per_patch
self._concat_or_add_pos = concat_or_add_pos
# Partially construct the positional encoding function.
# We fully construct it when we know the input size.
self._positional_encoding_ctor = functools.partial(
position_encoding.build_position_encoding,
position_encoding_type=position_encoding_type,
**position_encoding_kwargs)
# for deeper positional embeddings
self._n_extra_pos_mlp = n_extra_pos_mlp
def _build_network_inputs(
self, inputs: jnp.ndarray,
pos: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Construct the final input, including position encoding."""
batch_size = inputs.shape[0]
index_dims = inputs.shape[1:-1]
# Construct the position encoding.
pos_enc = self._positional_encoding_ctor(
index_dims=index_dims)(batch_size=batch_size, pos=pos)
for i in range(0, self._n_extra_pos_mlp):
pos_enc += hk.Linear(pos_enc.shape[-1])(pos_enc)
if i < (self._n_extra_pos_mlp-1):
pos_enc = jax.nn.relu(pos_enc)
if self._concat_or_add_pos == 'concat':
inputs_with_pos = jnp.concatenate([inputs, pos_enc], axis=-1)
elif self._concat_or_add_pos == 'add':
inputs_with_pos = inputs + pos_enc
return inputs_with_pos, inputs
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
network_input_is_1d: bool = True) -> PreprocessorOutputT:
inputs = jnp.reshape(inputs, [inputs.shape[0], -1,
self._samples_per_patch])
inputs, inputs_without_pos = self._build_network_inputs(inputs, pos)
modality_sizes = None # Size for each modality, only needed for multimodal
return inputs, modality_sizes, inputs_without_pos
class AudioPostprocessor(hk.Module):
"""Audio postprocessing for Perceiver."""
def __init__(
self,
postproc_type: str = 'patches', # 'conv', 'patches', 'pixels'
samples_per_patch: int = 96,
name: Optional[str] = None):
super().__init__(name=name)
if postproc_type not in ('patches',):
raise ValueError('Invalid postproc_type!')
self._samples_per_patch = samples_per_patch
# Architecture parameters:
self._postproc_type = postproc_type
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
out = hk.Linear(self._samples_per_patch)(inputs)
return jnp.reshape(out, [inputs.shape[0], -1])
class IdentityPostprocessor(hk.Module):
"""Passes through the inputs unchanged."""
def __init__(self, name: Optional[str] = None):
super().__init__(name=name)
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
return inputs
def restructure(modality_sizes: ModalitySizeT,
inputs: jnp.ndarray) -> Mapping[str, jnp.ndarray]:
"""Partitions a [B, N, C] tensor into tensors for each modality.
Args:
modality_sizes: dict specifying the size of the modality
inputs: input tensor
Returns:
dict mapping name of modality to its associated tensor.
"""
outputs = {}
index = 0
# Apply a predictable ordering to the modalities
for modality in sorted(modality_sizes.keys()):
size = modality_sizes[modality]
inp = inputs[:, index:index + size]
index += size
outputs[modality] = inp
return outputs
class MultimodalPreprocessor(hk.Module):
"""Multimodal preprocessing for Perceiver Encoder.
Inputs for each modality is preprocessed then padded with trainable position
embeddings to have the same number of channels.
"""
def __init__(
self,
modalities: Mapping[str, PreprocessorT],
mask_probs: Optional[Mapping[str, float]] = None,
min_padding_size: int = 2,
name: Optional[str] = None):
"""Constructor.
Args:
modalities: dict mapping modality name to preprocessor
mask_probs: dict mapping modality name to masking probability of that
modality
min_padding_size: the minimum padding size for all modalities.
The final output will have num_channels equal to the maximum channels
across all modalities plus min_padding_size.
name: name of module
"""
super().__init__(name=name)
self._modalities = modalities
self._min_padding_size = min_padding_size
self._mask_probs = mask_probs
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
network_input_is_1d: bool = True) -> PreprocessorOutputT:
outputs = {}
inputs_without_pos = {}
for modality, preprocessor in self._modalities.items():
outputs[modality], _, inputs_without_pos[modality] = preprocessor(
inputs[modality], is_training=is_training, pos=pos,
network_input_is_1d=network_input_is_1d)
common_channel_size = (max(o.shape[2] for o in outputs.values())
+ self._min_padding_size)
padded = {}
modality_sizes = {}
for modality, output in outputs.items():
pos_enc = position_encoding.TrainablePositionEncoding(
1, num_channels=common_channel_size-output.shape[2],
init_scale=0.02, name=f'{modality}_padding')
padding = jnp.broadcast_to(
pos_enc(batch_size=output.shape[0]),
[output.shape[0], output.shape[1],
common_channel_size-output.shape[2]])
output_padded = jnp.concatenate([output, padding], axis=2)
if self._mask_probs is not None:
# Randomly mask out each token corresponding to this modality
mask_token = position_encoding.TrainablePositionEncoding(
1, num_channels=output_padded.shape[2],
init_scale=0.02, name=f'{modality}_mask_token')(output.shape[0])
mask_prob = self._mask_probs[modality]
rng = hk.next_rng_key()
mask = jax.random.bernoulli(rng, mask_prob,
shape=[output.shape[0], output.shape[1]])
mask = jnp.expand_dims(mask, axis=2)
output_padded = (1 - mask) * output_padded + mask * mask_token
padded[modality] = output_padded
modality_sizes[modality] = output_padded.shape[1]
# Apply a predictable ordering to the modalities
padded_ls = [padded[k] for k in sorted(padded.keys())]
return (jnp.concatenate(padded_ls, axis=1), # pytype: disable=bad-return-type # jax-ndarray
modality_sizes,
inputs_without_pos)
class MultimodalPostprocessor(hk.Module):
"""Multimodal postprocessing for Perceiver."""
def __init__(
self,
modalities: Mapping[str, PostprocessorT],
input_is_dict: bool = False,
name: Optional[str] = None):
"""Constructor.
Args:
modalities: dict mapping modality name to post processor for that modality
input_is_dict: If True, input is assumed to be dictionary structured,
and outputs keep the same dictionary shape. If False, input is a tensor
which is sliced up during postprocessing by `modality_sizes`.
name: name of the module
"""
super().__init__(name=name)
self._modalities = modalities
self._input_is_dict = input_is_dict
def __call__(
self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> Mapping[str,
jnp.ndarray]:
if not self._input_is_dict:
# Slice up modalities by their sizes.
assert modality_sizes is not None
inputs = restructure(modality_sizes=modality_sizes, inputs=inputs)
outputs = {modality: postprocessor(
inputs[modality], is_training=is_training, pos=pos, modality_sizes=None)
for modality, postprocessor in self._modalities.items()}
return outputs
class ClassificationPostprocessor(hk.Module):
"""Classification postprocessing for Perceiver."""
def __init__(
self,
num_classes: int,
name: Optional[str] = None):
super().__init__(name=name)
self._num_classes = num_classes
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
logits = hk.Linear(self._num_classes)(inputs)
return logits[:, 0, :]
class ProjectionPostprocessor(hk.Module):
"""Projection postprocessing for Perceiver."""
def __init__(
self,
num_outputs: int,
name: Optional[str] = None):
super().__init__(name=name)
self._num_outputs = num_outputs
def __call__(self, inputs: jnp.ndarray, *,
is_training: bool,
pos: Optional[jnp.ndarray] = None,
modality_sizes: Optional[ModalitySizeT] = None) -> jnp.ndarray:
logits = hk.Linear(self._num_outputs)(inputs)
return logits
class EmbeddingDecoder(hk.Module):
"""Haiku module to decode embeddings."""
def __init__(self, embedding_matrix: jnp.ndarray, name='embedding_decoder'):
"""Constructs the module.
Args:
embedding_matrix: Array of shape [vocab_size, d_model].
name: Name of the module.
"""
super().__init__(name=name)
self._embedding_matrix = embedding_matrix
self._vocab_size, self._d_model = embedding_matrix.shape
def __call__(self, embeddings: jnp.ndarray) -> jnp.ndarray:
batch_size, seq_len, _ = embeddings.shape
output = jnp.matmul(
embeddings.reshape([-1, self._d_model]), # Flatten batch dim
jnp.transpose(self._embedding_matrix))
bias = hk.get_parameter('bias', shape=[self._vocab_size], init=jnp.zeros)
output = output + bias
return output.reshape([batch_size, seq_len, self._vocab_size])