Skip to content

Commit

Permalink
Broadcasting dims no longer match
Browse files Browse the repository at this point in the history
Fix #666
  • Loading branch information
Zettelkasten committed Dec 16, 2021
1 parent 43c0c00 commit 0b38e4c
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
9 changes: 9 additions & 0 deletions docs/configuration_reference/behavior_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ and not listing legacy/deprecated parameters.
Version History
---------------

Behavior version 11 (2021-12-16)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Broadcasting dims no longer match in :class:`CombineLayer` and others.
This was never needed, instead broadcasting happens in RETURNN automatically to non-existing dims.
To fix this, do not add any broadcasting dims.

See issue `#666 <https://github.com/rwth-i6/returnn/issues/666>`__.

Behavior version 10 (2021-12-07)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
10 changes: 8 additions & 2 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,10 @@ def _get_common_input_position_axes(cls, input_data, position_data, old_gather_a
:return: (common_axes_input, common_axes_position, specific_input_axes, specific_position_axes), all counted with
batch dim.
"""
is_equal_opts = dict(allow_same_spatial_dim=True, broadcast_matches=True)
from returnn.util import BehaviorVersion
is_equal_opts = dict(allow_same_spatial_dim=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([input_data, position_data], is_equal_opts=is_equal_opts)
input_tags, pos_tags = tags_dict[input_data], tags_dict[position_data]
specific_input_axes = [i for i, tag in enumerate(input_tags) if tag not in pos_tags and i != old_gather_axis]
Expand Down Expand Up @@ -6420,9 +6423,12 @@ def _auto_var_axes(source1, source2, red1, red2):
:return: var1 tags, var2 tags
:rtype: (list[Dim], list[Dim])
"""
from returnn.util import BehaviorVersion
is_equal_opts = dict(
treat_feature_as_spatial=True, allow_same_spatial_dim=True,
broadcast_matches=True, undefined_matches=True, derived_matches=True)
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([source1, source2], is_equal_opts=is_equal_opts)
tags1, tags2 = tags_dict[source1], tags_dict[source2]
var1 = [tag for i, tag in enumerate(tags1) if tag not in tags2 and i not in red1]
Expand Down
5 changes: 4 additions & 1 deletion returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5049,6 +5049,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
This is always a template, and a new copy.
:rtype: Data|None
"""
from returnn.util import BehaviorVersion
if not sources:
return None
assert sources
Expand All @@ -5067,8 +5068,10 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources])
is_equal_opts = dict(
ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True,
allow_same_spatial_dim=True, broadcast_matches=True,
allow_same_spatial_dim=True,
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts)
# Check for potential undefined tags, and replace those with defined tags if possible.
for axis, dim_tag in enumerate(common.dim_tags):
Expand Down
2 changes: 1 addition & 1 deletion returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class BehaviorVersion:
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
"""

_latest_behavior_version = 10
_latest_behavior_version = 11
_behavior_version = None # type: typing.Optional[int]

@classmethod
Expand Down

0 comments on commit 0b38e4c

Please sign in to comment.