From 9d4057a7b2b925b7842417e7bc41ea8888abf3fa Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 7 Dec 2022 11:07:01 +0100 Subject: [PATCH 1/4] add vgg test --- tests/test_layers.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/test_layers.py b/tests/test_layers.py index a77cdfe..1d742c1 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -1779,6 +1779,46 @@ def create_returnn_layer_dict(self, input): inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}) +def test_vgg(): + n_batch, n_time, n_feat = 3, 20, 5 # B, T, F + + def model_func(wrapped_import, inputs: torch.Tensor): + if typing.TYPE_CHECKING or not wrapped_import: + import torch + else: + torch = wrapped_import("torch") + + class VggBlock(torch.nn.Module): + def __init__(self, n_in, n_out, kernel_size, pool_size=None, stride: typing.Union[int, typing.Tuple] = 1): + super().__init__() + self.conv = torch.nn.Conv2d(n_in, n_out, kernel_size, stride=stride) + self.activation = torch.nn.SiLU() + self.pooling = torch.nn.MaxPool2d(pool_size) if pool_size is not None else None + + def forward(self, x): + # ignore padding here + x = self.conv(x) + x = self.activation(x) + if self.pooling is not None: + x = self.pooling(x) + return x + + x = inputs.unsqueeze(1) # (B, 1, T, F) + # VGG block: + vgg_blocks = torch.nn.Sequential( + VggBlock(1, 32, (3, 3)), + ) + x = vgg_blocks(x) + x = x.transpose(2, 3).flatten(1, 2) # (B, F, T) + return x + + rnd = numpy.random.RandomState(42) + x = rnd.normal(0., 1., (n_batch, n_time, n_feat)).astype("float32") + converter = verify_torch_and_convert_to_returnn( + model_func, inputs=x, returnn_dummy_input_shape=x.shape, validate_allclose_kwargs=dict(rtol=0, atol=5e-3), + inputs_data_kwargs={"shape": (None, n_feat), "batch_dim_axis": 0, "time_dim_axis": 1, "feature_dim_axis": 2}) + + def test_multiple_outputs(): n_batch, n_time = 3, 7 n_in, n_out = 11, 13 From 8311196040e56e54fd2671318f60c825e33a9015 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Wed, 7 Dec 2022 13:05:18 +0100 Subject: [PATCH 2/4] ConvNd: set in_dim and in_spatial_dims dim tags --- pytorch_to_returnn/torch/nn/modules/conv.py | 55 +++++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/pytorch_to_returnn/torch/nn/modules/conv.py b/pytorch_to_returnn/torch/nn/modules/conv.py index 7ba5308..03b795b 100644 --- a/pytorch_to_returnn/torch/nn/modules/conv.py +++ b/pytorch_to_returnn/torch/nn/modules/conv.py @@ -78,6 +78,15 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: assert all(p == 0 for p in self.padding) # not implemented otherwise assert all(p == 0 for p in self.output_padding) # not implemented otherwise assert self.padding_mode == "zeros" # not implemented otherwise + + from pytorch_to_returnn.naming import Naming + naming = Naming.get_instance() + input_tensor = naming.tensors[input] + in_dim = input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[1]] + in_spatial_dims = [ + input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[dim + len(input.shape)]] + for dim in range(-self.nd, 0)] + d = { "class": "conv", "from": self._get_input_layer_name(input), "activation": None, @@ -85,7 +94,8 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: "n_out": self.out_channels, "filter_size": self.kernel_size, "padding": "valid", - "in_spatial_dims": [self._get_input_axis_to_returnn(input, dim) for dim in range(-self.nd, 0)], + "in_spatial_dims": in_spatial_dims, + "in_dim": in_dim, } if any(s != 1 for s in self.stride): d["strides"] = self.stride @@ -121,11 +131,10 @@ def import_params_torch_to_returnn(self, *, layer: LayerBase, torch_module: _Con def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: LayerBase ) -> Tuple[Tuple[int, ...], Dict[int, int]]: """ - The basic returnn_axis_from_torch_axis should be correct, however, if the size of a dynamic axis changes (e.g. due - to strides and/or padding), this is not covered in the base method and we fix it here. + If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method + and we fix it here. Also, the basic returnn_axis_from_torch_axis fails if the RETURNN input feature dim is used as a + spatial dim for convolution. We try to cover this here and use the basic implementation as a fallback. """ - torch_shape, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( - inputs_flat=inputs_flat, layer=layer) assert len(inputs_flat) == 1 torch_shape = list(inputs_flat[0].shape) torch_shape[1] = self.out_channels @@ -133,6 +142,42 @@ def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: Layer torch_ax = idx + 2 torch_shape[torch_ax] = (torch_shape[torch_ax] + 2 * self.padding[idx] - self.dilation[idx] * ( self.kernel_size[idx] - 1) - 1) // self.stride[idx] + 1 + + from pytorch_to_returnn.naming import Naming + naming = Naming.get_instance() + input_tensor = naming.tensors[inputs_flat[0]] + in_data = input_tensor.returnn_data + out_data = layer.output + assert in_data.batch_ndim == out_data.batch_ndim + + mapping_out_to_in = {} + if in_data.batch_dim_axis is not None and out_data.batch_dim_axis is not None: + mapping_out_to_in[out_data.batch_dim_axis] = in_data.batch_dim_axis + if in_data.time_dim_axis and out_data.time_dim_axis: + mapping_out_to_in[out_data.time_dim_axis] = in_data.time_dim_axis + in_channel = input_tensor.returnn_axis_from_torch_axis[1] + out_channel = [ + dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels] + if len(out_channel) == 1: + mapping_out_to_in[out_channel[0]] = in_channel + + if len(mapping_out_to_in) == in_data.batch_ndim - 1: + # only one axis is missing, just take remaining axis + remaining_in = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.keys())) + remaining_out = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.values())) + assert len(remaining_in) == 1 and len(remaining_out) == 1 + mapping_out_to_in[remaining_out.pop()] = remaining_in.pop() + + if len(mapping_out_to_in) == in_data.batch_ndim: + # found all axes, so we can proceed + returnn_axis_from_torch_axis = {} + for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): + torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv + returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis + else: + # did not find all axes, so fall back to (possibly faulty) default mapping + _, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( + inputs_flat=inputs_flat, layer=layer) return tuple(torch_shape), returnn_axis_from_torch_axis From 758862dd51bc367c9b72b2acf0607ced288c7054 Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 9 Dec 2022 08:28:09 +0100 Subject: [PATCH 3/4] small fix --- pytorch_to_returnn/torch/nn/modules/conv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_to_returnn/torch/nn/modules/conv.py b/pytorch_to_returnn/torch/nn/modules/conv.py index 03b795b..37eb3bc 100644 --- a/pytorch_to_returnn/torch/nn/modules/conv.py +++ b/pytorch_to_returnn/torch/nn/modules/conv.py @@ -163,8 +163,8 @@ def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: Layer if len(mapping_out_to_in) == in_data.batch_ndim - 1: # only one axis is missing, just take remaining axis - remaining_in = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.keys())) - remaining_out = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.values())) + remaining_in = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.values())) + remaining_out = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.keys())) assert len(remaining_in) == 1 and len(remaining_out) == 1 mapping_out_to_in[remaining_out.pop()] = remaining_in.pop() From 44d915299a73c37a468132cd60d17d0309965e2f Mon Sep 17 00:00:00 2001 From: Peter Vieting Date: Fri, 9 Dec 2022 11:31:09 +0100 Subject: [PATCH 4/4] remove heuristics --- pytorch_to_returnn/torch/nn/modules/conv.py | 51 ++++++++++----------- 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/pytorch_to_returnn/torch/nn/modules/conv.py b/pytorch_to_returnn/torch/nn/modules/conv.py index 37eb3bc..573252f 100644 --- a/pytorch_to_returnn/torch/nn/modules/conv.py +++ b/pytorch_to_returnn/torch/nn/modules/conv.py @@ -132,8 +132,8 @@ def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: Layer ) -> Tuple[Tuple[int, ...], Dict[int, int]]: """ If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method - and we fix it here. Also, the basic returnn_axis_from_torch_axis fails if the RETURNN input feature dim is used as a - spatial dim for convolution. We try to cover this here and use the basic implementation as a fallback. + and we fix it here. Also, returnn_axis_from_torch_axis needs to be modified for cases where the RETURNN input + feature dim is used as a spatial dim for convolution. """ assert len(inputs_flat) == 1 torch_shape = list(inputs_flat[0].shape) @@ -151,33 +151,32 @@ def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: Layer assert in_data.batch_ndim == out_data.batch_ndim mapping_out_to_in = {} - if in_data.batch_dim_axis is not None and out_data.batch_dim_axis is not None: - mapping_out_to_in[out_data.batch_dim_axis] = in_data.batch_dim_axis - if in_data.time_dim_axis and out_data.time_dim_axis: - mapping_out_to_in[out_data.time_dim_axis] = in_data.time_dim_axis + # map unchanged dims + for in_dim, in_dim_tag in enumerate(in_data.dim_tags): + if in_dim_tag in out_data.dim_tags: + mapping_out_to_in[out_data.dim_tags.index(in_dim_tag)] = in_dim + + # map channel dim in_channel = input_tensor.returnn_axis_from_torch_axis[1] out_channel = [ dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels] - if len(out_channel) == 1: - mapping_out_to_in[out_channel[0]] = in_channel - - if len(mapping_out_to_in) == in_data.batch_ndim - 1: - # only one axis is missing, just take remaining axis - remaining_in = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.values())) - remaining_out = set(range(in_data.batch_ndim)).difference(set(mapping_out_to_in.keys())) - assert len(remaining_in) == 1 and len(remaining_out) == 1 - mapping_out_to_in[remaining_out.pop()] = remaining_in.pop() - - if len(mapping_out_to_in) == in_data.batch_ndim: - # found all axes, so we can proceed - returnn_axis_from_torch_axis = {} - for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): - torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv - returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis - else: - # did not find all axes, so fall back to (possibly faulty) default mapping - _, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( - inputs_flat=inputs_flat, layer=layer) + assert len(out_channel) == 1 + mapping_out_to_in[out_channel[0]] = in_channel + + # map spatial axes, the order is the same as in in_spatial_dims + in_spatial_dims = [ + input_tensor.returnn_axis_from_torch_axis[dim + in_data.batch_ndim] + for dim in range(-self.nd, 0)] + for in_dim, out_dim in zip(in_spatial_dims, out_data.get_spatial_batch_axes()): + mapping_out_to_in[out_dim] = in_dim + + assert len(mapping_out_to_in) == in_data.batch_ndim, ( + f"Not all axes were mapped successfully. In: {in_data}, out: {out_data}") + returnn_axis_from_torch_axis = {} + for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): + torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv + returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis + return tuple(torch_shape), returnn_axis_from_torch_axis