diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index aac803a59110a..9d22d2fa3ce2a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -1795,11 +1795,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) { std::vector output_args; for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) { - const auto& output = node_def.outputs[output_index]; if (!IsGradientRequiredForSrcNodeInput(output_index)) { output_args.emplace_back(ArgDef()); continue; } + const auto& output = node_def.outputs[output_index]; if (output.find("GI(") == 0) { size_t index = static_cast(std::stoi(output.substr(3, output.length() - 4))); output_args.emplace_back(GI(index)); diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index c9a8f819e8975..00c969cb40844 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -286,7 +286,12 @@ def upsample_bicubic2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "") def scaled_dot_product_attention_gradient(): return [ - ("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}), + ( + "Constant", + [], + ["grad_input_mask"], + {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}, + ), ( ("ATen", "org.pytorch.aten"), [ diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 28aca54023bfd..e21a93b4fdfee 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from typing import Callable import os +from typing import Callable import torch import torch.onnx.symbolic_helper as sym_help @@ -971,6 +971,7 @@ def softmax(g, input, dim, dtype=None): return softmax + ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None) if ATEN_SDPA_FALLBACK: # based on the following internal PyTorch kernel for efficient attention: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index ae5737e804fa3..75a62577de5da 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6981,7 +6981,7 @@ def run_step(model, inputs): mem_eff_attn_nodes = 0 for node in onnx_nodes: - if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator): + if "_scaled_dot_product_efficient_attention" in node.attributes.operator: mem_eff_attn_nodes += 1 assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"