Skip to content

Commit

Permalink
attn aten fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 19, 2024
1 parent 3ae5df1 commit f22b8dc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,14 @@ def upsample_nearest3d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)

@register_gradient("org.pytorch.aten", "ATen", "scaled_dot_product_attention", "")
def scaled_dot_product_attention_gradient():
return [
(
("ATen", "org.pytorch.aten"),
["GO(0)", "I(0)", "I(1)", "I(2)"],
["GI(0)", "GI(1)", "GI(2)"],
{"operator": {"value": "scaled_dot_product_attention", "dtype": "string"}},
),
]

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,17 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", casted_input, axis_i=dim)

return softmax

@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale):
return g.op(
"org.pytorch.aten::ATen",

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

query,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

key,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

value,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

attn_mask,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

dropout_p,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

is_causal,

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

scale,
operator_s="scaled_dot_product_attention"
)

Check warning

Code scanning / lintrunner

RUFF/W292 Warning

0 comments on commit f22b8dc

Please sign in to comment.