diff --git a/pytorch_to_returnn/torch/nn/modules/conv.py b/pytorch_to_returnn/torch/nn/modules/conv.py index 7ba5308..573252f 100644 --- a/pytorch_to_returnn/torch/nn/modules/conv.py +++ b/pytorch_to_returnn/torch/nn/modules/conv.py @@ -78,6 +78,15 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: assert all(p == 0 for p in self.padding) # not implemented otherwise assert all(p == 0 for p in self.output_padding) # not implemented otherwise assert self.padding_mode == "zeros" # not implemented otherwise + + from pytorch_to_returnn.naming import Naming + naming = Naming.get_instance() + input_tensor = naming.tensors[input] + in_dim = input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[1]] + in_spatial_dims = [ + input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[dim + len(input.shape)]] + for dim in range(-self.nd, 0)] + d = { "class": "conv", "from": self._get_input_layer_name(input), "activation": None, @@ -85,7 +94,8 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: "n_out": self.out_channels, "filter_size": self.kernel_size, "padding": "valid", - "in_spatial_dims": [self._get_input_axis_to_returnn(input, dim) for dim in range(-self.nd, 0)], + "in_spatial_dims": in_spatial_dims, + "in_dim": in_dim, } if any(s != 1 for s in self.stride): d["strides"] = self.stride @@ -121,11 +131,10 @@ def import_params_torch_to_returnn(self, *, layer: LayerBase, torch_module: _Con def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: LayerBase ) -> Tuple[Tuple[int, ...], Dict[int, int]]: """ - The basic returnn_axis_from_torch_axis should be correct, however, if the size of a dynamic axis changes (e.g. due - to strides and/or padding), this is not covered in the base method and we fix it here. + If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method + and we fix it here. Also, returnn_axis_from_torch_axis needs to be modified for cases where the RETURNN input + feature dim is used as a spatial dim for convolution. """ - torch_shape, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( - inputs_flat=inputs_flat, layer=layer) assert len(inputs_flat) == 1 torch_shape = list(inputs_flat[0].shape) torch_shape[1] = self.out_channels @@ -133,6 +142,41 @@ def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: Layer torch_ax = idx + 2 torch_shape[torch_ax] = (torch_shape[torch_ax] + 2 * self.padding[idx] - self.dilation[idx] * ( self.kernel_size[idx] - 1) - 1) // self.stride[idx] + 1 + + from pytorch_to_returnn.naming import Naming + naming = Naming.get_instance() + input_tensor = naming.tensors[inputs_flat[0]] + in_data = input_tensor.returnn_data + out_data = layer.output + assert in_data.batch_ndim == out_data.batch_ndim + + mapping_out_to_in = {} + # map unchanged dims + for in_dim, in_dim_tag in enumerate(in_data.dim_tags): + if in_dim_tag in out_data.dim_tags: + mapping_out_to_in[out_data.dim_tags.index(in_dim_tag)] = in_dim + + # map channel dim + in_channel = input_tensor.returnn_axis_from_torch_axis[1] + out_channel = [ + dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels] + assert len(out_channel) == 1 + mapping_out_to_in[out_channel[0]] = in_channel + + # map spatial axes, the order is the same as in in_spatial_dims + in_spatial_dims = [ + input_tensor.returnn_axis_from_torch_axis[dim + in_data.batch_ndim] + for dim in range(-self.nd, 0)] + for in_dim, out_dim in zip(in_spatial_dims, out_data.get_spatial_batch_axes()): + mapping_out_to_in[out_dim] = in_dim + + assert len(mapping_out_to_in) == in_data.batch_ndim, ( + f"Not all axes were mapped successfully. In: {in_data}, out: {out_data}") + returnn_axis_from_torch_axis = {} + for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): + torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv + returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis + return tuple(torch_shape), returnn_axis_from_torch_axis diff --git a/tests/test_layers.py b/tests/test_layers.py index a77cdfe..1d742c1 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1779,6 +1779,46 @@ def create_returnn_layer_dict(self, input): inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}) +def test_vgg(): + n_batch, n_time, n_feat = 3, 20, 5 # B, T, F + + def model_func(wrapped_import, inputs: torch.Tensor): + if typing.TYPE_CHECKING or not wrapped_import: + import torch + else: + torch = wrapped_import("torch") + + class VggBlock(torch.nn.Module): + def __init__(self, n_in, n_out, kernel_size, pool_size=None, stride: typing.Union[int, typing.Tuple] = 1): + super().__init__() + self.conv = torch.nn.Conv2d(n_in, n_out, kernel_size, stride=stride) + self.activation = torch.nn.SiLU() + self.pooling = torch.nn.MaxPool2d(pool_size) if pool_size is not None else None + + def forward(self, x): + # ignore padding here + x = self.conv(x) + x = self.activation(x) + if self.pooling is not None: + x = self.pooling(x) + return x + + x = inputs.unsqueeze(1) # (B, 1, T, F) + # VGG block: + vgg_blocks = torch.nn.Sequential( + VggBlock(1, 32, (3, 3)), + ) + x = vgg_blocks(x) + x = x.transpose(2, 3).flatten(1, 2) # (B, F, T) + return x + + rnd = numpy.random.RandomState(42) + x = rnd.normal(0., 1., (n_batch, n_time, n_feat)).astype("float32") + converter = verify_torch_and_convert_to_returnn( + model_func, inputs=x, returnn_dummy_input_shape=x.shape, validate_allclose_kwargs=dict(rtol=0, atol=5e-3), + inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}) + + def test_multiple_outputs(): n_batch, n_time = 3, 7 n_in, n_out = 11, 13