diff --git a/tests/test_rf_attention.py b/tests/test_rf_attention.py index 94f61d193..ecdf4af04 100644 --- a/tests/test_rf_attention.py +++ b/tests/test_rf_attention.py @@ -542,6 +542,27 @@ def _forward_step(*, model: _Net, extern_data: TensorDict): run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step) +def test_relative_positional_encoding_cross(): + enc_spatial_dim = Dim(Tensor("enc_spatial", [batch_dim], dtype="int32")) + dec_spatial_dim = Dim(Tensor("dec_spatial", [batch_dim], dtype="int32")) + in_dim = Dim(8, name="in") + extern_data = TensorDict( + { + "enc": Tensor("enc", [batch_dim, enc_spatial_dim, in_dim], dtype="float32"), + "dec": Tensor("dec", [batch_dim, dec_spatial_dim, in_dim], dtype="float32"), + } + ) + + # noinspection PyShadowingNames + def _forward_step(**_kwargs): + out, dim = rf.relative_positional_encoding( + key_value_spatial_dim=enc_spatial_dim, query_spatial_dim=dec_spatial_dim, feat_dim=in_dim + ) + out.mark_as_default_output(shape=(dim, in_dim)) + + run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step) + + def test_rel_pos_self_attention(): time_dim = Dim(Tensor("time", [batch_dim], dtype="int32")) in_dim = Dim(8, name="in")