Skip to content

Commit

Permalink
RF test_relative_positional_encoding_cross
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Dec 13, 2024
1 parent f02b481 commit 77135fa
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/test_rf_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,27 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)


def test_relative_positional_encoding_cross():
enc_spatial_dim = Dim(Tensor("enc_spatial", [batch_dim], dtype="int32"))
dec_spatial_dim = Dim(Tensor("dec_spatial", [batch_dim], dtype="int32"))
in_dim = Dim(8, name="in")
extern_data = TensorDict(
{
"enc": Tensor("enc", [batch_dim, enc_spatial_dim, in_dim], dtype="float32"),
"dec": Tensor("dec", [batch_dim, dec_spatial_dim, in_dim], dtype="float32"),
}
)

# noinspection PyShadowingNames
def _forward_step(**_kwargs):
out, dim = rf.relative_positional_encoding(
key_value_spatial_dim=enc_spatial_dim, query_spatial_dim=dec_spatial_dim, feat_dim=in_dim
)
out.mark_as_default_output(shape=(dim, in_dim))

run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step)


def test_rel_pos_self_attention():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(8, name="in")
Expand Down

0 comments on commit 77135fa

Please sign in to comment.