Skip to content

Commit

Permalink
grad output adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Jul 16, 2024
1 parent 8d0e879 commit b72a042
Showing 1 changed file with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def upsample_bicubic2d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0]
grad_output = (
["GI(0)", "GI(1)", "GI(2)", "GI(3)"]
if ATEN_SDPA_FALLBACK.upper() == "MASKED"
else ["GI(0)", "GI(1)", "GI(2)", ""]
)
return [
(
"Constant",
Expand All @@ -310,7 +315,7 @@ def scaled_dot_product_attention_gradient():
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", "GI(3)"],
grad_output,
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]

0 comments on commit b72a042

Please sign in to comment.