-
Notifications
You must be signed in to change notification settings - Fork 305
/
combiner.py
244 lines (217 loc) · 8.83 KB
/
combiner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright 2022 Behavox LLC. (author: Daniil Kulko)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
class RandomCombine(torch.nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1, num_inputs
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1))
.log()
.item()
)
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}"
if not self.training or torch.jit.is_scripting() or len(inputs) == 1:
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(
dtype=dtype
)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> torch.Tensor:
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(
final_weight: float,
pure_prob: float,
stddev: float,
) -> None:
print(
f"_test_random_combine: final_weight={final_weight}, "
f"pure_prob={pure_prob}, stddev={stddev}"
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main() -> None:
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
if __name__ == "__main__":
_test_random_combine_main()