Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 27, 2024
1 parent 6bf3018 commit 35bd07a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,29 @@ def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


# based on the following kernel implementation from PyTorch:
# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention_cuda", "")
def scaled_dot_product_attention_gradient():
return [
("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}),
(
("ATen", "org.pytorch.aten"),
["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "O(0)", "O(1)", "O(2)", "O(3)", "I(5)", "grad_input_mask", "I(6)", "I(7)"],
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", "GI(3)"],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward_cuda", "dtype": "string"}},
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
"org.pytorch.aten::ATen",
query,
key,
value,
Expand All @@ -988,5 +988,5 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention_cuda",
outputs=4
outputs=4,
)[0]

0 comments on commit 35bd07a

Please sign in to comment.