Skip to content

Commit

Permalink
Add helper function applying correct padding for even kernel sizes (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanEnergetics authored May 21, 2024
1 parent 5e4d013 commit 1264482
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
24 changes: 24 additions & 0 deletions i6_models/parts/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,30 @@ def get_same_padding(input_size: Union[int, Tuple[int, ...]]) -> Union[int, Tupl
raise TypeError(f"unexpected size type {type(input_size)}")


def apply_same_padding(x: torch.Tensor, kernel_size: Union[int, Tuple[int, ...]], **kwargs) -> torch.Tensor:
"""
Pad tensor almost symmetrically in one or more dimensions in order to not reduce time dimension
when applying convolution with the given kernel. As opposed to the standard padding parameter
this also handles even kernel sizes.
:param x:
:param kernel_size: kernel size of the convolution for which the tensor is padded
:param kwargs: keyword args passed to functional.pad
:return: padded tensor
"""
if isinstance(kernel_size, int):
h = (kernel_size - 1) // 2
return functional.pad(x, (h, kernel_size - 1 - h), **kwargs)
elif isinstance(kernel_size, tuple):
paddings = ()
for k in reversed(kernel_size): # padding function starts with last dim
h = (k - 1) // 2
paddings += (h, k - 1 - h)
return functional.pad(x, paddings, **kwargs)
else:
raise TypeError(f"Unexpected size type {type(kernel_size)}")


def mask_pool(seq_mask: torch.Tensor, *, kernel_size: int, stride: int, padding: int) -> torch.Tensor:
"""
apply strides to the masking
Expand Down
42 changes: 42 additions & 0 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from itertools import product

import torch
from torch import nn

from i6_models.parts.frontend.common import apply_same_padding, get_same_padding


def test_output_shape():
# test for even and odd dim
last_dim = 101
pre_last_dim = 100

iff = lambda x, y: x and y or not x and not y # x <=> y
strided_dim = lambda d, s: (d - 1) // s + 1 # expected out dimension for strided conv

# `get_same_padding` seems to work for some stride > 1
for kernel in product(range(1, 21), repeat=2):
conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, 1), padding=get_same_padding(kernel))

x = torch.randn(1, 1, pre_last_dim, last_dim)

out = conv(x)

# we expect `get_same_padding` to only cover odd kernel sizes
assert all(
iff(out_dim == in_dim, k % 2 == 1) for in_dim, out_dim, k in zip(x.shape[2:], out.shape[2:], kernel)
), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and stride=1"

for kernel, stride in product(product(range(1, 21), repeat=2), range(1, 7)):
conv = nn.Conv2d(1, 1, kernel_size=kernel, stride=(1, stride))

x = torch.randn(1, 1, pre_last_dim, last_dim)
x_padded = apply_same_padding(x, kernel)

out = conv(x_padded)

# correct out dimensions for all possible kernel sizes and strides
assert all(
out_dim == strided_dim(in_dim, s)
for in_dim, out_dim, k, s in zip(x.shape[2:], out.shape[2:], kernel, (1, stride))
), f"Failed for {x.shape=}, {out.shape=}, {kernel=} and {stride=}"

0 comments on commit 1264482

Please sign in to comment.