Skip to content

Commit

Permalink
Fix Ops (aten::_scaled_dot_product_efficient_attention) | feat (torch…
Browse files Browse the repository at this point in the history
…lib) (#1833)

attn_bias should be used.
  • Loading branch information
titaiwangms authored Aug 29, 2024
1 parent d90b102 commit e037aa0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1974,7 +1974,7 @@ def aten__scaled_dot_product_efficient_attention(
"""_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)"""

result = aten_scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale
query, key, value, attn_bias, dropout_p=dropout_p, is_causal=is_causal, scale=scale
)

# The followings are not comsumed by the graph.
Expand Down
6 changes: 4 additions & 2 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,12 +1321,14 @@ def sample_inputs__scaled_dot_product_efficient_attention(
make = opinfo_core.partial(
opinfo_core.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
)
batch, seq_q, seq_kv, num_heads, head_dim = 4, 3, 6, 4, 8
batch, seq_q, seq_kv, num_heads, head_dim = 2, 3, 6, 4, 8

dim_4_q_shape = (batch, num_heads, seq_q, head_dim)
dim_4_kv_shape = (batch, num_heads, seq_kv, head_dim)
shape_attn_bias = (batch, num_heads, seq_q, seq_kv)

qkv_shapes = [(dim_4_q_shape, dim_4_kv_shape)]

samples = []
for qkv_shape, is_causal, dropout_p, compute_log_sumexp in opinfo_core.product(
qkv_shapes, [True, False], [0.0], [True, False]
Expand All @@ -1337,7 +1339,7 @@ def sample_inputs__scaled_dot_product_efficient_attention(
make(shape_q),
make(shape_kv),
make(shape_kv),
attn_bias=None,
attn_bias=make(shape_attn_bias),
is_causal=is_causal,
dropout_p=dropout_p,
compute_log_sumexp=compute_log_sumexp,
Expand Down

0 comments on commit e037aa0

Please sign in to comment.