From 9d8c18cf2a85783af6a87c6657bf4ce63606f383 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 11 Dec 2024 15:12:30 +0100 Subject: [PATCH 01/13] RF set_sparse_dim --- returnn/frontend/_backend.py | 9 +++++++++ returnn/frontend/dims.py | 11 +++++++++++ returnn/tf/frontend_layers/_backend.py | 8 ++++++++ 3 files changed, 28 insertions(+) diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 41f14f59f..a837f0d6a 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -1095,6 +1095,15 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Dim) -> Tensor: out.raw_tensor = source.raw_tensor return out + @staticmethod + def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor: + """set sparse dim""" + # This default implementation works fine as long as the backend + # does not have special treatments of Tensor and dim tags itself (like TF net dict backend). + out = source.copy() + out.sparse_dim = sparse_dim + return out + _AllowedReduceModes = {"sum", "max", "min", "mean", "logsumexp", "any", "all", "argmin", "argmax"} @staticmethod diff --git a/returnn/frontend/dims.py b/returnn/frontend/dims.py index 3d58dc93a..940441e8a 100644 --- a/returnn/frontend/dims.py +++ b/returnn/frontend/dims.py @@ -15,6 +15,7 @@ "range_over_dim_strided", "range_over_merged_dims", "replace_dim", + "set_sparse_dim", "dim_match_priority_when_needed", "num_elements_of_shape", "masked_fraction_of_shape", @@ -94,6 +95,16 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Optional[Dim] = None) - return source._raw_backend.replace_dim(source, in_dim=in_dim, out_dim=out_dim), out_dim +def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor: + """ + :param source: + :param sparse_dim: + :return: source with sparse_dim set + """ + # noinspection PyProtectedMember + return source._raw_backend.set_sparse_dim(source, sparse_dim) + + def dim_match_priority_when_needed(dim: Dim, *other_dims: Dim) -> Dim: """ :return: maybe copy of dim with higher match_priority if needed to distinguish from other_dims diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index bdcc55230..1faf1ad31 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -774,6 +774,14 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Dim) -> Tensor: {"class": "reinterpret_data", "set_dim_tags": {in_dim: out_dim}, "from": source}, name="new_dim" ) + @staticmethod + def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor: + """set sparse dim""" + return rfl.make_layer( + {"class": "reinterpret_data", "set_sparse": True, "set_sparse_dim": sparse_dim, "from": source}, + name="set_sparse_dim", + ) + @staticmethod def reduce(source: Tensor, *, mode: str, axis: Union[Dim, Sequence[Dim]], use_mask: bool = True) -> Tensor: """Reduce""" From ad1f02cba4ffa76cfc307731b2591badd4b60725 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 11 Dec 2024 18:25:42 +0100 Subject: [PATCH 02/13] RF concat, check that dims are static --- returnn/frontend/array_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index 3668ebcab..eda20ce58 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -385,6 +385,8 @@ def concat( assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True" if not out_dim: out_dim = sum(d for _, d in sources) + for src, dim in sources[:-1]: + assert dim.is_static(), f"concat {sources}, dim {dim} is not static" # noinspection PyProtectedMember return sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim), out_dim From 54579b88d5289e435f1189238702081a4796626c Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 11 Dec 2024 19:58:42 +0100 Subject: [PATCH 03/13] RF cum_concat_step simplify, pure RF implementation --- returnn/frontend/_backend.py | 15 --------------- returnn/frontend/array_.py | 13 +++++++++---- returnn/tf/frontend_layers/_backend.py | 14 -------------- returnn/torch/frontend/_backend.py | 12 ------------ 4 files changed, 9 insertions(+), 45 deletions(-) diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index a837f0d6a..267751701 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -496,21 +496,6 @@ def pad( """ raise NotImplementedError - @staticmethod - def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor: - """ - Concatenates all previous frames over a time-axis. - See RETURNN :class:`CumConcatLayer` for details. - - :param source: same dims as prev_accum except for the accum axis - :param prev_accum: previous accumulated tensor, shape {..., axis} - :param axis: the axis to accumulate over - :param out_spatial_dim: the spatial dim of the output will be this dim. like axis+1. - :return: accumulated. accumulated shape {..., out_spatial_dim}, - same shape as prev_accum with axis replaced by out_spatial_dim. - """ - raise NotImplementedError - @staticmethod def stack(sources: Sequence[Tensor], *, out_dim: Dim) -> Tensor: """ diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index eda20ce58..91b873ebf 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -509,13 +509,18 @@ def cum_concat_step( :return: (accumulated, out_spatial_dim). accumulated shape {..., out_spatial_dim}, same shape as prev_accum with axis replaced by out_spatial_dim. """ + # Note: Before, we had a backend function just for this. + # In case of TF-layers, this was using CumConcatLayer. + # This would allow for automatic optimization when inside a RecLayer. + # However, we don't really need this for eager frameworks, + # and we want to simplify this for now, + # using pure RF code. if not out_spatial_dim: out_spatial_dim = axis + 1 - # noinspection PyProtectedMember - return ( - source._raw_backend.cum_concat_step(source, prev_accum=prev_accum, axis=axis, out_spatial_dim=out_spatial_dim), - out_spatial_dim, + out, (out_spatial_dim,) = rf.pad( + prev_accum, axes=[axis], padding=[(0, 1)], out_dims=[out_spatial_dim], value=source, handle_dynamic_dims=True ) + return out, out_spatial_dim def stack(sources: Sequence[Tensor], *, out_dim: Optional[Dim] = None) -> Tuple[Tensor, Dim]: diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index 1faf1ad31..4c21b2d04 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -375,20 +375,6 @@ def pad( name="pad", ) - @staticmethod - def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor: - """cum_concat_step""" - return rfl.make_layer( - { - "class": "cum_concat", - "from": source, - "state": {"state": prev_accum}, - "out_spatial_dim": out_spatial_dim, - "axis": axis, - }, - name="cum_concat", - ) - @staticmethod def activation(tensor: Tensor, func: str) -> Tensor: """activation""" diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index 57d01309e..bff749d7b 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -496,18 +496,6 @@ def pad( ) return out - @staticmethod - def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor: - """cum concat step""" - out = prev_accum.copy_template_replace_dim_tag( - axis=prev_accum.get_axis_from_description(axis), - new_dim_tag=out_spatial_dim, - name=f"{source.name}/cum_concat_step", - ) - source_raw = source.copy_compatible_to_dims_raw(prev_accum.dims) - out.raw_tensor = torch.cat((prev_accum.raw_tensor, source_raw), dim=prev_accum.get_axis_from_description(axis)) - return out - @staticmethod def stack(sources: Sequence[Tensor], *, out_dim: Dim) -> Tensor: """stack""" From 55f6bbce895b527eecb76fbe93c71ca625ecdf23 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 11 Dec 2024 20:26:02 +0100 Subject: [PATCH 04/13] RF concat, handle_dynamic_dims --- returnn/frontend/array_.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index 91b873ebf..6839f7982 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -367,6 +367,7 @@ def concat( *sources: Tuple[Tensor, Dim], allow_broadcast: bool = False, out_dim: Optional[Dim] = None, + handle_dynamic_dims: Optional[bool] = None, ) -> Tuple[Tensor, Dim]: """ Concatenates multiple sources in the specified dimension. @@ -376,6 +377,7 @@ def concat( :param sources: list of (tensor, dim) pairs. dim is the axis to concatenate on. :param allow_broadcast: if True, the sources can have different dims, and the result will be broadcasted. :param out_dim: reuse existing dim for the resulting concatenated dim, if given + :param handle_dynamic_dims: :return: concatenated tensor, out_dim """ assert sources @@ -385,8 +387,9 @@ def concat( assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True" if not out_dim: out_dim = sum(d for _, d in sources) - for src, dim in sources[:-1]: - assert dim.is_static(), f"concat {sources}, dim {dim} is not static" + if handle_dynamic_dims is None or handle_dynamic_dims: + for src, dim in sources[:-1]: + assert dim.is_static(), f"concat {sources}, dim {dim} is not static, not yet implemented..." # noinspection PyProtectedMember return sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim), out_dim From 76b08ff72003475517760ade721ea2af33a556d4 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 11 Dec 2024 20:26:25 +0100 Subject: [PATCH 05/13] RF relative_positional_encoding, ignore dyn dims in concat #1666 --- returnn/frontend/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/returnn/frontend/attention.py b/returnn/frontend/attention.py index e3462b076..b9af5b095 100644 --- a/returnn/frontend/attention.py +++ b/returnn/frontend/attention.py @@ -869,6 +869,7 @@ def _make_indices( indices, out_spatial_dim = rf.concat( (q_pos_vec - query_spatial_dim_m1.get_dim_value_tensor(), query_spatial_dim_m1), (kv_pos_vec, key_value_spatial_dim), + handle_dynamic_dims=False, ) if query_offset is not None: indices = indices - query_offset From ceb99e61db189f436470fabd48aaf74ad52d8e32 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 10:53:16 +0100 Subject: [PATCH 06/13] RF pad, support non-scalar value --- returnn/torch/frontend/_backend.py | 29 +++++++++++++++------- tests/test_rf_array.py | 40 ++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/returnn/torch/frontend/_backend.py b/returnn/torch/frontend/_backend.py index bff749d7b..fa56ef25c 100644 --- a/returnn/torch/frontend/_backend.py +++ b/returnn/torch/frontend/_backend.py @@ -452,9 +452,6 @@ def pad( ) -> Tensor: """pad""" assert len(out_dims) == len(axes) == len(padding) - out = source.copy_template_new_dim_tags( - [out_dims[axes.index(dim)] if dim in axes else dim for dim in source.dim_tags], keep_special_axes=True - ) remaining_dims = set(axes) raw_pad = [] for dim in reversed(source.dims): @@ -469,10 +466,24 @@ def pad( ] if not remaining_dims: break - if isinstance(value, Tensor): - assert value.dims == (), f"value {value} must be a scalar" - value = value.raw_tensor - out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value) + # Use torch.nn.functional.pad if possible. + if (isinstance(value, Tensor) and value.dims == ()) or (not isinstance(value, Tensor)): + if isinstance(value, Tensor): + assert value.dims == () + value = value.raw_tensor + out = source.copy_template_new_dim_tags( + [out_dims[axes.index(dim)] if dim in axes else dim for dim in source.dim_tags], keep_special_axes=True + ) + out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value) + else: # Fallback to concat. + assert isinstance(value, Tensor) + assert all(dim in source.dims and dim not in axes for dim in value.dims) + assert len(axes) == 1 # not implemented otherwise currently... + ext_dim = Dim(1, name="ext") + value_ext = rf.expand_dim(value, ext_dim) + out = TorchBackend.concat( + (source, axes[0]), (value_ext, ext_dim), allow_broadcast=True, out_dim=out_dims[0] + ) if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims: if all(right == 0 for right in raw_pad[1::2]) and mode != "circular": pass # no masking needed @@ -490,9 +501,9 @@ def pad( rf.copy_to_device((left + middle).dyn_size_ext, out.device), ) out.raw_tensor = torch.where( - mask.copy_compatible_to(out, check_dtype=False, check_sparse=False).raw_tensor, + mask.copy_compatible_to_dims_raw(out.dims), out.raw_tensor, - value, + value.copy_compatible_to_dims_raw(out.dims) if isinstance(value, Tensor) else value, ) return out diff --git a/tests/test_rf_array.py b/tests/test_rf_array.py index 55d90312d..bc5c6f293 100644 --- a/tests/test_rf_array.py +++ b/tests/test_rf_array.py @@ -275,6 +275,46 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): assert all(out_.raw_tensor[b, seq_len] == 1.0) +def test_pad_time_right_non_scalar(): + time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) + in_dim = Dim(7, name="in") + extern_data = TensorDict( + { + "data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"), + "value": Tensor("value", [batch_dim], dtype="float32"), + } + ) + + # noinspection PyShadowingNames + def _forward_step(*, extern_data: TensorDict, **_kwargs): + data, value = extern_data["data"], extern_data["value"] + data.mark_as_output("data", shape=(batch_dim, time_dim, in_dim)) + value.mark_as_output("value", shape=(batch_dim,)) + out, (new_time,) = rf.pad(data, axes=[time_dim], padding=[(0, 1)], value=value) + out.mark_as_default_output(shape=(batch_dim, new_time, in_dim)) + + # TF-layers currently does not support this. + res = run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step, test_tensorflow=False) + data_: Tensor = res["data"] + value_: Tensor = res["value"] + out_: Tensor = res["output"] + assert data_.dims == (batch_dim, time_dim, in_dim) + new_time_dim = out_.dims[1] + assert out_.dims == (batch_dim, new_time_dim, in_dim) and new_time_dim != time_dim + assert time_dim.dyn_size_ext.dims == new_time_dim.dyn_size_ext.dims == (batch_dim,) + batch_size = batch_dim.get_dim_value() + assert batch_size > 1 + assert len(set(time_dim.dyn_size_ext.raw_tensor)) > 1 # not all the same + for b in range(batch_size): + seq_len = time_dim.dyn_size_ext.raw_tensor[b] + new_seq_len = new_time_dim.dyn_size_ext.raw_tensor[b] + print(f"batch {b}, seq_len {seq_len}, new_seq_len {new_seq_len}") + assert new_seq_len == seq_len + 1 + np.testing.assert_allclose(data_.raw_tensor[b, :seq_len], out_.raw_tensor[b, :seq_len]) + print(out_.raw_tensor[b]) + assert all(out_.raw_tensor[b, seq_len] == value_.raw_tensor[b]) + + def test_stack(): batch_dim_ = Dim(3, name="batch") time_dim = Dim(5, name="time") From 32fcbbbfa4277507917cdf70d1f0e85e8ad4156d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 14:28:28 +0100 Subject: [PATCH 07/13] RF TF-layers concat fix out_dim --- returnn/tf/frontend_layers/_backend.py | 1 - 1 file changed, 1 deletion(-) diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index 4c21b2d04..a75089606 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -342,7 +342,6 @@ def concat( opts = {} if allow_broadcast: opts["allow_broadcast"] = True - out_dim = sum(d for _, d in sources) return rfl.make_layer( {"class": "concat", "from": sources, "out_dim": out_dim, **opts}, name="concat", From 0ae04a81b0c99150b2a9f38662415d05bf820726 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 14:34:30 +0100 Subject: [PATCH 08/13] TF ConcatLayer, fix explicit custom out_dim --- returnn/tf/layers/basic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index c18f4c159..16f340a03 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -517,11 +517,12 @@ def get_out_data_from_opts(cls, name, sources, out_dim=None, **kwargs): dimension = 0 for tag in concat_dim_tags: dimension += tag.dimension + sum_concat_dim_tags: Dim = sum(concat_dim_tags) if not out_dim: - out_dim = sum(concat_dim_tags) + out_dim = sum_concat_dim_tags assert isinstance(out_dim, Dim) - else: - sum(concat_dim_tags).declare_same_as(out_dim) + elif not out_dim.is_dim_known(): + sum_concat_dim_tags.declare_same_as(out_dim) assert out_dim.dimension == dimension def _as_common(x, axis): From 431080374302c77a309a46ff6c6513ef2f2a3a77 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 14:34:54 +0100 Subject: [PATCH 09/13] RF TF-layers, fix concat fix explicit out_dim --- returnn/tf/frontend_layers/_backend.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index a75089606..dd6e109dd 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -342,6 +342,20 @@ def concat( opts = {} if allow_broadcast: opts["allow_broadcast"] = True + dim_deps = rfl.get_dim_deps(out_dim) + sources_dims = set() + for source, _ in sources: + sources_dims.update(source.dims) + need_explicit_dim_deps = False + for dim in dim_deps: + if dim not in sources_dims: + need_explicit_dim_deps = True + break + if need_explicit_dim_deps: + source0 = rfl.make_layer( + {"class": "copy", "from": sources[0][0], "extra_deps": dim_deps}, name="concat_extra_dim_deps" + ) + sources = ((source0, sources[0][1]),) + sources[1:] return rfl.make_layer( {"class": "concat", "from": sources, "out_dim": out_dim, **opts}, name="concat", From 26a136f756a18ee2fa2e36be8e9f744e01a74873 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 14:36:17 +0100 Subject: [PATCH 10/13] RF relative_positional_encoding, fix internal indices spatial dim Specifically for cross attention, it could happen that max(q_seq_len+k_seq_len-1) != shape. --- returnn/frontend/attention.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/returnn/frontend/attention.py b/returnn/frontend/attention.py index b9af5b095..482ef39f0 100644 --- a/returnn/frontend/attention.py +++ b/returnn/frontend/attention.py @@ -862,13 +862,23 @@ def _make_indices( query_spatial_dim_m1 = query_spatial_dim - 1 q_pos_vec = rf.range_over_dim(query_spatial_dim_m1) # [q_len-1] + # The masking in the output is quite custom (left+right masking), so our seq lens don't make sense, + # and might even cause to fail some tests (that e.g. max(q_seq_len+k_seq_len-1) == shape). + out_spatial_dim = Dim( + query_spatial_dim_m1.get_dim_value_tensor() + key_value_spatial_dim.get_dim_value_tensor(), + name=f"2*{query_spatial_dim.description}-1" + if (query_spatial_dim == key_value_spatial_dim) + else f"{query_spatial_dim.description}+{key_value_spatial_dim.description}-1", + ) + # We want to have all distances as in rf.combine_bc(kv_pos_vec, "-", q_pos_vec) with shape [q_len,kv_len]. # We want to store only non-duplicates. # The min value is with kv_pos=0, q_pos=q_len-1: -(q_len-1) # The max value is with kv_pos=kv_len-1, q_pos=0: k_len-1 - indices, out_spatial_dim = rf.concat( + indices, _ = rf.concat( (q_pos_vec - query_spatial_dim_m1.get_dim_value_tensor(), query_spatial_dim_m1), (kv_pos_vec, key_value_spatial_dim), + out_dim=out_spatial_dim, handle_dynamic_dims=False, ) if query_offset is not None: From 7b2882a7e631933c5d9f353dbb204f9550c57be7 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 14:43:19 +0100 Subject: [PATCH 11/13] RF test_relative_positional_encoding_cross --- tests/test_rf_attention.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index 94f61d193..ecdf4af04 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -542,6 +542,27 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_relative_positional_encoding_cross(): + enc_spatial_dim = Dim(Tensor("enc_spatial", [batch_dim], dtype="int32")) + dec_spatial_dim = Dim(Tensor("dec_spatial", [batch_dim], dtype="int32")) + in_dim = Dim(8, name="in") + extern_data = TensorDict( + { + "enc": Tensor("enc", [batch_dim, enc_spatial_dim, in_dim], dtype="float32"), + "dec": Tensor("dec", [batch_dim, dec_spatial_dim, in_dim], dtype="float32"), + } + ) + + # noinspection PyShadowingNames + def _forward_step(**_kwargs): + out, dim = rf.relative_positional_encoding( + key_value_spatial_dim=enc_spatial_dim, query_spatial_dim=dec_spatial_dim, feat_dim=in_dim + ) + out.mark_as_default_output(shape=(dim, in_dim)) + + run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step) + + def test_rel_pos_self_attention(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(8, name="in") From d5074609f04d761a94e59fdf8565f2dce0d5e1ba Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 15:23:53 +0100 Subject: [PATCH 12/13] RF test_rel_pos_self_attention, extend by batch test Fix #1666 --- tests/test_rf_attention.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index ecdf4af04..763b41dbc 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -571,6 +571,7 @@ def test_rel_pos_self_attention(): "data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"), } ) + check_batching = False class _Net(rf.Module): def __init__(self): @@ -586,6 +587,32 @@ def __init__(self): def __call__(self, x: Tensor, *, axis: Dim) -> Tensor: """forward""" + nonlocal check_batching + if check_batching: + assert rf.is_executing_eagerly() + assert batch_dim in x.dims and axis != batch_dim + y = self.self_att(x, axis=axis) + for b in range(batch_dim.get_dim_value()): + x_b = rf.gather(x, axis=batch_dim, indices=b) + assert batch_dim in axis.dyn_size_ext.dims # current assumption... + seq_len = rf.gather(axis.dyn_size_ext, axis=batch_dim, indices=b) + axis_b = Dim(seq_len) + # Note: The current order (replace_dim and then slice) is somewhat dependent + # on the current internal behavior of gather and replace_dim, + # which might change at some point... + x_b, _ = rf.replace_dim(x_b, in_dim=axis, out_dim=axis_b) + x_b, _ = rf.slice(x_b, axis=axis_b, start=0, end=seq_len, out_dim=axis_b) + y_b = self.self_att(x_b, axis=axis_b) + y_b_ = rf.gather(y, axis=batch_dim, indices=b) + y_b_, _ = rf.replace_dim(y_b_, in_dim=axis, out_dim=axis_b) + y_b_, _ = rf.slice(y_b_, axis=axis_b, start=0, end=seq_len, out_dim=axis_b) + y_b_ = y_b_.copy_transpose(y_b.dims) + # Assuming PyTorch... + np.testing.assert_almost_equal( + y_b.raw_tensor.cpu().detach().numpy(), y_b_.raw_tensor.cpu().detach().numpy(), decimal=5 + ) + return y + return self.self_att(x, axis=axis) # noinspection PyShadowingNames @@ -594,6 +621,8 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): out.mark_as_default_output(shape=(batch_dim, time_dim, model.out_dim)) run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) + check_batching = True + run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step, test_tensorflow=False) def test_sinusoidal_positional_encoding(): From a112ff4328c91cdc8257331ff24701fe5994d12f Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Dec 2024 15:36:54 +0100 Subject: [PATCH 13/13] RF test_e_branchformer, disable small subcheck for now --- tests/test_rf_encoder_conformer.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/test_rf_encoder_conformer.py b/tests/test_rf_encoder_conformer.py index 261f6f25e..a086f24db 100644 --- a/tests/test_rf_encoder_conformer.py +++ b/tests/test_rf_encoder_conformer.py @@ -359,15 +359,16 @@ def _tensor(x: torch.Tensor, name: str, dims: Sequence[Dim]) -> Tensor: (batch_dim, num_heads_dim, enc_spatial_dim, key_dim_per_head), ), # Check RelPositionalEncoding vs our relative_positional_encoding - ( - (rf.RelPosSelfAttention.__call__, 0, "pos_emb", 0), - (RelPositionMultiHeadedAttention.forward, 0, "pos_emb", 0), - lambda x, **_: _tensor( - _reorder_rel_pos_emb_espnet_to_rf(x.squeeze(dim=0)), - "pos_emb", - [enc_spatial_dim - 1 + enc_spatial_dim, model_dim], - ), - ), + # Currently disabled this check, as the dim tags are different now... + # ( + # (rf.RelPosSelfAttention.__call__, 0, "pos_emb", 0), + # (RelPositionMultiHeadedAttention.forward, 0, "pos_emb", 0), + # lambda x, **_: _tensor( + # _reorder_rel_pos_emb_espnet_to_rf(x.squeeze(dim=0)), + # "pos_emb", + # [enc_spatial_dim - 1 + enc_spatial_dim, model_dim], + # ), + # ), ( (EBranchformerLayer.__call__, 0, "x_mhsa", 0), (EBranchformerEncoderLayer.forward, 0, "x_att", 0),