From fa3b94d8931302893c6681c92a098f6165649bcb Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 26 Dec 2024 15:56:50 -0800 Subject: [PATCH] Run lint --- .../rewriter/onnxruntime/xformers/__init__.py | 4 +- .../onnxruntime/xformers/_smollm_2.py | 56 ++++++++++++------- .../onnxruntime/xformers/fuse_xformers.py | 5 +- .../rewriter/onnxruntime/xformers/mha.py | 4 +- .../rewriter/onnxruntime/xformers/mha_test.py | 2 +- onnxscript/rewriter/pattern.py | 1 - 6 files changed, 44 insertions(+), 28 deletions(-) diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index 38ff2619a..c5b1803a1 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -3,12 +3,12 @@ from __future__ import annotations from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers +from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization -from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha -from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers __all__ = [ "fuse_rms_normalization", diff --git a/onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py b/onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py index 085d1fe55..48258760f 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_smollm_2.py @@ -7,10 +7,11 @@ """ import numpy -from onnxscript import script + import onnxscript.ir as ir -from onnxscript.onnx_types import FLOAT, INT64 +from onnxscript import script from onnxscript.onnx_opset import opset18 +from onnxscript.onnx_types import FLOAT, INT64 def make_model( @@ -28,11 +29,16 @@ def make_model( model_rotary_emb_inv_freq, ): @script() - def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_values_0_0: FLOAT[1,32,16,64], past_key_values_0_1: FLOAT[1,32,16,64]) -> (FLOAT[1,30,49152], FLOAT[1,32,46,64], FLOAT[1,32,46,64]): + def main_graph( + input_ids: INT64[1, 30], + position_ids: INT64[1, 30], + past_key_values_0_0: FLOAT[1, 32, 16, 64], + past_key_values_0_1: FLOAT[1, 32, 16, 64], + ) -> (FLOAT[1, 30, 49152], FLOAT[1, 32, 46, 64], FLOAT[1, 32, 46, 64]): embedding = opset18.Gather(lm_head_weight, input_ids, axis=0) val_2 = opset18.CastLike(1.0, 46) arange = opset18.Range(16, 46, val_2) - val_5 = opset18.Cast(-3.4028235e+38, to=1) + val_5 = opset18.Cast(-3.4028235e38, to=1) val_7 = opset18.Cast([30, 47], to=7) full = opset18.Expand(val_5, val_7) diagonal__1 = opset18.Constant(value_int=1) @@ -117,7 +123,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value _to_copy_5 = opset18.Cast(mul_2, to=1) _to_copy_6 = opset18.Cast(embedding, to=1) scalar_tensor_default = opset18.Cast(2, to=1) - pow_1 = _to_copy_6 ** scalar_tensor_default + pow_1 = _to_copy_6**scalar_tensor_default val_55 = opset18.Constant(value_ints=[-1]) val_57 = opset18.Reshape([-1], val_55, allowzero=0) mean = opset18.ReduceMean(pow_1, val_57, keepdims=1, noop_with_empty_axes=0) @@ -344,7 +350,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value add_3 = embedding + view_15 _to_copy_8 = opset18.Cast(add_3, to=1) scalar_tensor_default_1 = opset18.Cast(2, to=1) - pow_2 = _to_copy_8 ** scalar_tensor_default_1 + pow_2 = _to_copy_8**scalar_tensor_default_1 val_224 = opset18.Constant(value_ints=[-1]) val_225 = opset18.Reshape([-1], val_224, allowzero=0) mean_1 = opset18.ReduceMean(pow_2, val_225, keepdims=1, noop_with_empty_axes=0) @@ -378,7 +384,7 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value add_5 = add_3 + view_21 _to_copy_10 = opset18.Cast(add_5, to=1) scalar_tensor_default_2 = opset18.Cast(2, to=1) - pow_3 = _to_copy_10 ** scalar_tensor_default_2 + pow_3 = _to_copy_10**scalar_tensor_default_2 val_236 = opset18.Constant(value_ints=[-1]) val_237 = opset18.Reshape([-1], val_236, allowzero=0) mean_2 = opset18.ReduceMean(pow_3, val_237, keepdims=1, noop_with_empty_axes=0) @@ -400,18 +406,29 @@ def main_graph(input_ids: INT64[1,30], position_ids: INT64[1,30], past_key_value model = main_graph.to_model_proto() return model + def make_model_with_random_weights(): model_layers_0_input_layernorm_weight = numpy.random.rand(2048).astype(numpy.float32) - model_layers_0_post_attention_layernorm_weight = numpy.random.rand(2048).astype(numpy.float32) + model_layers_0_post_attention_layernorm_weight = numpy.random.rand(2048).astype( + numpy.float32 + ) model_norm_weight = numpy.random.rand(2048).astype(numpy.float32) - lm_head_weight = numpy.random.rand(49152,2048).astype(numpy.float32) - model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32) - model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32) - model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32) - model_layers_0_self_attn_o_proj_weight = numpy.random.rand(2048,2048).astype(numpy.float32) - model_layers_0_mlp_gate_proj_weight = numpy.random.rand(8192,2048).astype(numpy.float32) - model_layers_0_mlp_up_proj_weight = numpy.random.rand(8192,2048).astype(numpy.float32) - model_layers_0_mlp_down_proj_weight = numpy.random.rand(2048,8192).astype(numpy.float32) + lm_head_weight = numpy.random.rand(49152, 2048).astype(numpy.float32) + model_layers_0_self_attn_q_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_k_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_v_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_self_attn_o_proj_weight = numpy.random.rand(2048, 2048).astype( + numpy.float32 + ) + model_layers_0_mlp_gate_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32) + model_layers_0_mlp_up_proj_weight = numpy.random.rand(8192, 2048).astype(numpy.float32) + model_layers_0_mlp_down_proj_weight = numpy.random.rand(2048, 8192).astype(numpy.float32) model_rotary_emb_inv_freq = numpy.random.rand(32).astype(numpy.float32) model = make_model( model_layers_0_input_layernorm_weight, @@ -427,7 +444,8 @@ def make_model_with_random_weights(): model_layers_0_mlp_down_proj_weight, model_rotary_emb_inv_freq, ) - return model + return model + class TestData: def get_onnx_model(self): @@ -442,8 +460,8 @@ def get_ort_inputs(self): inputs = { "input_ids": numpy.random.randint(0, 49152, (1, 30)).astype(numpy.int64), "position_ids": numpy.ones((1, 30), dtype=numpy.int64), - "past_key_values_0_0": numpy.random.rand(1,32,16,64).astype(numpy.float32), - "past_key_values_0_1": numpy.random.rand(1,32,16,64).astype(numpy.float32), + "past_key_values_0_0": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32), + "past_key_values_0_1": numpy.random.rand(1, 32, 16, 64).astype(numpy.float32), } self._ort_inputs = inputs return self._ort_inputs diff --git a/onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py b/onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py index d837f0467..13161115b 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py +++ b/onnxscript/rewriter/onnxruntime/xformers/fuse_xformers.py @@ -3,11 +3,12 @@ from __future__ import annotations from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization -from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha + def fuse_xformers(model): fuse_rms_normalization(model) @@ -15,4 +16,4 @@ def fuse_xformers(model): fuse_rotary_embedding(model) fuse_cos_sin_cache(model) fuse_sdpa(model) - fuse_mha(model) \ No newline at end of file + fuse_mha(model) diff --git a/onnxscript/rewriter/onnxruntime/xformers/mha.py b/onnxscript/rewriter/onnxruntime/xformers/mha.py index d2964aa91..d1b56c059 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/mha.py +++ b/onnxscript/rewriter/onnxruntime/xformers/mha.py @@ -166,9 +166,7 @@ def _multi_head_attention( _rule1 = pattern.RewriteRule( - _multi_head_attention_pattern, - _multi_head_attention, - _mha_validation + _multi_head_attention_pattern, _multi_head_attention, _mha_validation ) diff --git a/onnxscript/rewriter/onnxruntime/xformers/mha_test.py b/onnxscript/rewriter/onnxruntime/xformers/mha_test.py index 24ecf952a..d9f5d240a 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/mha_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/mha_test.py @@ -5,9 +5,9 @@ import unittest import onnxscript.optimizer +import onnxscript.rewriter.onnxruntime.xformers as xformers from onnxscript.rewriter.onnxruntime.xformers._smollm_2 import TestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run -import onnxscript.rewriter.onnxruntime.xformers as xformers class TestMultiHeadAttention(unittest.TestCase): diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 14f9cddab..23fa02ddf 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -541,7 +541,6 @@ def matches(self, node: ir.Node, match: MatchResult) -> MatchResult: if not self.domain.matches(node.domain): return match.fail(f"Domain mismatch: expected {self.domain}, got {node.domain}.") - for name, attr_pattern in self.attributes.items(): attr_value = node.attributes.get(name) if attr_value is None: