diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 317b20153669f..f939b9cbcc4ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -286,6 +286,11 @@ def upsample_bicubic2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0] + grad_output = ( + ["GI(0)", "GI(1)", "GI(2)", "GI(3)"] + if ATEN_SDPA_FALLBACK.upper() == "MASKED" + else ["GI(0)", "GI(1)", "GI(2)", ""] + ) return [ ( "Constant", @@ -310,7 +315,7 @@ def scaled_dot_product_attention_gradient(): "I(6)", "I(7)", ], - ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], + grad_output, {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, ), ]