Skip to content

Commit

Permalink
feature flag
Browse files Browse the repository at this point in the history
  • Loading branch information
prathikr committed Jul 2, 2024
1 parent b5f5863 commit 18648ad
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# 'is_tensor' is optional, if not present, the default is False.

import json
import os

from onnxruntime.capi import _pybind_state as C

Expand Down Expand Up @@ -278,30 +279,32 @@ def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", "GI(3)"],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
("Constant", [], ["grad_input_mask"], {"value": {"value": [1, 1, 1, 1], "dtype": "int", "is_tensor": True}}),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", "GI(3)"],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# --------------------------------------------------------------------------

from typing import Callable
import os

import torch
import torch.onnx.symbolic_helper as sym_help
Expand Down Expand Up @@ -970,23 +971,24 @@ def softmax(g, input, dim, dtype=None):

return softmax


# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]
ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]
Original file line number Diff line number Diff line change
Expand Up @@ -6930,6 +6930,8 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1"

class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -6983,3 +6985,5 @@ def run_step(model, inputs):
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]

0 comments on commit 18648ad

Please sign in to comment.