-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convolution over feature dim #135
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,14 +78,24 @@ def create_returnn_layer_dict(self, input: Tensor) -> Dict[str, Any]: | |
assert all(p == 0 for p in self.padding) # not implemented otherwise | ||
assert all(p == 0 for p in self.output_padding) # not implemented otherwise | ||
assert self.padding_mode == "zeros" # not implemented otherwise | ||
|
||
from pytorch_to_returnn.naming import Naming | ||
naming = Naming.get_instance() | ||
input_tensor = naming.tensors[input] | ||
in_dim = input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[1]] | ||
in_spatial_dims = [ | ||
input_tensor.returnn_data.dim_tags[input_tensor.returnn_axis_from_torch_axis[dim + len(input.shape)]] | ||
for dim in range(-self.nd, 0)] | ||
|
||
d = { | ||
"class": "conv", "from": self._get_input_layer_name(input), | ||
"activation": None, | ||
"with_bias": self.bias is not None, | ||
"n_out": self.out_channels, | ||
"filter_size": self.kernel_size, | ||
"padding": "valid", | ||
"in_spatial_dims": [self._get_input_axis_to_returnn(input, dim) for dim in range(-self.nd, 0)], | ||
"in_spatial_dims": in_spatial_dims, | ||
"in_dim": in_dim, | ||
} | ||
if any(s != 1 for s in self.stride): | ||
d["strides"] = self.stride | ||
|
@@ -121,18 +131,52 @@ def import_params_torch_to_returnn(self, *, layer: LayerBase, torch_module: _Con | |
def _get_output_shape_from_returnn(self, inputs_flat: List[Tensor], layer: LayerBase | ||
) -> Tuple[Tuple[int, ...], Dict[int, int]]: | ||
""" | ||
The basic returnn_axis_from_torch_axis should be correct, however, if the size of a dynamic axis changes (e.g. due | ||
to strides and/or padding), this is not covered in the base method and we fix it here. | ||
If the size of a dynamic axis changes (e.g. due to strides and/or padding), this is not covered in the base method | ||
and we fix it here. Also, returnn_axis_from_torch_axis needs to be modified for cases where the RETURNN input | ||
feature dim is used as a spatial dim for convolution. | ||
""" | ||
torch_shape, returnn_axis_from_torch_axis = super(_ConvNd, self)._get_output_shape_from_returnn( | ||
inputs_flat=inputs_flat, layer=layer) | ||
assert len(inputs_flat) == 1 | ||
torch_shape = list(inputs_flat[0].shape) | ||
torch_shape[1] = self.out_channels | ||
for idx in range(self.nd): | ||
torch_ax = idx + 2 | ||
torch_shape[torch_ax] = (torch_shape[torch_ax] + 2 * self.padding[idx] - self.dilation[idx] * ( | ||
self.kernel_size[idx] - 1) - 1) // self.stride[idx] + 1 | ||
|
||
from pytorch_to_returnn.naming import Naming | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why you do it so complicated. This code here should be quite short, straightforward, and not using any heuristics. Your code here is full of heuristics, checking whether you can map all axes, etc. You don't need that. We know exactly how it must map. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
All dims that did not change are trivial, the channel dim can be done as I do it here. How would you do it for the spatial dims? Just assume that the order of spatial dims is the same as in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added an update which does the mapping as I described above. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
What order? Of the RETURNN output? We don't need to guess anything here. We know everything exactly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Except BCHW vs BHWC, but you can just check where out_dim is. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, then that is exactly what I do now, right? |
||
naming = Naming.get_instance() | ||
input_tensor = naming.tensors[inputs_flat[0]] | ||
in_data = input_tensor.returnn_data | ||
out_data = layer.output | ||
assert in_data.batch_ndim == out_data.batch_ndim | ||
|
||
mapping_out_to_in = {} | ||
# map unchanged dims | ||
for in_dim, in_dim_tag in enumerate(in_data.dim_tags): | ||
if in_dim_tag in out_data.dim_tags: | ||
mapping_out_to_in[out_data.dim_tags.index(in_dim_tag)] = in_dim | ||
|
||
# map channel dim | ||
in_channel = input_tensor.returnn_axis_from_torch_axis[1] | ||
out_channel = [ | ||
dim for dim in layer.output.get_static_axes() if layer.output.dim_tags[dim].dimension == self.out_channels] | ||
assert len(out_channel) == 1 | ||
mapping_out_to_in[out_channel[0]] = in_channel | ||
|
||
# map spatial axes, the order is the same as in in_spatial_dims | ||
in_spatial_dims = [ | ||
input_tensor.returnn_axis_from_torch_axis[dim + in_data.batch_ndim] | ||
for dim in range(-self.nd, 0)] | ||
for in_dim, out_dim in zip(in_spatial_dims, out_data.get_spatial_batch_axes()): | ||
mapping_out_to_in[out_dim] = in_dim | ||
|
||
assert len(mapping_out_to_in) == in_data.batch_ndim, ( | ||
f"Not all axes were mapped successfully. In: {in_data}, out: {out_data}") | ||
returnn_axis_from_torch_axis = {} | ||
for returnn_out_axis, returnn_in_axis in mapping_out_to_in.items(): | ||
torch_axis = input_tensor.torch_axis_from_returnn_axis[returnn_in_axis] # torch does not change order for conv | ||
returnn_axis_from_torch_axis[torch_axis] = returnn_out_axis | ||
|
||
return tuple(torch_shape), returnn_axis_from_torch_axis | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't you use
_get_input_axis_to_returnn
for that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because that would return an axis description like
"T"
or"F"
. This would be mapped to a dim tag in theConvLayer
construction. However, in case we do convolution over the feature dim,"F"
would be mapped to thein_dim
, so the new feature dim and not the old one which does not work.