diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 75512cb8e8c88..0a1d1e12b929d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -276,3 +276,14 @@ def upsample_nearest3d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) + +@register_gradient("org.pytorch.aten", "ATen", "scaled_dot_product_attention", "") +def scaled_dot_product_attention_gradient(): + return [ + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)", "I(1)", "I(2)"], + ["GI(0)", "GI(1)", "GI(2)"], + {"operator": {"value": "scaled_dot_product_attention", "dtype": "string"}}, + ), + ] \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0bd29b8d155c4..957b51f1f842e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -969,3 +969,17 @@ def softmax(g, input, dim, dtype=None): softmax = g.op("Softmax", casted_input, axis_i=dim) return softmax + +@register_symbolic("scaled_dot_product_attention") +def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): + return g.op( + "org.pytorch.aten::ATen", + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + operator_s="scaled_dot_product_attention" + ) \ No newline at end of file