diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f83ce8e032..72f7e553eb 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3774,7 +3774,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs): data = data.copy_as_batch_major() axis = cls._get_axis(data=data, axis=axis) - new_dim = SpatialDim("%s_expand_dims" % name, dim) + new_dim = (FeatureDim if init_axis.lower() == "f" else SpatialDim)("%s_expand_dims" % name, dim) data = data.copy_template(name="%s_output" % name) data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis) if isinstance(init_axis, str):