From 00511d93713086bfbbeb04679a8f9a09a5f7b78c Mon Sep 17 00:00:00 2001 From: Frithjof Petrick Date: Thu, 16 Dec 2021 18:01:50 +0100 Subject: [PATCH 1/4] Broadcasting dims no longer match Fix #666 --- docs/configuration_reference/behavior_version.rst | 9 +++++++++ returnn/tf/layers/basic.py | 10 ++++++++-- returnn/tf/util/data.py | 5 ++++- returnn/util/basic.py | 2 +- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/configuration_reference/behavior_version.rst b/docs/configuration_reference/behavior_version.rst index 60c30d016b..af8a270ac0 100644 --- a/docs/configuration_reference/behavior_version.rst +++ b/docs/configuration_reference/behavior_version.rst @@ -22,6 +22,15 @@ and not listing legacy/deprecated parameters. Version History --------------- +Behavior version 11 (2021-12-16) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Broadcasting dims no longer match in :class:`CombineLayer` and others. +This was never needed, instead broadcasting happens in RETURNN automatically to non-existing dims. +To fix this, do not add any broadcasting dims. + +See issue `#666 `__. + Behavior version 10 (2021-12-07) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 7839a4730f..f83ce8e032 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1282,7 +1282,10 @@ def _get_common_input_position_axes(cls, input_data, position_data, old_gather_a :return: (common_axes_input, common_axes_position, specific_input_axes, specific_position_axes), all counted with batch dim. """ - is_equal_opts = dict(allow_same_spatial_dim=True, broadcast_matches=True) + from returnn.util import BehaviorVersion + is_equal_opts = dict(allow_same_spatial_dim=True) + if BehaviorVersion.get() < 11: + is_equal_opts["broadcast_matches"] = True all_dim_tags, tags_dict = Dim.get_all_dimension_tags([input_data, position_data], is_equal_opts=is_equal_opts) input_tags, pos_tags = tags_dict[input_data], tags_dict[position_data] specific_input_axes = [i for i, tag in enumerate(input_tags) if tag not in pos_tags and i != old_gather_axis] @@ -6420,9 +6423,12 @@ def _auto_var_axes(source1, source2, red1, red2): :return: var1 tags, var2 tags :rtype: (list[Dim], list[Dim]) """ + from returnn.util import BehaviorVersion is_equal_opts = dict( treat_feature_as_spatial=True, allow_same_spatial_dim=True, - broadcast_matches=True, undefined_matches=True, derived_matches=True) + undefined_matches=True, derived_matches=True) + if BehaviorVersion.get() < 11: + is_equal_opts["broadcast_matches"] = True all_dim_tags, tags_dict = Dim.get_all_dimension_tags([source1, source2], is_equal_opts=is_equal_opts) tags1, tags2 = tags_dict[source1], tags_dict[source2] var1 = [tag for i, tag in enumerate(tags1) if tag not in tags2 and i not in red1] diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 869093ba9b..74a1c08a42 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -5049,6 +5049,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_ This is always a template, and a new copy. :rtype: Data|None """ + from returnn.util import BehaviorVersion if not sources: return None assert sources @@ -5067,8 +5068,10 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_ common.beam = SearchBeam.get_combined_beam(*[s.beam for s in sources]) is_equal_opts = dict( ignore_feature_dim=ignore_feature_dim, treat_feature_as_spatial=True, - allow_same_spatial_dim=True, broadcast_matches=True, + allow_same_spatial_dim=True, undefined_matches=True, derived_matches=True) + if BehaviorVersion.get() < 11: + is_equal_opts["broadcast_matches"] = True all_dim_tags, tags_dict = Dim.get_all_dimension_tags(sources, is_equal_opts=is_equal_opts) # Check for potential undefined tags, and replace those with defined tags if possible. for axis, dim_tag in enumerate(common.dim_tags): diff --git a/returnn/util/basic.py b/returnn/util/basic.py index e7270d1fdf..c09e41f42b 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -209,7 +209,7 @@ class BehaviorVersion: The version will be set after the config is defined at __main__.init_config() or Engine.__init__() """ - _latest_behavior_version = 10 + _latest_behavior_version = 11 _behavior_version = None # type: typing.Optional[int] @classmethod From 3360e1b2291ff48421173aaad7d3c8738215a31b Mon Sep 17 00:00:00 2001 From: Frithjof Petrick Date: Thu, 16 Dec 2021 18:41:10 +0100 Subject: [PATCH 2/4] Fix tests for new behavior version 11 --- tests/test_TFNetworkLayer.py | 9 +++++---- tests/test_TFUtil.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 7c5e6709aa..75534fb976 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -920,7 +920,8 @@ def test_CombineLayer_broadcast(): net_dict = { "lin1": {"class": "linear", "activation": "sigmoid", "n_out": 5, "from": "data:data"}, "lin2": {"class": "linear", "activation": "sigmoid", "n_out": 1, "from": "data:data"}, - "combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2"]}, + "lin2_squeeze": {"class": "squeeze", "from": "lin2", "axis": "f"}, + "combine": {"class": "combine", "kind": "add", "from": ["lin1", "lin2_squeeze"]}, "output": {"class": "softmax", "loss": "ce", "from": "combine"} } config = Config({"debug_print_layer_output_template": True}) @@ -939,7 +940,7 @@ def test_CombineLayer_broadcast_multiple(): with make_scope() as session: net_dict = { "p1": {"class": "variable", "shape": (5, 5, 3), "add_batch_axis": False}, - "p2": {"class": "variable", "shape": (5, 1, 1), "add_batch_axis": False}, + "p2": {"class": "variable", "shape": (5,), "add_batch_axis": False}, "combine": {"class": "combine", "kind": "add", "from": ["p1", "p2"]}, "output": {"class": "softmax", "loss": "ce", "from": "combine"} } @@ -1275,7 +1276,7 @@ def test_CombineLayer_time_broadcast(): config = Config({ "debug_print_layer_output_template": True, "extern_data": { - "in1": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0}, + "in1": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0}, "in2": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2} } }) @@ -1299,7 +1300,7 @@ def test_CombineLayer_time_broadcast_swapped(): "debug_print_layer_output_template": True, "extern_data": { "in1": {"shape": (n_features, None), "batch_dim_axis": 0, "time_dim_axis": 2}, - "in2": {"shape": (n_features, 1), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0}, + "in2": {"shape": (n_features,), "batch_dim_axis": None, "time_dim_axis": None, "feature_dim_axis": 0}, } }) network = TFNetwork(config=config, train_flag=True) diff --git a/tests/test_TFUtil.py b/tests/test_TFUtil.py index 18e4408ab7..aa47543a47 100644 --- a/tests/test_TFUtil.py +++ b/tests/test_TFUtil.py @@ -608,7 +608,7 @@ def test_Data_get_common_data_extra_static_spatial(): def test_Data_get_common_data_broadcast_multiple(): d1 = Data(name='d_orig', shape=(5, 5, 3), dtype='float32', batch_dim_axis=None) - d2 = Data(name='d_bc', shape=(5, 1, 1), dtype='float32', batch_dim_axis=None) + d2 = Data(name='d_bc', shape=(5,), dtype='float32', batch_dim_axis=None) common = Data.get_common_data([d1, d2]) assert d1.shape == common.shape From a94072b24c72d92bf06be53e15d9d1b6870ff5ed Mon Sep 17 00:00:00 2001 From: Frithjof Petrick Date: Thu, 16 Dec 2021 18:48:09 +0100 Subject: [PATCH 3/4] ExpandDimsLayer: Add FeatureDim tag if axis == "f" --- returnn/tf/layers/basic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index f83ce8e032..d15274410b 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -3774,7 +3774,10 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs): data = data.copy_as_batch_major() axis = cls._get_axis(data=data, axis=axis) - new_dim = SpatialDim("%s_expand_dims" % name, dim) + new_dim = Dim( + kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial, + description="%s_expand_dims" % name, + dimension=dim) data = data.copy_template(name="%s_output" % name) data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis) if isinstance(init_axis, str): From d620a08658977073b6ec2fd477678cdf1756e810 Mon Sep 17 00:00:00 2001 From: Frithjof Petrick Date: Fri, 17 Dec 2021 11:49:38 +0100 Subject: [PATCH 4/4] test_GatherLayer_broadcast_dim --- tests/test_TFNetworkLayer.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 75534fb976..a03c68b81f 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -3400,6 +3400,24 @@ def test_GatherLayer_search_beam(): "initial_output": 0}}}}) +def test_GatherLayer_broadcast_dim(): + from returnn.tf.util.data import batch_dim + head_dim = SpatialDim("head", 1) # previously, this dim would match all others and therefore fail. + round_dim = SpatialDim("round", 2) + chunk_dim = SpatialDim("chunk") + time_dim = SpatialDim("time") + config = Config({"extern_data": { + "source": {"dim_tags": [batch_dim, head_dim, time_dim]}, + "position": {"dim_tags": [batch_dim, head_dim, round_dim, chunk_dim], "dtype": "int32"}}, + "debug_print_layer_output_template": True}) + net = TFNetwork(config=config) + net.construct_from_dict({ + "output": { + 'class': 'gather', 'from': 'data:source', 'position': 'data:position', 'axis': time_dim, + 'out_shape': {batch_dim, head_dim, round_dim, chunk_dim}} + }) + + def test_SliceNdLayer(): n_batch = 5 n_time = 7