Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First version of attention fusion #1986

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8de7231
First version
gramalingam Dec 18, 2024
a20b903
Add rotary embedding
gramalingam Dec 18, 2024
b8f7a08
Remove SDPA
gramalingam Dec 18, 2024
315c94e
Add comment
gramalingam Dec 18, 2024
2219fd3
Remove MHA
gramalingam Dec 18, 2024
f77f0e7
Merge branch 'main' into rama/fuse-attn
gramalingam Dec 18, 2024
5ec9d1e
Add rewrite for cos-sin computation
gramalingam Dec 20, 2024
90f0b7b
Merge branch 'rama/fuse-attn' of https://github.com/microsoft/onnx-sc…
gramalingam Dec 20, 2024
1fdc19b
Run lint
gramalingam Dec 20, 2024
eb916b8
Add cos sin test
gramalingam Dec 20, 2024
d874dbc
Extend rewriter to support node reuse
gramalingam Dec 20, 2024
a745039
Minor fixes
gramalingam Dec 21, 2024
17c06c3
Fix concat bug in rotary embedding
gramalingam Dec 22, 2024
c7c7c79
Minor cleanup
gramalingam Dec 23, 2024
834815b
Merge branch 'main' into rama/fuse-attn
gramalingam Dec 23, 2024
9a4a58e
Use callable to test callable
gramalingam Dec 23, 2024
766791d
Fix lint issues
gramalingam Dec 23, 2024
c7384af
Attention fusion
gramalingam Dec 24, 2024
d0254d1
Add support for cached state in rewrite
gramalingam Dec 24, 2024
b91166b
Cleanup MHA pattern
gramalingam Dec 24, 2024
205805c
Complete MHA pattern
gramalingam Dec 26, 2024
e907f3e
Add MHA fusion test
gramalingam Dec 26, 2024
82f1919
Add validation condition
gramalingam Dec 26, 2024
fa3b94d
Run lint
gramalingam Dec 26, 2024
9310b67
Merge with main
gramalingam Jan 7, 2025
e0f29e2
Fix merge conflict
gramalingam Jan 7, 2025
41aa177
Fix merge conflict
gramalingam Jan 7, 2025
2688d6e
Merge conflict fix
gramalingam Jan 7, 2025
c080f4a
Merge with main
gramalingam Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@ def get_singleton_value(val: ir.Value | None):
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False
if callable(expected):
return expected(scalar)
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rel_tol=rtol)


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
18 changes: 18 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
"fuse_sdpa",
"fuse_mha",
"fuse_xformers",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

"""
A one-layer SmolLM model test case.
A one-layer SmolLM model test case, with inputs: input_ids, attention_mask, and position_ids.
This is an onnxscript version of the model.
"""

Expand Down Expand Up @@ -234,7 +234,7 @@ def make_model_with_random_weights():
return model


class _SmollmTestData:
class TestData:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
Expand Down
Loading
Loading