diff --git a/src/models/bestrq.py b/src/models/bestrq.py index 8413275..3d70600 100644 --- a/src/models/bestrq.py +++ b/src/models/bestrq.py @@ -119,7 +119,14 @@ def forward( if mask_time_indices is not None: mask_time_indices = mask_time_indices.to(torch.bool) - targets = self.rpq(input_values.view((*mask_time_indices.shape[:2], -1))) + is_correctly_padded = (input_values.shape[1] % mask_time_indices.shape[1]) == 0 + subsample_to = input_values.shape[1] // round(input_values.shape[1] / mask_time_indices.shape[1]) + + targets = self.rpq(input_values.view((mask_time_indices.shape[0], subsample_to, -1))) + + if not is_correctly_padded: + targets = targets[:, :, -mask_time_indices.shape[1] :] + targets = targets.masked_fill(~mask_time_indices[:, None, ...], -100) outputs = self.wav2vec2(