Skip to content

Commit

Permalink
ExpandDimsLayer: Add FeatureDim tag if axis == "f"
Browse files Browse the repository at this point in the history
  • Loading branch information
Zettelkasten committed Dec 16, 2021
1 parent 4d5fc59 commit 7076bc7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3774,7 +3774,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
data = data.copy_as_batch_major()
axis = cls._get_axis(data=data, axis=axis)

new_dim = SpatialDim("%s_expand_dims" % name, dim)
new_dim = (FeatureDim if init_axis.lower() == "f" else SpatialDim)("%s_expand_dims" % name, dim)
data = data.copy_template(name="%s_output" % name)
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
if isinstance(init_axis, str):
Expand Down

0 comments on commit 7076bc7

Please sign in to comment.