Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 20, 2024
1 parent 612e425 commit bdcfebb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def upsample_nearest3d_gradient():
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


@register_gradient("org.pytorch.aten", "ATen", "_efficient_attention_forward", "")
def scaled_dot_product_attention_gradient():
return [
Expand All @@ -286,4 +287,4 @@ def scaled_dot_product_attention_gradient():
["GI(0)", "GI(1)", "GI(2)"],
{"operator": {"value": "_efficient_attention_backward", "dtype": "string"}},
),
]
]
Original file line number Diff line number Diff line change
Expand Up @@ -970,17 +970,18 @@ def softmax(g, input, dim, dtype=None):

return softmax


@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale):
dropout_p_casted = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
return g.op(
"org.pytorch.aten::ATen",
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
dropout_p_casted,
is_causal,
scale,
operator_s="_efficient_attention_forward"
)
operator_s="_efficient_attention_forward",
)

0 comments on commit bdcfebb

Please sign in to comment.