From 8219ec9744003040c0fae1d4e97d7ecff7454b49 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 2 Jul 2024 19:06:54 +0000 Subject: [PATCH] adjust test and comments --- .../ortmodule/_custom_gradient_registry.py | 4 +--- .../ortmodule/_custom_op_symbolic_registry.py | 4 +--- .../python/orttraining_test_ortmodule_api.py | 22 ++++++++++++++++++- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 9848b2518b5f5..2319481358f95 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -278,9 +278,7 @@ def upsample_bicubic2d_gradient(): return _upsample_gradient("upsample_bicubic2d_backward", 2) -# based on the following kernel implementation from PyTorch: -# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748 -# dispatch logic: +# based on the following internal PyTorch kernel for efficient attention: # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784 @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0e873338eb095..f979c94fc63b2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -971,9 +971,7 @@ def softmax(g, input, dim, dtype=None): return softmax -# based on the following kernel implementation from PyTorch: -# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788 -# dispatch logic: +# based on the following internal PyTorch kernel for efficient attention: # https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778 @register_symbolic("scaled_dot_product_attention") def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index dfe6984c1c498..7a9abb48860f2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6946,7 +6946,7 @@ def gen_inputs(device, dtype): device = "cuda" pt_model = _NeuralNetAttention().to(device) - ort_model = ORTModule(copy.deepcopy(pt_model)) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn")) def run_step(model, inputs): prediction = model(*inputs) @@ -6962,3 +6962,23 @@ def run_step(model, inputs): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + execution_mgr = ort_model._torch_module._execution_manager._training_manager + from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name + + path = os.path.join( + execution_mgr._debug_options.save_onnx_models.path, + _get_onnx_file_name( + execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode + ), + ) + + onnx_model = onnx.load(path) + onnx_nodes = onnx_model.graph.node + + mem_eff_attn_nodes = 0 + for node in onnx_nodes: + if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator): + mem_eff_attn_nodes += 1 + + assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"