Skip to content

Commit

Permalink
Dim broadcasting only allowed for auto-generated dims (#1388)
Browse files Browse the repository at this point in the history
Auto-generated dims are those via the legacy n_out
or all other internal dims.
Any dim tags created explicitly by the user
should never be treated as broadcast dims.

See also #666.
  • Loading branch information
albertz authored Sep 5, 2023
1 parent a9eebce commit 0f25d29
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 3 deletions.
8 changes: 7 additions & 1 deletion returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,7 +1234,13 @@ def is_equal(
if other_kind == DimTypes.Feature or not other_kind:
other_kind = DimTypes.Spatial
if self.dimension != other.dimension:
if broadcast_matches and (self.dimension == 1 or other.dimension == 1):
if broadcast_matches and (
# Only auto-generated dim tags are allowed to be treated as broadcastable.
# This was another suggestion from here: https://github.com/rwth-i6/returnn/issues/666
# It was not implemented like this because the auto_generated flag was only introduced later.
(self.dimension == 1 and self.auto_generated)
or (other.dimension == 1 and other.auto_generated)
):
pass # pass on
else:
return False
Expand Down
8 changes: 7 additions & 1 deletion returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,13 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None) -> _t.Tenso
if dim_tag and dim_tag.dimension == 1 and dim_tag.batch == batch_info:
pass # keep it
else:
dim_tag = Dim(kind=Dim.Types.Batch, description="batch-broadcast", dimension=1, batch=batch_info)
dim_tag = Dim(
kind=Dim.Types.Batch,
description="batch-broadcast",
dimension=1,
batch=batch_info,
auto_generated=True,
)
return self.copy_add_batch_dim(batch_dim_axis=axis, batch=batch_info, dim_tag=dim_tag)

data_opts = self.get_kwargs()
Expand Down
31 changes: 30 additions & 1 deletion tests/test_TFUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import tensorflow as tf
from returnn.tf.util.basic import *
from returnn.tf.util.data import SpatialDim, FeatureDim
from returnn.tf.network import ExternData
import returnn.tf.compat as tf_compat
from nose.tools import assert_equal, assert_not_equal, assert_is_instance, assert_is, assert_in, assert_true
from numpy.testing.utils import assert_almost_equal, assert_allclose
from pprint import pprint
import contextlib
import unittest
import numpy.testing
from returnn.util import better_exchook
Expand Down Expand Up @@ -762,6 +762,35 @@ def test_Data_get_common_data_broadcast_multiple_dim_tags():
assert out.dim_tags == (batch_dim, time_dim, input_dim, feat_dim)


@contextlib.contextmanager
def set_behavior_version(version: int):
"""
This is a context manager which sets the behavior version to the given value.
"""
from returnn.util.basic import BehaviorVersion

# noinspection PyProtectedMember
old = BehaviorVersion._behavior_version
try:
BehaviorVersion._behavior_version = version
yield
finally:
BehaviorVersion._behavior_version = old


def test_Data_get_common_data_no_broadcast_for_explicit():
from returnn.tf.util.data import batch_dim

time_dim = SpatialDim("time")
input_dim = FeatureDim("input", 3)
feat_dim = FeatureDim("feat", 1)
a = Data("a", dim_tags=[batch_dim, time_dim, input_dim])
b = Data("b", dim_tags=[feat_dim])
with set_behavior_version(0):
out = Data.get_common_data([a, b], allow_broadcast_all_sources=True)
assert out.dim_tags == (batch_dim, time_dim, input_dim, feat_dim)


def test_Data_get_common_data_extra2_static_spatial():
d1 = Data(name="t", shape=(None, 32, 32, 128), dtype="float32", auto_create_placeholders=True)
d2 = Data(name="r", shape=(None, 32, 32, 128), dtype="float32", auto_create_placeholders=True)
Expand Down

0 comments on commit 0f25d29

Please sign in to comment.