Skip to content

Commit

Permalink
RF relative_positional_encoding, fix internal indices spatial dim
Browse files Browse the repository at this point in the history
Specifically for cross attention, it could happen that
max(q_seq_len+k_seq_len-1) != shape.
  • Loading branch information
albertz committed Dec 13, 2024
1 parent 4310803 commit 26a136f
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion returnn/frontend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,13 +862,23 @@ def _make_indices(
query_spatial_dim_m1 = query_spatial_dim - 1
q_pos_vec = rf.range_over_dim(query_spatial_dim_m1) # [q_len-1]

# The masking in the output is quite custom (left+right masking), so our seq lens don't make sense,
# and might even cause to fail some tests (that e.g. max(q_seq_len+k_seq_len-1) == shape).
out_spatial_dim = Dim(
query_spatial_dim_m1.get_dim_value_tensor() + key_value_spatial_dim.get_dim_value_tensor(),
name=f"2*{query_spatial_dim.description}-1"
if (query_spatial_dim == key_value_spatial_dim)
else f"{query_spatial_dim.description}+{key_value_spatial_dim.description}-1",
)

# We want to have all distances as in rf.combine_bc(kv_pos_vec, "-", q_pos_vec) with shape [q_len,kv_len].
# We want to store only non-duplicates.
# The min value is with kv_pos=0, q_pos=q_len-1: -(q_len-1)
# The max value is with kv_pos=kv_len-1, q_pos=0: k_len-1
indices, out_spatial_dim = rf.concat(
indices, _ = rf.concat(
(q_pos_vec - query_spatial_dim_m1.get_dim_value_tensor(), query_spatial_dim_m1),
(kv_pos_vec, key_value_spatial_dim),
out_dim=out_spatial_dim,
handle_dynamic_dims=False,
)
if query_offset is not None:
Expand Down

0 comments on commit 26a136f

Please sign in to comment.