From e8f9102505c0b06f37bbef6580e7de1f3a051be8 Mon Sep 17 00:00:00 2001 From: Frithjof Petrick Date: Thu, 16 Dec 2021 18:48:09 +0100 Subject: [PATCH] ExpandDimsLayer: Add FeatureDim tag if axis == "f" --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index a659f4b79a..0dad2bf29a 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3774,7 +3774,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs): data = data.copy_as_batch_major() axis = cls._get_axis(data=data, axis=axis) - new_dim = SpatialDim("%s_expand_dims" % name, dim) + new_dim = (FeatureDim if init_axis.lower() == "f" else SpatialDim)("%s_expand_dims" % name, dim) data = data.copy_template(name="%s_output" % name) data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis) if isinstance(init_axis, str):