Skip to content

Commit

Permalink
Adds ATen fallback for scaled_dot_product_attention (microsoft#21107)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

Introduces an ATen fallback for
`torch.nn.functional.scaled_dot_product_attention`. This operator was
introduced in torch 2.0 and, since then, has had many updates including
the implementation of memory efficient attention for V100 machines. The
current torchscript exporter exports a subgraph for attention which does
not provide the same memory savings that PyTorch's memory efficient
attention kernel provides. Allowing fallback to PyTorch ATen op for
attention helps mitigate memory spike issues for models leveraging
memory efficient attention.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Memory issues arose when integrating ONNX Runtime Training with AML
Stable Diffusion.

---------

Co-authored-by: root <[email protected]>
  • Loading branch information
prathikr and root authored Jul 22, 2024
1 parent 5b9369e commit 11ad299
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 1 deletion.
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```
#### ORTMODULE_ATEN_SDPA_FALLBACK
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, this is disabled. This env var can be used for enabling pre-export attention fall back to PyTorch's [_scaled_dot_product_efficient_attention](https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778) ATen kernel for execution when calling torch.nn.functional.scaled_dot_product_attention. NOTE: only use this feature if user model leverages memory efficient attention WITHOUT masking (ie. attn_mask=None). Utilize GPU profiling looks like NVIDIA Nsight Systems to identify if user model leverages memory efficient attention.

```bash
export ORTMODULE_ATEN_SDPA_FALLBACK=1 # ENABLE
unset ORTMODULE_ATEN_SDPA_FALLBACK # DISABLE
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
15 changes: 14 additions & 1 deletion orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,20 @@ IMPLEMENT_GRADIENT_BUILDER(GetExternalGradient) {
}

std::vector<ArgDef> output_args;
for (const auto& output : node_def.outputs) {
for (size_t output_index = 0; output_index < node_def.outputs.size(); ++output_index) {
// If the input is not used in the forward computation, we don't need it for gradient computation
// Required for ORTMODULE_ATEN_SDPA_FALLBACK
if (static_cast<int>(output_index) >= GetSrcNodeInputSize()) {
continue;
}

if (!IsGradientRequiredForSrcNodeInput(static_cast<int>(output_index))) {
output_args.emplace_back(ArgDef());
continue;
}

const auto& output = node_def.outputs[output_index];

if (output.find("GI(") == 0) {
size_t index = static_cast<size_t>(std::stoi(output.substr(3, output.length() - 4)));
output_args.emplace_back(GI(index));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# 'is_tensor' is optional, if not present, the default is False.

import json
import os

from onnxruntime.capi import _pybind_state as C

Expand Down Expand Up @@ -276,3 +277,39 @@ def upsample_nearest3d_gradient():
@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec")
def upsample_bicubic2d_gradient():
return _upsample_gradient("upsample_bicubic2d_backward", 2)


ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14784
@register_gradient("org.pytorch.aten", "ATen", "_scaled_dot_product_efficient_attention", "")
def scaled_dot_product_attention_gradient():
return [
(
"Constant",
[],
["grad_input_mask"],
{"value": {"value": [1, 1, 1, 0], "dtype": "int", "is_tensor": True}},
),
(
("ATen", "org.pytorch.aten"),
[
"GO(0)",
"I(0)",
"I(1)",
"I(2)",
"I(3)",
"O(0)",
"O(1)",
"O(2)",
"O(3)",
"I(5)",
"grad_input_mask",
"I(6)",
"I(7)",
],
["GI(0)", "GI(1)", "GI(2)", ""],
{"operator": {"value": "_scaled_dot_product_efficient_attention_backward", "dtype": "string"}},
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os
from typing import Callable

import torch
Expand Down Expand Up @@ -969,3 +970,26 @@ def softmax(g, input, dim, dtype=None):
softmax = g.op("Softmax", casted_input, axis_i=dim)

return softmax


ATEN_SDPA_FALLBACK = os.getenv("ORTMODULE_ATEN_SDPA_FALLBACK", None)
if ATEN_SDPA_FALLBACK:
# based on the following internal PyTorch kernel for efficient attention:
# https://github.com/pytorch/pytorch/blob/c12a4f2e65ad41b739aab1a261e2336b4a79fcfb/aten/src/ATen/native/native_functions.yaml#L14778
@register_symbolic("scaled_dot_product_attention")
def scaled_dot_product_attention(g, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
dropout_p_f = g.op("Cast", dropout_p, to_i=torch.onnx.TensorProtoDataType.FLOAT)
compute_logsumexp = g.op("Constant", value_t=torch.tensor([1], dtype=torch.bool))
return g.op(
"org.pytorch.aten::ATen",
query,
key,
value,
attn_mask,
compute_logsumexp,
dropout_p_f,
is_causal,
scale,
operator_s="_scaled_dot_product_efficient_attention",
outputs=4,
)[0]
Original file line number Diff line number Diff line change
Expand Up @@ -6953,3 +6953,74 @@ def generate_inputs(batch_size, max_seq_length, vocab_size):
else:
if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ:
del os.environ["ORTMODULE_MEMORY_OPT_LEVEL"]


@pytest.mark.skipif(
Version(torch.__version__) < Version("2.3.0"),
reason="torch.nn.attention module was introduced in PyTorch 2.3.0",
)
def test_aten_attention():
from torch.nn.attention import SDPBackend, sdpa_kernel

class _NeuralNetAttention(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, q, k, v, attn_mask=None):
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask)

def gen_inputs(device, dtype):
return [
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
torch.randn(32, 8, 128, 64, dtype=dtype, device=device, requires_grad=True),
]

def run_step(model, inputs, attn_mask=None):
prediction = model(*inputs, attn_mask)
prediction.sum().backward()
return prediction

device = "cuda"

os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"] = "1" # TESTING WITHOUT ATTN_MASK

pt_model = _NeuralNetAttention().to(device)
ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="mem_eff_attn"))

# reset manual seed to reset the generator
torch.manual_seed(2333)
pt_input = gen_inputs(device=device, dtype=torch.float32)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)

_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input[0].grad, pt_input[0].grad)
_test_helpers.assert_values_are_close(ort_input[1].grad, pt_input[1].grad)
_test_helpers.assert_values_are_close(ort_input[2].grad, pt_input[2].grad)

execution_mgr = ort_model._torch_module._execution_manager._training_manager
from onnxruntime.training.ortmodule._onnx_models import _get_onnx_file_name

path = os.path.join(
execution_mgr._debug_options.save_onnx_models.path,
_get_onnx_file_name(
execution_mgr._debug_options.save_onnx_models.name_prefix, "execution_model", execution_mgr._export_mode
),
)

onnx_model = onnx.load(path)
onnx_nodes = onnx_model.graph.node

mem_eff_attn_nodes = 0
for node in onnx_nodes:
if "ATen" in node.name:
for attr in node.attribute:
if b"_scaled_dot_product_efficient_attention" in attr.s:
mem_eff_attn_nodes += 1

assert mem_eff_attn_nodes > 0, "No mem_eff_attn nodes are found"

del os.environ["ORTMODULE_ATEN_SDPA_FALLBACK"]

0 comments on commit 11ad299

Please sign in to comment.