Skip to content

Commit

Permalink
include Peng's & Vincent's editS
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jul 2, 2024
1 parent 35bd07a commit dd1849a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
7 changes: 6 additions & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,12 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}

std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
const auto& output = node_def.outputs[output_index];
if (!IsGradientRequiredForSrcNodeInput(output_index)) {
output_args.emplace_back(ArgDef());
continue;
}
if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def upsample_bicubic2d_gradient():

# based on the following kernel implementation from PyTorch:
# https://github.com/pytorch/pytorch/blob/52341c28e817ee6bc36b529823f8248ba395d5bb/aten/src/ATen/native/transformers/cuda/attention_backward.cu#L748
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention_cuda", "")
# dispatch logic:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}),
Expand All @@ -302,6 +304,6 @@ def scaled_dot_product_attention_gradient():
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", "GI(3)"],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward_cuda", "dtype": "string"}},
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,8 @@ def softmax(g, input, dim, dtype=None):

# based on the following kernel implementation from PyTorch:
# https://github.com/pytorch/pytorch/blob/00f675bb4c2ec02bb5ffecfc75571026e220701c/aten/src/ATen/native/transformers/cuda/attention.cu#L788
# dispatch logic:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
Expand All @@ -987,6 +989,6 @@ def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention_cuda",
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]

0 comments on commit dd1849a

Please sign in to comment.