diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index bd193206cab3b..126f84f4d65cc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -277,6 +277,7 @@ def upsample_nearest3d_gradient(): def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) + @register_gradient("org.pytorch.aten", "ATen", "_efficient_attention_forward", "") def scaled_dot_product_attention_gradient(): return [ @@ -286,4 +287,4 @@ def scaled_dot_product_attention_gradient(): ["GI(0)", "GI(1)", "GI(2)"], {"operator": {"value": "_efficient_attention_backward", "dtype": "string"}}, ), - ] \ No newline at end of file + ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 32f9a76f6b7c9..fd52862af2873 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -970,11 +970,12 @@ def softmax(g, input, dim, dtype=None): return softmax + @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): dropout_p_casted = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, @@ -982,5 +983,5 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_ dropout_p_casted, is_causal, scale, - operator_s="_efficient_attention_forward" - ) \ No newline at end of file + operator_s="_efficient_attention_forward", + )