Skip to content

Commit

Permalink
Run lint
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam committed Dec 26, 2024
1 parent 82f1919 commit fa3b94d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 28 deletions.
4 changes: 2 additions & 2 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers

__all__ = [
"fuse_rms_normalization",
Expand Down
56 changes: 37 additions & 19 deletions onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
"""

import numpy
from onnxscript import script

import onnxscript.ir as ir
from onnxscript.onnx_types import FLOAT, INT64
from onnxscript import script
from onnxscript.onnx_opset import opset18
from onnxscript.onnx_types import FLOAT, INT64


def make_model(
Expand All @@ -28,11 +29,16 @@ def make_model(
model_rotary_emb_inv_freq,
):
@script()
def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_values_0_0: FLOAT[1,32,16,64], past_key_values_0_1: FLOAT[1,32,16,64]) -> (FLOAT[1,30,49152], FLOAT[1,32,46,64], FLOAT[1,32,46,64]):
def main_graph(
input_ids: INT64[1, 30],

Check failure

Code scanning / lintrunner

MYPY/type-arg Error

"INT64" expects no type arguments, but 2 given To disable, use # type: ignore[type-arg]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[1] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[30] instead? To disable, use # type: ignore[valid-type]
position_ids: INT64[1, 30],

Check failure

Code scanning / lintrunner

MYPY/type-arg Error

"INT64" expects no type arguments, but 2 given To disable, use # type: ignore[type-arg]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[1] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[30] instead? To disable, use # type: ignore[valid-type]
past_key_values_0_0: FLOAT[1, 32, 16, 64],

Check failure

Code scanning / lintrunner

MYPY/type-arg Error

"FLOAT" expects no type arguments, but 4 given To disable, use # type: ignore[type-arg]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[1] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[32] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[16] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[64] instead? To disable, use # type: ignore[valid-type]
past_key_values_0_1: FLOAT[1, 32, 16, 64],

Check failure

Code scanning / lintrunner

MYPY/type-arg Error

"FLOAT" expects no type arguments, but 4 given To disable, use # type: ignore[type-arg]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[1] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[32] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[16] instead? To disable, use # type: ignore[valid-type]

Check failure

Code scanning / lintrunner

MYPY/valid-type Error

Invalid type: try using Literal[64] instead? To disable, use # type: ignore[valid-type]
) -> (FLOAT[1, 30, 49152], FLOAT[1, 32, 46, 64], FLOAT[1, 32, 46, 64]):

Check failure

Code scanning / lintrunner

MYPY/syntax Error

Syntax error in type annotation To disable, use # type: ignore[syntax]
embedding = opset18.Gather(lm_head_weight, input_ids, axis=0)
val_2 = opset18.CastLike(1.0, 46)

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T1_CastLike" of "CastLike" of "Opset15" cannot be "float" To disable, use # type: ignore[type-var]

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T2_CastLike" of "CastLike" of "Opset15" cannot be "int" To disable, use # type: ignore[type-var]
arange = opset18.Range(16, 46, val_2)

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T_Range" of "Range" of "Opset11" cannot be "int" To disable, use # type: ignore[type-var]
val_5 = opset18.Cast(-3.4028235e+38, to=1)
val_5 = opset18.Cast(-3.4028235e38, to=1)

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T1_Cast" of "Cast" of "Opset13" cannot be "float" To disable, use # type: ignore[type-var]
val_7 = opset18.Cast([30, 47], to=7)

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T1_Cast" of "Cast" of "Opset13" cannot be "list[int]" To disable, use # type: ignore[type-var]
full = opset18.Expand(val_5, val_7)

Check failure

Code scanning / lintrunner

MYPY/type-var Error

Value of type variable "T_Expand" of "Expand" of "Opset13" cannot be "BFLOAT16 | BOOL | DOUBLE | FLOAT | FLOAT16 | INT16 | INT32 | INT64 | INT8 | STRING | UINT16 | UINT32 | UINT64 | UINT8" To disable, use # type: ignore[type-var]

Check failure

Code scanning / lintrunner

MYPY/arg-type Error

Argument 2 to "Expand" of "Opset13" has incompatible type "BFLOAT16 | BOOL | DOUBLE | FLOAT | FLOAT16 | INT16 | INT32 | INT64 | INT8 | STRING | UINT16 | UINT32 | UINT64 | UINT8"; expected "INT64" To disable, use # type: ignore[arg-type]
diagonal__1 = opset18.Constant(value_int=1)
Expand Down Expand Up @@ -117,7 +123,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value
_to_copy_5 = opset18.Cast(mul_2, to=1)
_to_copy_6 = opset18.Cast(embedding, to=1)
scalar_tensor_default = opset18.Cast(2, to=1)
pow_1 = _to_copy_6 ** scalar_tensor_default
pow_1 = _to_copy_6**scalar_tensor_default
val_55 = opset18.Constant(value_ints=[-1])
val_57 = opset18.Reshape([-1], val_55, allowzero=0)
mean = opset18.ReduceMean(pow_1, val_57, keepdims=1, noop_with_empty_axes=0)
Expand Down Expand Up @@ -344,7 +350,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value
add_3 = embedding + view_15
_to_copy_8 = opset18.Cast(add_3, to=1)
scalar_tensor_default_1 = opset18.Cast(2, to=1)
pow_2 = _to_copy_8 ** scalar_tensor_default_1
pow_2 = _to_copy_8**scalar_tensor_default_1
val_224 = opset18.Constant(value_ints=[-1])
val_225 = opset18.Reshape([-1], val_224, allowzero=0)
mean_1 = opset18.ReduceMean(pow_2, val_225, keepdims=1, noop_with_empty_axes=0)
Expand Down Expand Up @@ -378,7 +384,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value
add_5 = add_3 + view_21
_to_copy_10 = opset18.Cast(add_5, to=1)
scalar_tensor_default_2 = opset18.Cast(2, to=1)
pow_3 = _to_copy_10 ** scalar_tensor_default_2
pow_3 = _to_copy_10**scalar_tensor_default_2
val_236 = opset18.Constant(value_ints=[-1])
val_237 = opset18.Reshape([-1], val_236, allowzero=0)
mean_2 = opset18.ReduceMean(pow_3, val_237, keepdims=1, noop_with_empty_axes=0)
Expand All @@ -400,18 +406,29 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value
model = main_graph.to_model_proto()
return model


def make_model_with_random_weights():
model_layers_0_input_layernorm_weight = numpy.random.rand(2048).astype(numpy.float32)
model_layers_0_post_attention_layernorm_weight = numpy.random.rand(2048).astype(numpy.float32)
model_layers_0_post_attention_layernorm_weight = numpy.random.rand(2048).astype(
numpy.float32
)
model_norm_weight = numpy.random.rand(2048).astype(numpy.float32)
lm_head_weight = numpy.random.rand(49152,2048).astype(numpy.float32)
model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32)
model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32)
model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32)
model_layers_0_self_attn_o_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32)
model_layers_0_mlp_gate_proj_weight = numpy.random.rand(8192,2048).astype(numpy.float32)
model_layers_0_mlp_up_proj_weight = numpy.random.rand(8192,2048).astype(numpy.float32)
model_layers_0_mlp_down_proj_weight = numpy.random.rand(2048,8192).astype(numpy.float32)
lm_head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32)
model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2048, 2048).astype(
numpy.float32
)
model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2048, 2048).astype(
numpy.float32
)
model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2048, 2048).astype(
numpy.float32
)
model_layers_0_self_attn_o_proj_weight = numpy.random.rand(2048, 2048).astype(
numpy.float32
)
model_layers_0_mlp_gate_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32)
model_layers_0_mlp_up_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32)
model_layers_0_mlp_down_proj_weight = numpy.random.rand(2048, 8192).astype(numpy.float32)
model_rotary_emb_inv_freq = numpy.random.rand(32).astype(numpy.float32)
model = make_model(
model_layers_0_input_layernorm_weight,
Expand All @@ -427,7 +444,8 @@ def make_model_with_random_weights():
model_layers_0_mlp_down_proj_weight,
model_rotary_emb_inv_freq,
)
return model
return model


class TestData:
def get_onnx_model(self):
Expand All @@ -442,8 +460,8 @@ def get_ort_inputs(self):
inputs = {
"input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64),
"position_ids": numpy.ones((1, 30), dtype=numpy.int64),
"past_key_values_0_0": numpy.random.rand(1,32,16,64).astype(numpy.float32),
"past_key_values_0_1": numpy.random.rand(1,32,16,64).astype(numpy.float32),
"past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
"past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32),
}
self._ort_inputs = inputs
return self._ort_inputs
5 changes: 3 additions & 2 deletions onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha


def fuse_xformers(model):
fuse_rms_normalization(model)
fuse_normalization(model)
fuse_rotary_embedding(model)
fuse_cos_sin_cache(model)
fuse_sdpa(model)
fuse_mha(model)
fuse_mha(model)
4 changes: 1 addition & 3 deletions onnxscript/rewriter/onnxruntime/xformers/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def _multi_head_attention(


_rule1 = pattern.RewriteRule(
_multi_head_attention_pattern,
_multi_head_attention,
_mha_validation
_multi_head_attention_pattern, _multi_head_attention, _mha_validation
)


Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/onnxruntime/xformers/mha_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import unittest

import onnxscript.optimizer
import onnxscript.rewriter.onnxruntime.xformers as xformers
from onnxscript.rewriter.onnxruntime.xformers._smollm_2 import TestData
from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run
import onnxscript.rewriter.onnxruntime.xformers as xformers


class TestMultiHeadAttention(unittest.TestCase):
Expand Down
1 change: 0 additions & 1 deletion onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,6 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult:
if not self.domain.matches(node.domain):
return match.fail(f"Domain mismatch: expected {self.domain}, got {node.domain}.")


for name, attr_pattern in self.attributes.items():
attr_value = node.attributes.get(name)
if attr_value is None:
Expand Down

0 comments on commit fa3b94d

Please sign in to comment.