Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ATen fallback for scaled_dot_product_attention #21107

Merged
merged 35 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f22b8dc
attn aten fallback
Jun 19, 2024
612e425
use correct operator names
Jun 20, 2024
bdcfebb
formatting
Jun 20, 2024
80c3107
add unit test
Jun 20, 2024
2b29b4c
formatting
Jun 20, 2024
d2b8566
use pytorch sdpa kernel
Jun 26, 2024
0ca8fa0
bug fix
Jun 26, 2024
8999ff2
lint
Jun 26, 2024
6bf3018
use different kernel
Jun 27, 2024
35bd07a
formatting
Jun 27, 2024
dd1849a
include Peng's & Vincent's editS
Jul 2, 2024
8219ec9
adjust test and comments
prathikr Jul 2, 2024
65c2cb7
move import inside test
prathikr Jul 2, 2024
b5f5863
merge with master
prathikr Jul 2, 2024
18648ad
feature flag
prathikr Jul 2, 2024
be9ce0a
add documentation
prathikr Jul 2, 2024
e269e89
minor fixes
prathikr Jul 2, 2024
d3cc487
doc update
prathikr Jul 2, 2024
f500528
peng fix, xavier suggestion
prathikr Jul 8, 2024
c4cdab6
bug fix
prathikr Jul 8, 2024
c05a5ee
bug fix
prathikr Jul 8, 2024
f82bd48
bug fix
prathikr Jul 8, 2024
668409b
adjust unit test
prathikr Jul 8, 2024
b5f1169
adjust checks
prathikr Jul 9, 2024
31becab
grad input fix
prathikr Jul 9, 2024
5aa147d
handle both with and without bias
prathikr Jul 9, 2024
37eb6bc
full mask
prathikr Jul 9, 2024
ae3b5e7
merge with main
prathikr Jul 12, 2024
3484926
lint
prathikr Jul 12, 2024
8d0e879
add version check for tesT
prathikr Jul 15, 2024
b72a042
grad output adjustment
prathikr Jul 16, 2024
6b4dd10
add more docs
prathikr Jul 16, 2024
999b04b
remove support for masked attention
prathikr Jul 17, 2024
b1fe489
adjust docs
prathikr Jul 18, 2024
4ab54e6
Merge remote-tracking branch 'origin' into prathikrao/attn-aten-fallback
prathikr Jul 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```

#### ORTMODULE_ATEN_SDPA_FALLBACK

- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.

```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
15 changes: 14 additions & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}

std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
// If the input is not used in the forward computation, we don't need it for gradient computation
// Required for ORTMODULE_ATEN_SDPA_FALLBACK
if (static_cast<int>(output_index) >= GetSrcNodeInputSize()) {
continue;
}

if (!IsGradientRequiredForSrcNodeInput(static_cast<int>(output_index))) {
output_args.emplace_back(ArgDef());
continue;
}

const auto& output = node_def.outputs[output_index];

if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# 'is_tensor' is optional, if not present, the default is False.

import json
import os

from onnxruntime.capi import _pybind_state as C

Expand Down Expand Up @@ -276,3 +277,39 @@ def upsample_nearest3d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
prathikr marked this conversation as resolved.
Show resolved Hide resolved
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", ""],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os
from typing import Callable

import torch
Expand Down Expand Up @@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", casted_input, axis_i=dim)

return softmax


ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
prathikr marked this conversation as resolved.
Show resolved Hide resolved
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]
Original file line number Diff line number Diff line change
Expand Up @@ -6953,3 +6953,74 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
else:
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]


@pytest.mark.skipif(
Version(torch.__version__) < Version("2.3.0"),
reason="torch.nn.attention module was introduced in PyTorch 2.3.0",
)
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel

class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v, attn_mask=None):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)

def gen_inputs(device, dtype):
return [
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
]

def run_step(model, inputs, attn_mask=None):
prediction = model(*inputs, attn_mask)
prediction.sum().backward()
return prediction

device = "cuda"

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK

pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)

execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name

path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)

onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node
Dismissed Show dismissed Hide dismissed

mem_eff_attn_nodes = 0
for node in onnx_nodes:
if "ATen" in node.name:
for attr in node.attribute:
if b"_scaled_dot_product_efficient_attention" in attr.s:
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]
Loading