-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -276,3 +276,14 @@ def upsample_nearest3d_gradient(): | |
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") | ||
def upsample_bicubic2d_gradient(): | ||
return _upsample_gradient("upsample_bicubic2d_backward", 2) | ||
|
||
@register_gradient("org.pytorch.aten", "ATen", "scaled_dot_product_attention", "") | ||
def scaled_dot_product_attention_gradient(): | ||
return [ | ||
( | ||
("ATen", "org.pytorch.aten"), | ||
["GO(0)", "I(0)", "I(1)", "I(2)"], | ||
["GI(0)", "GI(1)", "GI(2)"], | ||
{"operator": {"value": "scaled_dot_product_attention", "dtype": "string"}}, | ||
), | ||
] | ||
Check warning Code scanning / lintrunner RUFF/W292 Warning
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -969,3 +969,17 @@ def softmax(g, input, dim, dtype=None): | |
softmax = g.op("Softmax", casted_input, axis_i=dim) | ||
|
||
return softmax | ||
|
||
@register_symbolic("scaled_dot_product_attention") | ||
def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): | ||
return g.op( | ||
"org.pytorch.aten::ATen", | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
query, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
key, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
value, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
attn_mask, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
dropout_p, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
is_causal, | ||
Check warning Code scanning / lintrunner RUFF/W291 Warning
Trailing whitespace.
See https://docs.astral.sh/ruff/rules/trailing-whitespace |
||
scale, | ||
operator_s="scaled_dot_product_attention" | ||
) | ||
Check warning Code scanning / lintrunner RUFF/W292 Warning
No newline at end of file.
See https://docs.astral.sh/ruff/rules/missing-newline-at-end-of-file |