Skip to content

Commit

Permalink
Fix #666, broadcast no longer matches dims (#864)
Browse files Browse the repository at this point in the history
Fix #666.

New behavior version 11 (#508).
  • Loading branch information
Zettelkasten authored Dec 17, 2021
1 parent 74cfd06 commit af7a588
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 10 deletions.
9 changes: 9 additions & 0 deletions docs/configuration_reference/behavior_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ and not listing legacy/deprecated parameters.
Version History
---------------

Behavior version 11 (2021-12-16)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Broadcasting dims no longer match in :class:`CombineLayer` and others.
This was never needed, instead broadcasting happens in RETURNN automatically to non-existing dims.
To fix this, do not add any broadcasting dims.

See issue `#666 <https://github.com/rwth-i6/returnn/issues/666>`__.

Behavior version 10 (2021-12-07)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
15 changes: 12 additions & 3 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,7 +1282,10 @@ def _get_common_input_position_axes(cls, input_data, position_data, old_gather_a
:return: (common_axes_input, common_axes_position, specific_input_axes, specific_position_axes), all counted with
batch dim.
"""
is_equal_opts = dict(allow_same_spatial_dim=True, broadcast_matches=True)
from returnn.util import BehaviorVersion
is_equal_opts = dict(allow_same_spatial_dim=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([input_data, position_data], is_equal_opts=is_equal_opts)
input_tags, pos_tags = tags_dict[input_data], tags_dict[position_data]
specific_input_axes = [i for i, tag in enumerate(input_tags) if tag not in pos_tags and i != old_gather_axis]
Expand Down Expand Up @@ -3771,7 +3774,10 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
data = data.copy_as_batch_major()
axis = cls._get_axis(data=data, axis=axis)

new_dim = SpatialDim("%s_expand_dims" % name, dim)
new_dim = Dim(
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
description="%s_expand_dims" % name,
dimension=dim)
data = data.copy_template(name="%s_output" % name)
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
if isinstance(init_axis, str):
Expand Down Expand Up @@ -6420,9 +6426,12 @@ def _auto_var_axes(source1, source2, red1, red2):
:return: var1 tags, var2 tags
:rtype: (list[Dim], list[Dim])
"""
from returnn.util import BehaviorVersion
is_equal_opts = dict(
treat_feature_as_spatial=True, allow_same_spatial_dim=True,
broadcast_matches=True, undefined_matches=True, derived_matches=True)
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags([source1, source2], is_equal_opts=is_equal_opts)
tags1, tags2 = tags_dict[source1], tags_dict[source2]
var1 = [tag for i, tag in enumerate(tags1) if tag not in tags2 and i not in red1]
Expand Down
5 changes: 4 additions & 1 deletion returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5049,6 +5049,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
This is always a template, and a new copy.
:rtype: Data|None
"""
from returnn.util import BehaviorVersion
if not sources:
return None
assert sources
Expand All @@ -5067,8 +5068,10 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources])
is_equal_opts = dict(
ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True,
allow_same_spatial_dim=True, broadcast_matches=True,
allow_same_spatial_dim=True,
undefined_matches=True, derived_matches=True)
if BehaviorVersion.get() < 11:
is_equal_opts["broadcast_matches"] = True
all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts)
# Check for potential undefined tags, and replace those with defined tags if possible.
for axis, dim_tag in enumerate(common.dim_tags):
Expand Down
2 changes: 1 addition & 1 deletion returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ class BehaviorVersion:
The version will be set after the config is defined at __main__.init_config() or Engine.__init__()
"""

_latest_behavior_version = 10
_latest_behavior_version = 11
_behavior_version = None # type: typing.Optional[int]

@classmethod
Expand Down
27 changes: 23 additions & 4 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,8 @@ def test_CombineLayer_broadcast():
net_dict = {
"lin1": {"class": "linear", "activation": "sigmoid", "n_out": 5, "from": "data:data"},
"lin2": {"class": "linear", "activation": "sigmoid", "n_out": 1, "from": "data:data"},
"combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2"]},
"lin2_squeeze": {"class": "squeeze", "from": "lin2", "axis": "f"},
"combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2_squeeze"]},
"output": {"class": "softmax", "loss": "ce", "from": "combine"}
}
config = Config({"debug_print_layer_output_template": True})
Expand All @@ -939,7 +940,7 @@ def test_CombineLayer_broadcast_multiple():
with make_scope() as session:
net_dict = {
"p1": {"class": "variable", "shape": (5, 5, 3), "add_batch_axis": False},
"p2": {"class": "variable", "shape": (5, 1, 1), "add_batch_axis": False},
"p2": {"class": "variable", "shape": (5,), "add_batch_axis": False},
"combine": {"class": "combine", "kind": "add", "from": ["p1", "p2"]},
"output": {"class": "softmax", "loss": "ce", "from": "combine"}
}
Expand Down Expand Up @@ -1275,7 +1276,7 @@ def test_CombineLayer_time_broadcast():
config = Config({
"debug_print_layer_output_template": True,
"extern_data": {
"in1": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in1": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in2": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2}
}
})
Expand All @@ -1299,7 +1300,7 @@ def test_CombineLayer_time_broadcast_swapped():
"debug_print_layer_output_template": True,
"extern_data": {
"in1": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2},
"in2": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
"in2": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0},
}
})
network = TFNetwork(config=config, train_flag=True)
Expand Down Expand Up @@ -3399,6 +3400,24 @@ def test_GatherLayer_search_beam():
"initial_output": 0}}}})


def test_GatherLayer_broadcast_dim():
from returnn.tf.util.data import batch_dim
head_dim = SpatialDim("head", 1) # previously, this dim would match all others and therefore fail.
round_dim = SpatialDim("round", 2)
chunk_dim = SpatialDim("chunk")
time_dim = SpatialDim("time")
config = Config({"extern_data": {
"source": {"dim_tags": [batch_dim, head_dim, time_dim]},
"position": {"dim_tags": [batch_dim, head_dim, round_dim, chunk_dim], "dtype": "int32"}},
"debug_print_layer_output_template": True})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'gather', 'from': 'data:source', 'position': 'data:position', 'axis': time_dim,
'out_shape': {batch_dim, head_dim, round_dim, chunk_dim}}
})


def test_SliceNdLayer():
n_batch = 5
n_time = 7
Expand Down
2 changes: 1 addition & 1 deletion tests/test_TFUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def test_Data_get_common_data_extra_static_spatial():

def test_Data_get_common_data_broadcast_multiple():
d1 = Data(name='d_orig', shape=(5, 5, 3), dtype='float32', batch_dim_axis=None)
d2 = Data(name='d_bc', shape=(5, 1, 1), dtype='float32', batch_dim_axis=None)
d2 = Data(name='d_bc', shape=(5,), dtype='float32', batch_dim_axis=None)
common = Data.get_common_data([d1, d2])
assert d1.shape == common.shape

Expand Down

0 comments on commit af7a588

Please sign in to comment.