From 1264482c04ed46eba02a21285379b735401adfd0 Mon Sep 17 00:00:00 2001 From: DanEnergetics Date: Tue, 21 May 2024 12:17:05 +0200 Subject: [PATCH] Add helper function applying correct padding for even kernel sizes (#54) --- i6_models/parts/frontend/common.py | 24 +++++++++++++++++ tests/test_padding.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 tests/test_padding.py diff --git a/i6_models/parts/frontend/common.py b/i6_models/parts/frontend/common.py index e5f0730e..bc1bfced 100644 --- a/i6_models/parts/frontend/common.py +++ b/i6_models/parts/frontend/common.py @@ -20,6 +20,30 @@ def get_same_padding(input_size: Union[int, Tuple[int, ...]]) -> Union[int, Tupl raise TypeError(f"unexpected size type {type(input_size)}") +def apply_same_padding(x: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]], **kwargs) -> torch.Tensor: + """ + Pad tensor almost symmetrically in one or more dimensions in order to not reduce time dimension + when applying convolution with the given kernel. As opposed to the standard padding parameter + this also handles even kernel sizes. + + :param x: + :param kernel_size: kernel size of the convolution for which the tensor is padded + :param kwargs: keyword args passed to functional.pad + :return: padded tensor + """ + if isinstance(kernel_size, int): + h = (kernel_size - 1) // 2 + return functional.pad(x, (h, kernel_size - 1 - h), **kwargs) + elif isinstance(kernel_size, tuple): + paddings = () + for k in reversed(kernel_size): # padding function starts with last dim + h = (k - 1) // 2 + paddings += (h, k - 1 - h) + return functional.pad(x, paddings, **kwargs) + else: + raise TypeError(f"Unexpected size type {type(kernel_size)}") + + def mask_pool(seq_mask: torch.Tensor, *, kernel_size: int, stride: int, padding: int) -> torch.Tensor: """ apply strides to the masking diff --git a/tests/test_padding.py b/tests/test_padding.py new file mode 100644 index 00000000..c46b8d83 --- /dev/null +++ b/tests/test_padding.py @@ -0,0 +1,42 @@ +from itertools import product + +import torch +from torch import nn + +from i6_models.parts.frontend.common import apply_same_padding, get_same_padding + + +def test_output_shape(): + # test for even and odd dim + last_dim = 101 + pre_last_dim = 100 + + iff = lambda x, y: x and y or not x and not y # x <=> y + strided_dim = lambda d, s: (d - 1) // s + 1 # expected out dimension for strided conv + + # `get_same_padding` seems to work for some stride > 1 + for kernel in product(range(1, 21), repeat=2): + conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, 1), padding=get_same_padding(kernel)) + + x = torch.randn(1, 1, pre_last_dim, last_dim) + + out = conv(x) + + # we expect `get_same_padding` to only cover odd kernel sizes + assert all( + iff(out_dim == in_dim, k % 2 == 1) for in_dim, out_dim, k in zip(x.shape[2:], out.shape[2:], kernel) + ), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and stride=1" + + for kernel, stride in product(product(range(1, 21), repeat=2), range(1, 7)): + conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, stride)) + + x = torch.randn(1, 1, pre_last_dim, last_dim) + x_padded = apply_same_padding(x, kernel) + + out = conv(x_padded) + + # correct out dimensions for all possible kernel sizes and strides + assert all( + out_dim == strided_dim(in_dim, s) + for in_dim, out_dim, k, s in zip(x.shape[2:], out.shape[2:], kernel, (1, stride)) + ), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and {stride=}"