From f22b8dc80230eec86541ab0fba5c67ae804b9f4e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 19 Jun 2024 21:09:35 +0000 Subject: [PATCH 01/32] attn aten fallback --- .../ortmodule/_custom_gradient_registry.py | 11 +++++++++++ .../ortmodule/_custom_op_symbolic_registry.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 75512cb8e8c88..0a1d1e12b929d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -276,3 +276,14 @@ def upsample_nearest3d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) + +@register_gradient("org.pytorch.aten", "ATen", "scaled_dot_product_attention", "") +def scaled_dot_product_attention_gradient(): + return [ + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)", "I(1)", "I(2)"], + ["GI(0)", "GI(1)", "GI(2)"], + {"operator": {"value": "scaled_dot_product_attention", "dtype": "string"}}, + ), + ] \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0bd29b8d155c4..957b51f1f842e 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -969,3 +969,17 @@ def softmax(g, input, dim, dtype=None): softmax = g.op("Softmax", casted_input, axis_i=dim) return softmax + +@register_symbolic("scaled_dot_product_attention") +def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): + return g.op( + "org.pytorch.aten::ATen", + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + operator_s="scaled_dot_product_attention" + ) \ No newline at end of file From 612e425a8b26dab9545549f98447d521eb81d337 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 20 Jun 2024 22:53:45 +0000 Subject: [PATCH 02/32] use correct operator names --- .../ortmodule/_custom_gradient_registry.py | 4 ++-- .../ortmodule/_custom_op_symbolic_registry.py | 15 ++++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 0a1d1e12b929d..bd193206cab3b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -277,13 +277,13 @@ def upsample_nearest3d_gradient(): def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) -@register_gradient("org.pytorch.aten", "ATen", "scaled_dot_product_attention", "") +@register_gradient("org.pytorch.aten", "ATen", "_efficient_attention_forward", "") def scaled_dot_product_attention_gradient(): return [ ( ("ATen", "org.pytorch.aten"), ["GO(0)", "I(0)", "I(1)", "I(2)"], ["GI(0)", "GI(1)", "GI(2)"], - {"operator": {"value": "scaled_dot_product_attention", "dtype": "string"}}, + {"operator": {"value": "_efficient_attention_backward", "dtype": "string"}}, ), ] \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 957b51f1f842e..32f9a76f6b7c9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -972,14 +972,15 @@ def softmax(g, input, dim, dtype=None): @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): + dropout_p_casted = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) return g.op( "org.pytorch.aten::ATen", - query, - key, - value, - attn_mask, - dropout_p, - is_causal, + query, + key, + value, + attn_mask, + dropout_p_casted, + is_causal, scale, - operator_s="scaled_dot_product_attention" + operator_s="_efficient_attention_forward" ) \ No newline at end of file From bdcfebbaa3f9fcc8410a3f820b9a4e9b8eb89dd9 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 20 Jun 2024 22:55:29 +0000 Subject: [PATCH 03/32] formatting --- .../python/training/ortmodule/_custom_gradient_registry.py | 3 ++- .../training/ortmodule/_custom_op_symbolic_registry.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index bd193206cab3b..126f84f4d65cc 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -277,6 +277,7 @@ def upsample_nearest3d_gradient(): def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) + @register_gradient("org.pytorch.aten", "ATen", "_efficient_attention_forward", "") def scaled_dot_product_attention_gradient(): return [ @@ -286,4 +287,4 @@ def scaled_dot_product_attention_gradient(): ["GI(0)", "GI(1)", "GI(2)"], {"operator": {"value": "_efficient_attention_backward", "dtype": "string"}}, ), - ] \ No newline at end of file + ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 32f9a76f6b7c9..fd52862af2873 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -970,11 +970,12 @@ def softmax(g, input, dim, dtype=None): return softmax + @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): dropout_p_casted = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, @@ -982,5 +983,5 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_ dropout_p_casted, is_causal, scale, - operator_s="_efficient_attention_forward" - ) \ No newline at end of file + operator_s="_efficient_attention_forward", + ) From 80c3107da16e9ff4c03f9b14faa028d6194c3e27 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 20 Jun 2024 23:02:54 +0000 Subject: [PATCH 04/32] add unit test --- .../python/orttraining_test_ortmodule_api.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index f35bb47f6b41d..ec512ad6b722d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -25,6 +25,7 @@ # Import autocasting libs from torch import nn from torch.cuda import amp +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer from transformers.modeling_outputs import SequenceClassifierOutput @@ -6925,3 +6926,38 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): else: if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ: del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] + +def test_aten_upsample_bicubic(): + class _NeuralNetAttention(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, q, k, v): + with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + + def gen_inputs(device, dtype): + return [ + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), + ] + + device = "cuda" + pt_model = _NeuralNetAttention().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, inputs): + prediction = model(*inputs) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = gen_inputs(device=device, dtype=torch.float32) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) \ No newline at end of file From 2b29b4c56bf117f18e160328ebd0c5ef7d0a1fd0 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 20 Jun 2024 23:03:37 +0000 Subject: [PATCH 05/32] formatting --- .../test/python/orttraining_test_ortmodule_api.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ec512ad6b722d..dfe6984c1c498 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6927,7 +6927,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ: del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] -def test_aten_upsample_bicubic(): + +def test_aten_attention(): class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__() @@ -6960,4 +6961,4 @@ def run_step(model, inputs): ort_prediction = run_step(ort_model, ort_input) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) \ No newline at end of file + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) From d2b85663253195fb0799835662966edf0dd2ecce Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Jun 2024 00:29:37 +0000 Subject: [PATCH 06/32] use pytorch sdpa kernel --- .../ortmodule/_custom_op_symbolic_registry.py | 14 ++++++++------ .../test/python/orttraining_test_ortmodule_api.py | 1 + 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index fd52862af2873..fca0868bd0193 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -971,17 +971,19 @@ def softmax(g, input, dim, dtype=None): return softmax +# based on the following kernel implementation from PyTorch: +# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/attention.cpp#L638 @register_symbolic("scaled_dot_product_attention") -def scaled_dot_product_attention(g, query, key, value, attn_mask, dropout_p, is_causal, scale): - dropout_p_casted = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) +def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, attn_mask, - dropout_p_casted, + dropout_p_f, is_causal, scale, - operator_s="_efficient_attention_forward", - ) + operator_s="scaled_dot_product_attention" + ) \ No newline at end of file diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index dfe6984c1c498..1445720292429 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6962,3 +6962,4 @@ def run_step(model, inputs): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +test_aten_attention() \ No newline at end of file From 0ca8fa0019b1cd4b62e56fc545ecd073c1124ed5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Jun 2024 00:30:21 +0000 Subject: [PATCH 07/32] bug fix --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 1445720292429..dfe6984c1c498 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6962,4 +6962,3 @@ def run_step(model, inputs): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -test_aten_attention() \ No newline at end of file From 8999ff2f67c22a3a43fc32e9005ba72885a1f0eb Mon Sep 17 00:00:00 2001 From: root Date: Wed, 26 Jun 2024 00:31:24 +0000 Subject: [PATCH 08/32] lint --- .../training/ortmodule/_custom_op_symbolic_registry.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index fca0868bd0193..f9b0749c91113 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -977,7 +977,7 @@ def softmax(g, input, dim, dtype=None): def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, @@ -985,5 +985,5 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p dropout_p_f, is_causal, scale, - operator_s="scaled_dot_product_attention" - ) \ No newline at end of file + operator_s="scaled_dot_product_attention", + ) From 6bf30188f9182eaf5be53862ef0cfc7637b203cd Mon Sep 17 00:00:00 2001 From: root Date: Thu, 27 Jun 2024 23:25:40 +0000 Subject: [PATCH 09/32] use different kernel --- .../training/ortmodule/_custom_gradient_registry.py | 9 +++++---- .../ortmodule/_custom_op_symbolic_registry.py | 11 +++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 126f84f4d65cc..ab31d498a71a8 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -278,13 +278,14 @@ def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) -@register_gradient("org.pytorch.aten", "ATen", "_efficient_attention_forward", "") +@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention_cuda", "") def scaled_dot_product_attention_gradient(): return [ + ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), ( ("ATen", "org.pytorch.aten"), - ["GO(0)", "I(0)", "I(1)", "I(2)"], - ["GI(0)", "GI(1)", "GI(2)"], - {"operator": {"value": "_efficient_attention_backward", "dtype": "string"}}, + ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "O(0)", "O(1)", "O(2)", "O(3)", "I(5)", "grad_input_mask", "I(6)", "I(7)"], + ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], + {"operator": {"value": "_scaled_dot_product_efficient_attention_backward_cuda", "dtype": "string"}}, ), ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index f9b0749c91113..4b86c996df412 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -972,18 +972,21 @@ def softmax(g, input, dim, dtype=None): # based on the following kernel implementation from PyTorch: -# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/attention.cpp#L638 +# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788 @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) + compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool)) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, attn_mask, + compute_logsumexp, dropout_p_f, is_causal, scale, - operator_s="scaled_dot_product_attention", - ) + operator_s="_scaled_dot_product_efficient_attention_cuda", + outputs=4 + )[0] From 35bd07ae9921c3c9bff57999adafa5d1cc1be586 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 27 Jun 2024 23:26:54 +0000 Subject: [PATCH 10/32] formatting --- .../ortmodule/_custom_gradient_registry.py | 18 +++++++++++++++++- .../ortmodule/_custom_op_symbolic_registry.py | 4 ++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index ab31d498a71a8..2e1a370de844c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -278,13 +278,29 @@ def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) +# based on the following kernel implementation from PyTorch: +# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748 @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention_cuda", "") def scaled_dot_product_attention_gradient(): return [ ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), ( ("ATen", "org.pytorch.aten"), - ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "O(0)", "O(1)", "O(2)", "O(3)", "I(5)", "grad_input_mask", "I(6)", "I(7)"], + [ + "GO(0)", + "I(0)", + "I(1)", + "I(2)", + "I(3)", + "O(0)", + "O(1)", + "O(2)", + "O(3)", + "I(5)", + "grad_input_mask", + "I(6)", + "I(7)", + ], ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], {"operator": {"value": "_scaled_dot_product_efficient_attention_backward_cuda", "dtype": "string"}}, ), diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 4b86c996df412..3adae577e9196 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -978,7 +978,7 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool)) return g.op( - "org.pytorch.aten::ATen", + "org.pytorch.aten::ATen", query, key, value, @@ -988,5 +988,5 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p is_causal, scale, operator_s="_scaled_dot_product_efficient_attention_cuda", - outputs=4 + outputs=4, )[0] From dd1849a84ba9b29b32fa28712b88bec847ce1c8e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 18:15:25 +0000 Subject: [PATCH 11/32] include Peng's & Vincent's editS --- orttraining/orttraining/core/graph/gradient_builder.cc | 7 ++++++- .../python/training/ortmodule/_custom_gradient_registry.py | 6 ++++-- .../training/ortmodule/_custom_op_symbolic_registry.py | 4 +++- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 22dcf4eb92411..aac803a59110a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1794,7 +1794,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { } std::vector output_args; - for (const auto& output : node_def.outputs) { + for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { + const auto& output = node_def.outputs[output_index]; + if (!IsGradientRequiredForSrcNodeInput(output_index)) { + output_args.emplace_back(ArgDef()); + continue; + } if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 2e1a370de844c..9848b2518b5f5 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -280,7 +280,9 @@ def upsample_bicubic2d_gradient(): # based on the following kernel implementation from PyTorch: # https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748 -@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention_cuda", "") +# dispatch logic: +# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 +@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): return [ ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), @@ -302,6 +304,6 @@ def scaled_dot_product_attention_gradient(): "I(7)", ], ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], - {"operator": {"value": "_scaled_dot_product_efficient_attention_backward_cuda", "dtype": "string"}}, + {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, ), ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 3adae577e9196..0e873338eb095 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -973,6 +973,8 @@ def softmax(g, input, dim, dtype=None): # based on the following kernel implementation from PyTorch: # https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788 +# dispatch logic: +# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) @@ -987,6 +989,6 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p dropout_p_f, is_causal, scale, - operator_s="_scaled_dot_product_efficient_attention_cuda", + operator_s="_scaled_dot_product_efficient_attention", outputs=4, )[0] From 8219ec9744003040c0fae1d4e97d7ecff7454b49 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 19:06:54 +0000 Subject: [PATCH 12/32] adjust test and comments --- .../ortmodule/_custom_gradient_registry.py | 4 +--- .../ortmodule/_custom_op_symbolic_registry.py | 4 +--- .../python/orttraining_test_ortmodule_api.py | 22 ++++++++++++++++++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 9848b2518b5f5..2319481358f95 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -278,9 +278,7 @@ def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) -# based on the following kernel implementation from PyTorch: -# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748 -# dispatch logic: +# based on the following internal PyTorch kernel for efficient attention: # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0e873338eb095..f979c94fc63b2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -971,9 +971,7 @@ def softmax(g, input, dim, dtype=None): return softmax -# based on the following kernel implementation from PyTorch: -# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788 -# dispatch logic: +# based on the following internal PyTorch kernel for efficient attention: # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index dfe6984c1c498..7a9abb48860f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6946,7 +6946,7 @@ def gen_inputs(device, dtype): device = "cuda" pt_model = _NeuralNetAttention().to(device) - ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn")) def run_step(model, inputs): prediction = model(*inputs) @@ -6962,3 +6962,23 @@ def run_step(model, inputs): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = onnx_model.graph.node + + mem_eff_attn_nodes = 0 + for node in onnx_nodes: + if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator): + mem_eff_attn_nodes += 1 + + assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" From 65c2cb7e192e87d15ffe00ddf09c9160d7a598f0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 19:08:06 +0000 Subject: [PATCH 13/32] move import inside test --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 7a9abb48860f2..37bc6c066a1f9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -25,7 +25,6 @@ # Import autocasting libs from torch import nn from torch.cuda import amp -from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AdamW, AutoConfig, BertForSequenceClassification, Trainer from transformers.modeling_outputs import SequenceClassifierOutput @@ -6929,6 +6928,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): def test_aten_attention(): + from torch.nn.attention import SDPBackend, sdpa_kernel + class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__() From 18648adeb3a1ebe5ac561b99704f886b7b2933ac Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 20:47:28 +0000 Subject: [PATCH 14/32] feature flag --- .../ortmodule/_custom_gradient_registry.py | 57 ++++++++++--------- .../ortmodule/_custom_op_symbolic_registry.py | 42 +++++++------- .../python/orttraining_test_ortmodule_api.py | 4 ++ 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 2319481358f95..c9a8f819e8975 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -25,6 +25,7 @@ # 'is_tensor' is optional, if not present, the default is False. import json +import os from onnxruntime.capi import _pybind_state as C @@ -278,30 +279,32 @@ def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) -# based on the following internal PyTorch kernel for efficient attention: -# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 -@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") -def scaled_dot_product_attention_gradient(): - return [ - ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), - ( - ("ATen", "org.pytorch.aten"), - [ - "GO(0)", - "I(0)", - "I(1)", - "I(2)", - "I(3)", - "O(0)", - "O(1)", - "O(2)", - "O(3)", - "I(5)", - "grad_input_mask", - "I(6)", - "I(7)", - ], - ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], - {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, - ), - ] +ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) +if ATEN_SDPA_FALLBACK: + # based on the following internal PyTorch kernel for efficient attention: + # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 + @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") + def scaled_dot_product_attention_gradient(): + return [ + ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), + ( + ("ATen", "org.pytorch.aten"), + [ + "GO(0)", + "I(0)", + "I(1)", + "I(2)", + "I(3)", + "O(0)", + "O(1)", + "O(2)", + "O(3)", + "I(5)", + "grad_input_mask", + "I(6)", + "I(7)", + ], + ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], + {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, + ), + ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index f979c94fc63b2..28aca54023bfd 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- from typing import Callable +import os import torch import torch.onnx.symbolic_helper as sym_help @@ -970,23 +971,24 @@ def softmax(g, input, dim, dtype=None): return softmax - -# based on the following internal PyTorch kernel for efficient attention: -# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 -@register_symbolic("scaled_dot_product_attention") -def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): - dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) - compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool)) - return g.op( - "org.pytorch.aten::ATen", - query, - key, - value, - attn_mask, - compute_logsumexp, - dropout_p_f, - is_causal, - scale, - operator_s="_scaled_dot_product_efficient_attention", - outputs=4, - )[0] +ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) +if ATEN_SDPA_FALLBACK: + # based on the following internal PyTorch kernel for efficient attention: + # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 + @register_symbolic("scaled_dot_product_attention") + def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): + dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT) + compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool)) + return g.op( + "org.pytorch.aten::ATen", + query, + key, + value, + attn_mask, + compute_logsumexp, + dropout_p_f, + is_causal, + scale, + operator_s="_scaled_dot_product_efficient_attention", + outputs=4, + )[0] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 37bc6c066a1f9..ae5737e804fa3 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6930,6 +6930,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): def test_aten_attention(): from torch.nn.attention import SDPBackend, sdpa_kernel + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" + class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__() @@ -6983,3 +6985,5 @@ def run_step(model, inputs): mem_eff_attn_nodes += 1 assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" + + del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] From be9ce0a1ede845a1d7b973999e8d57c980d0e997 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 20:52:31 +0000 Subject: [PATCH 15/32] add documentation --- docs/ORTModule_Training_Guidelines.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 8d5472ba30601..bb1bf0aace73c 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable ``` +#### ORTMODULE_ATEN_SDPA_FALLBACK + +- **Feature Area**: *ORTMODULE/Optimizations* +- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. + + ```bash + export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE + export ORTMODULE_ATEN_SDPA_FALLBACK=0 # DISABLE + ``` + ### 2.2 Memory Optimization Q: *Want to run a bigger batch size?* From e269e893ddb613cb98e3fe17ba544bff254f2310 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 21:39:51 +0000 Subject: [PATCH 16/32] minor fixes --- orttraining/orttraining/core/graph/gradient_builder.cc | 2 +- .../python/training/ortmodule/_custom_gradient_registry.py | 7 ++++++- .../training/ortmodule/_custom_op_symbolic_registry.py | 3 ++- .../test/python/orttraining_test_ortmodule_api.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index aac803a59110a..9d22d2fa3ce2a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1795,11 +1795,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { std::vector output_args; for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { - const auto& output = node_def.outputs[output_index]; if (!IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } + const auto& output = node_def.outputs[output_index]; if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index c9a8f819e8975..00c969cb40844 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -286,7 +286,12 @@ def upsample_bicubic2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): return [ - ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), + ( + "Constant", + [], + ["grad_input_mask"], + {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}, + ), ( ("ATen", "org.pytorch.aten"), [ diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 28aca54023bfd..e21a93b4fdfee 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Callable import os +from typing import Callable import torch import torch.onnx.symbolic_helper as sym_help @@ -971,6 +971,7 @@ def softmax(g, input, dim, dtype=None): return softmax + ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) if ATEN_SDPA_FALLBACK: # based on the following internal PyTorch kernel for efficient attention: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ae5737e804fa3..75a62577de5da 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6981,7 +6981,7 @@ def run_step(model, inputs): mem_eff_attn_nodes = 0 for node in onnx_nodes: - if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator): + if "_scaled_dot_product_efficient_attention" in node.attributes.operator: mem_eff_attn_nodes += 1 assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" From d3cc487a2e1a4d5cd2f06a688b65ee535f9e9b85 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 21:43:06 +0000 Subject: [PATCH 17/32] doc update --- docs/ORTModule_Training_Guidelines.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index bb1bf0aace73c..f0336bfe2c24d 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -311,7 +311,6 @@ A classical usage of disabling the deep copy: when the deep copy before module e ```bash export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE - export ORTMODULE_ATEN_SDPA_FALLBACK=0 # DISABLE ``` ### 2.2 Memory Optimization From f5005289ba1253bfee01fdc99221bfe1efd7eb1c Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Jul 2024 19:15:45 +0000 Subject: [PATCH 18/32] peng fix, xavier suggestion --- orttraining/orttraining/core/graph/gradient_builder.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 9d22d2fa3ce2a..735f77a8ef90b 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1795,11 +1795,15 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { std::vector output_args; for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { - if (!IsGradientRequiredForSrcNodeInput(output_index)) { + // If the input is not used in the forward computation, we don't need it for gradient computation + // Required for ORTMODULE_ATEN_SDPA_FALLBACK + if (output_index >= GetSrcNodeInputSize() or !IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } + const auto& output = node_def.outputs[output_index]; + if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); From c4cdab615b6266ce90b59be5427144e34584837a Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Jul 2024 20:29:34 +0000 Subject: [PATCH 19/32] bug fix --- orttraining/orttraining/core/graph/gradient_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 735f77a8ef90b..bebb9b395fbbd 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1797,7 +1797,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { // If the input is not used in the forward computation, we don't need it for gradient computation // Required for ORTMODULE_ATEN_SDPA_FALLBACK - if (output_index >= GetSrcNodeInputSize() or !IsGradientRequiredForSrcNodeInput(output_index)) { + if ((output_index >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } From c05a5ee66e59299a8785f5674ed482985584d585 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Jul 2024 20:36:47 +0000 Subject: [PATCH 20/32] bug fix --- orttraining/orttraining/core/graph/gradient_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index bebb9b395fbbd..fc24d214e20c5 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1797,7 +1797,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { // If the input is not used in the forward computation, we don't need it for gradient computation // Required for ORTMODULE_ATEN_SDPA_FALLBACK - if ((output_index >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { + if ((static_cast(output_index) >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } From f82bd48df0dd8a9076df3d4aa2e25efd80c53d1f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Jul 2024 20:49:51 +0000 Subject: [PATCH 21/32] bug fix --- orttraining/orttraining/core/graph/gradient_builder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index fc24d214e20c5..107213945e8e6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1797,7 +1797,7 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { // If the input is not used in the forward computation, we don't need it for gradient computation // Required for ORTMODULE_ATEN_SDPA_FALLBACK - if ((static_cast(output_index) >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { + if ((static_cast(output_index) >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } From 668409b96eb7383d6bac166ea499a1e30b379b41 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 8 Jul 2024 21:33:38 +0000 Subject: [PATCH 22/32] adjust unit test --- .../test/python/orttraining_test_ortmodule_api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 75a62577de5da..a1ceb66cc0193 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6964,7 +6964,9 @@ def run_step(model, inputs): ort_prediction = run_step(ort_model, ort_input) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + _test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad) + _test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad) + _test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad) execution_mgr = ort_model._torch_module._execution_manager._training_manager from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name @@ -6981,8 +6983,10 @@ def run_step(model, inputs): mem_eff_attn_nodes = 0 for node in onnx_nodes: - if "_scaled_dot_product_efficient_attention" in node.attributes.operator: - mem_eff_attn_nodes += 1 + if "ATen" in node.name: + for attr in node.attribute: + if "_scaled_dot_product_efficient_attention" in attr.s: + mem_eff_attn_nodes += 1 assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" From b5f116937fcb11ad068fa2e9745c49b626801dc8 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jul 2024 00:13:24 +0000 Subject: [PATCH 23/32] adjust checks --- orttraining/orttraining/core/graph/gradient_builder.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 107213945e8e6..76fe0ee91d4c6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1797,7 +1797,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { // If the input is not used in the forward computation, we don't need it for gradient computation // Required for ORTMODULE_ATEN_SDPA_FALLBACK - if ((static_cast(output_index) >= GetSrcNodeInputSize()) || !IsGradientRequiredForSrcNodeInput(output_index)) { + if (static_cast(output_index) >= GetSrcNodeInputSize()) { + continue; + } + + if (!IsGradientRequiredForSrcNodeInput(static_cast(output_index))) { output_args.emplace_back(ArgDef()); continue; } From 31becabb5ac7c50f31ef0cd55106835f3f7d2caf Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jul 2024 00:52:39 +0000 Subject: [PATCH 24/32] grad input fix --- .../python/training/ortmodule/_custom_gradient_registry.py | 2 +- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 00c969cb40844..007794f4ad052 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -290,7 +290,7 @@ def scaled_dot_product_attention_gradient(): "Constant", [], ["grad_input_mask"], - {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}, + {"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}}, ), ( ("ATen", "org.pytorch.aten"), diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index a1ceb66cc0193..b6e3fb35341a0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6985,7 +6985,7 @@ def run_step(model, inputs): for node in onnx_nodes: if "ATen" in node.name: for attr in node.attribute: - if "_scaled_dot_product_efficient_attention" in attr.s: + if b"_scaled_dot_product_efficient_attention" in attr.s: mem_eff_attn_nodes += 1 assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" From 5aa147d90c47025dc652a3266f34ed5f72854539 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jul 2024 19:01:30 +0000 Subject: [PATCH 25/32] handle both with and without bias --- docs/ORTModule_Training_Guidelines.md | 4 +- .../ortmodule/_custom_gradient_registry.py | 3 +- .../python/orttraining_test_ortmodule_api.py | 59 ++++++++++++++++--- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index f0336bfe2c24d..6ac59a18edee0 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -310,7 +310,9 @@ A classical usage of disabling the deep copy: when the deep copy before module e - **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. ```bash - export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE + export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE **WITHOUT** ATTN_MASK INPUT + export ORTMODULE_ATEN_SDPA_FALLBACK=MASKED # ENABLE **WITH** ATTN_MASK INPUT + unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE ``` ### 2.2 Memory Optimization diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 007794f4ad052..5c7747a4e9b25 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -285,12 +285,13 @@ def upsample_bicubic2d_gradient(): # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): + grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0] return [ ( "Constant", [], ["grad_input_mask"], - {"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}}, + {"value": {"value": grad_input_mask, "dtype": "int", "is_tensor": True}}, ), ( ("ATen", "org.pytorch.aten"), diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index b6e3fb35341a0..5cd121932eabf 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6930,15 +6930,13 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): def test_aten_attention(): from torch.nn.attention import SDPBackend, sdpa_kernel - os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" - class _NeuralNetAttention(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, q, k, v): + def forward(self, q, k, v, attn_mask=None): with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION): - return torch.nn.functional.scaled_dot_product_attention(q, k, v) + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask) def gen_inputs(device, dtype): return [ @@ -6947,15 +6945,18 @@ def gen_inputs(device, dtype): torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True), ] + def run_step(model, inputs, attn_mask=None): + prediction = model(*inputs, attn_mask) + prediction.sum().backward() + return prediction + device = "cuda" + + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK + pt_model = _NeuralNetAttention().to(device) ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn")) - def run_step(model, inputs): - prediction = model(*inputs) - prediction.sum().backward() - return prediction - # reset manual seed to reset the generator torch.manual_seed(2333) pt_input = gen_inputs(device=device, dtype=torch.float32) @@ -6990,4 +6991,44 @@ def run_step(model, inputs): assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "MASKED" # TESTING WITH ATTN_MASK + + pt_model = _NeuralNetAttention().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn_masked")) + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = gen_inputs(device=device, dtype=torch.float32) + attn_mask = torch.randint(2, (32, 8, 128, 128), dtype=torch.float32, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input, attn_mask) + ort_prediction = run_step(ort_model, ort_input, attn_mask) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad) + _test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad) + _test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad) + + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = onnx_model.graph.node + + mem_eff_attn_nodes = 0 + for node in onnx_nodes: + if "ATen" in node.name: + for attr in node.attribute: + if b"_scaled_dot_product_efficient_attention" in attr.s: + mem_eff_attn_nodes += 1 + + assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" + del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] From 37eb6bc2dda88c7b65ee67daa67327bd178a702e Mon Sep 17 00:00:00 2001 From: root Date: Tue, 9 Jul 2024 19:03:09 +0000 Subject: [PATCH 26/32] full mask --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5cd121932eabf..f2597789cf9d0 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6999,7 +6999,7 @@ def run_step(model, inputs, attn_mask=None): # reset manual seed to reset the generator torch.manual_seed(2333) pt_input = gen_inputs(device=device, dtype=torch.float32) - attn_mask = torch.randint(2, (32, 8, 128, 128), dtype=torch.float32, device=device, requires_grad=True) + attn_mask = torch.ones(32, 8, 128, 128, dtype=torch.float32, device=device, requires_grad=True) ort_input = copy.deepcopy(pt_input) pt_prediction = run_step(pt_model, pt_input, attn_mask) ort_prediction = run_step(ort_model, ort_input, attn_mask) From 3484926fef587d6f22aeb119be3d43eabcf74659 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 12 Jul 2024 21:49:42 +0000 Subject: [PATCH 27/32] lint --- .../test/python/orttraining_test_ortmodule_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 47a6c046e5aa1..b89a810cd9b47 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6980,7 +6980,7 @@ def run_step(model, inputs, attn_mask=None): device = "cuda" - os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK pt_model = _NeuralNetAttention().to(device) ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn")) @@ -7019,8 +7019,8 @@ def run_step(model, inputs, attn_mask=None): assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" - os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "MASKED" # TESTING WITH ATTN_MASK - + os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "MASKED" # TESTING WITH ATTN_MASK + pt_model = _NeuralNetAttention().to(device) ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn_masked")) From 8d0e87956fb8b91978dd459df6455c12a60ed9b1 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 15 Jul 2024 18:15:22 +0000 Subject: [PATCH 28/32] add version check for tesT --- .../orttraining/test/python/orttraining_test_ortmodule_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index b89a810cd9b47..1c4173000716d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6955,6 +6955,10 @@ def generate_inputs(batch_size, max_seq_length, vocab_size): del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] +@pytest.mark.skipif( + Version(torch.__version__) < Version("2.3.0"), + reason="torch.nn.attention module was introduced in PyTorch 2.3.0", +) def test_aten_attention(): from torch.nn.attention import SDPBackend, sdpa_kernel From b72a042ed36782a99173cc70eadd4429df3564dd Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Jul 2024 18:50:22 +0000 Subject: [PATCH 29/32] grad output adjustment --- .../python/training/ortmodule/_custom_gradient_registry.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 317b20153669f..f939b9cbcc4ec 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -286,6 +286,11 @@ def upsample_bicubic2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0] + grad_output = ( + ["GI(0)", "GI(1)", "GI(2)", "GI(3)"] + if ATEN_SDPA_FALLBACK.upper() == "MASKED" + else ["GI(0)", "GI(1)", "GI(2)", ""] + ) return [ ( "Constant", @@ -310,7 +315,7 @@ def scaled_dot_product_attention_gradient(): "I(6)", "I(7)", ], - ["GI(0)", "GI(1)", "GI(2)", "GI(3)"], + grad_output, {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, ), ] From 6b4dd101337a021bf83127e418685d3df7663c3d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Jul 2024 20:08:17 +0000 Subject: [PATCH 30/32] add more docs --- docs/ORTModule_Training_Guidelines.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 6ac59a18edee0..6ba77ff8448bf 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -307,7 +307,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_ATEN_SDPA_FALLBACK - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. +- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. NOTE: will not work if model uses both masked and unmasked attention, can only be one. ```bash export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE **WITHOUT** ATTN_MASK INPUT From 999b04bc2888673b6d5af13fa5a1aa06b22ed3db Mon Sep 17 00:00:00 2001 From: root Date: Wed, 17 Jul 2024 18:32:18 +0000 Subject: [PATCH 31/32] remove support for masked attention --- docs/ORTModule_Training_Guidelines.md | 5 +-- .../ortmodule/_custom_gradient_registry.py | 10 +---- .../python/orttraining_test_ortmodule_api.py | 40 ------------------- 3 files changed, 4 insertions(+), 51 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 6ba77ff8448bf..779dbe5c74f6e 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -307,11 +307,10 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_ATEN_SDPA_FALLBACK - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. NOTE: will not work if model uses both masked and unmasked attention, can only be one. +- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. NOTE: only works if attn_mask=None when torch.nn.functional.scaled_dot_product_attention is called. ```bash - export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE **WITHOUT** ATTN_MASK INPUT - export ORTMODULE_ATEN_SDPA_FALLBACK=MASKED # ENABLE **WITH** ATTN_MASK INPUT + export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE ``` diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index f939b9cbcc4ec..97650f509ac88 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -285,18 +285,12 @@ def upsample_bicubic2d_gradient(): # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): - grad_input_mask = [1, 1, 1, 1] if ATEN_SDPA_FALLBACK.upper() == "MASKED" else [1, 1, 1, 0] - grad_output = ( - ["GI(0)", "GI(1)", "GI(2)", "GI(3)"] - if ATEN_SDPA_FALLBACK.upper() == "MASKED" - else ["GI(0)", "GI(1)", "GI(2)", ""] - ) return [ ( "Constant", [], ["grad_input_mask"], - {"value": {"value": grad_input_mask, "dtype": "int", "is_tensor": True}}, + {"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}}, ), ( ("ATen", "org.pytorch.aten"), @@ -315,7 +309,7 @@ def scaled_dot_product_attention_gradient(): "I(6)", "I(7)", ], - grad_output, + ["GI(0)", "GI(1)", "GI(2)", ""], {"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}}, ), ] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 1c4173000716d..21b0241959195 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -7023,44 +7023,4 @@ def run_step(model, inputs, attn_mask=None): assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" - os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "MASKED" # TESTING WITH ATTN_MASK - - pt_model = _NeuralNetAttention().to(device) - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn_masked")) - - # reset manual seed to reset the generator - torch.manual_seed(2333) - pt_input = gen_inputs(device=device, dtype=torch.float32) - attn_mask = torch.ones(32, 8, 128, 128, dtype=torch.float32, device=device, requires_grad=True) - ort_input = copy.deepcopy(pt_input) - pt_prediction = run_step(pt_model, pt_input, attn_mask) - ort_prediction = run_step(ort_model, ort_input, attn_mask) - - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad) - _test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad) - _test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad) - - execution_mgr = ort_model._torch_module._execution_manager._training_manager - from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name - - path = os.path.join( - execution_mgr._debug_options.save_onnx_models.path, - _get_onnx_file_name( - execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode - ), - ) - - onnx_model = onnx.load(path) - onnx_nodes = onnx_model.graph.node - - mem_eff_attn_nodes = 0 - for node in onnx_nodes: - if "ATen" in node.name: - for attr in node.attribute: - if b"_scaled_dot_product_efficient_attention" in attr.s: - mem_eff_attn_nodes += 1 - - assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found" - del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] From b1fe48992797e8df07d6c7a38bd4c18306906960 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Jul 2024 20:37:22 +0000 Subject: [PATCH 32/32] adjust docs --- docs/ORTModule_Training_Guidelines.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 779dbe5c74f6e..c79ba59a07ee9 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -307,7 +307,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_ATEN_SDPA_FALLBACK - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's efficient_attention ATen kernel for execution. NOTE: only works if attn_mask=None when torch.nn.functional.scaled_dot_product_attention is called. +- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention. ```bash export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE