Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Jul 2, 2024
1 parent be9ce0a commit e269e89
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1795,11 +1795,11 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {

std::vector<ArgDef> output_args;
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
const auto& output = node_def.outputs[output_index];
if (!IsGradientRequiredForSrcNodeInput(output_index)) {
output_args.emplace_back(ArgDef());
continue;
}
const auto& output = node_def.outputs[output_index];
if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,12 @@ def upsample_bicubic2d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}),
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from typing import Callable
import os
from typing import Callable

import torch
import torch.onnx.symbolic_helper as sym_help
Expand Down Expand Up @@ -971,6 +971,7 @@ def softmax(g, input, dim, dtype=None):

return softmax


ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6981,7 +6981,7 @@ def run_step(model, inputs):

mem_eff_attn_nodes = 0
for node in onnx_nodes:
if ("ATen" in node.name) and ("scaled_dot_product_attention" in node.attributes.operator):
if "_scaled_dot_product_efficient_attention" in node.attributes.operator:
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"
Expand Down

0 comments on commit e269e89

Please sign in to comment.