Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolution over feature dim #135

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 49 additions & 5 deletions pytorch_to_returnn/torch/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,24 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]:
assert all(p == 0 for p in self.padding) # not implemented otherwise
assert all(p == 0 for p in self.output_padding) # not implemented otherwise
assert self.padding_mode == "zeros" # not implemented otherwise

from pytorch_to_returnn.naming import Naming
naming = Naming.get_instance()
input_tensor = naming.tensors[input]
in_dim = input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[1]]
in_spatial_dims = [
input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[dim + len(input.shape)]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you use _get_input_axis_to_returnn for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that would return an axis description like "T" or "F". This would be mapped to a dim tag in the ConvLayer construction. However, in case we do convolution over the feature dim, "F" would be mapped to the in_dim, so the new feature dim and not the old one which does not work.

for dim in range(-self.nd, 0)]

d = {
"class": "conv", "from": self._get_input_layer_name(input),
"activation": None,
"with_bias": self.bias is not None,
"n_out": self.out_channels,
"filter_size": self.kernel_size,
"padding": "valid",
"in_spatial_dims": [self._get_input_axis_to_returnn(input, dim) for dim in range(-self.nd, 0)],
"in_spatial_dims": in_spatial_dims,
"in_dim": in_dim,
}
if any(s != 1 for s in self.stride):
d["strides"] = self.stride
Expand Down Expand Up @@ -121,18 +131,52 @@ def import_params_torch_to_returnn(self, *, layer: LayerBase, torch_module: _Con
def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: LayerBase
) -> Tuple[Tuple[int, ...], Dict[int, int]]:
"""
The basic returnn_axis_from_torch_axis should be correct, however, if the size of a dynamic axis changes (e.g. due
to strides and/or padding), this is not covered in the base method and we fix it here.
If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method
and we fix it here. Also, returnn_axis_from_torch_axis needs to be modified for cases where the RETURNN input
feature dim is used as a spatial dim for convolution.
"""
torch_shape, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn(
inputs_flat=inputs_flat, layer=layer)
assert len(inputs_flat) == 1
torch_shape = list(inputs_flat[0].shape)
torch_shape[1] = self.out_channels
for idx in range(self.nd):
torch_ax = idx + 2
torch_shape[torch_ax] = (torch_shape[torch_ax] + 2 * self.padding[idx] - self.dilation[idx] * (
self.kernel_size[idx] - 1) - 1) // self.stride[idx] + 1

from pytorch_to_returnn.naming import Naming
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you do it so complicated. This code here should be quite short, straightforward, and not using any heuristics. Your code here is full of heuristics, checking whether you can map all axes, etc. You don't need that. We know exactly how it must map.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know exactly how it must map.

All dims that did not change are trivial, the channel dim can be done as I do it here. How would you do it for the spatial dims? Just assume that the order of spatial dims is the same as in in_spatial_dims?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an update which does the mapping as I described above.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just assume that the order of spatial dims is the same as in in_spatial_dims?

What order? Of the RETURNN output?

We don't need to guess anything here. We know everything exactly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except BCHW vs BHWC, but you can just check where out_dim is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then that is exactly what I do now, right?

naming = Naming.get_instance()
input_tensor = naming.tensors[inputs_flat[0]]
in_data = input_tensor.returnn_data
out_data = layer.output
assert in_data.batch_ndim == out_data.batch_ndim

mapping_out_to_in = {}
# map unchanged dims
for in_dim, in_dim_tag in enumerate(in_data.dim_tags):
if in_dim_tag in out_data.dim_tags:
mapping_out_to_in[out_data.dim_tags.index(in_dim_tag)] = in_dim

# map channel dim
in_channel = input_tensor.returnn_axis_from_torch_axis[1]
out_channel = [
dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels]
assert len(out_channel) == 1
mapping_out_to_in[out_channel[0]] = in_channel

# map spatial axes, the order is the same as in in_spatial_dims
in_spatial_dims = [
input_tensor.returnn_axis_from_torch_axis[dim + in_data.batch_ndim]
for dim in range(-self.nd, 0)]
for in_dim, out_dim in zip(in_spatial_dims, out_data.get_spatial_batch_axes()):
mapping_out_to_in[out_dim] = in_dim

assert len(mapping_out_to_in) == in_data.batch_ndim, (
f"Not all axes were mapped successfully. In: {in_data}, out: {out_data}")
returnn_axis_from_torch_axis = {}
for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items():
torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv
returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis

return tuple(torch_shape), returnn_axis_from_torch_axis


Expand Down
40 changes: 40 additions & 0 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,46 @@ def create_returnn_layer_dict(self, input):
inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2})


def test_vgg():
n_batch, n_time, n_feat = 3, 20, 5 # B, T, F

def model_func(wrapped_import, inputs: torch.Tensor):
if typing.TYPE_CHECKING or not wrapped_import:
import torch
else:
torch = wrapped_import("torch")

class VggBlock(torch.nn.Module):
def __init__(self, n_in, n_out, kernel_size, pool_size=None, stride: typing.Union[int, typing.Tuple] = 1):
super().__init__()
self.conv = torch.nn.Conv2d(n_in, n_out, kernel_size, stride=stride)
self.activation = torch.nn.SiLU()
self.pooling = torch.nn.MaxPool2d(pool_size) if pool_size is not None else None

def forward(self, x):
# ignore padding here
x = self.conv(x)
x = self.activation(x)
if self.pooling is not None:
x = self.pooling(x)
return x

x = inputs.unsqueeze(1) # (B, 1, T, F)
# VGG block:
vgg_blocks = torch.nn.Sequential(
VggBlock(1, 32, (3, 3)),
)
x = vgg_blocks(x)
x = x.transpose(2, 3).flatten(1, 2) # (B, F, T)
return x

rnd = numpy.random.RandomState(42)
x = rnd.normal(0., 1., (n_batch, n_time, n_feat)).astype("float32")
converter = verify_torch_and_convert_to_returnn(
model_func, inputs=x, returnn_dummy_input_shape=x.shape, validate_allclose_kwargs=dict(rtol=0, atol=5e-3),
inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2})


def test_multiple_outputs():
n_batch, n_time = 3, 7
n_in, n_out = 11, 13
Expand Down