Skip to content

Commit

Permalink
Migrate onnxrewriter (#1346)
Browse files Browse the repository at this point in the history
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1334
* #1340
* __->__ #1346

Squashed of the following steps:
- #1328
- #1329
- #1330
- #1331
- #1332
- #1333
- #1343
- #1345

Co-authored-by: Shubham Bhokare
<[email protected]>
Co-authored-by: Justin Chu <[email protected]>
Co-authored-by: Xavier Dupré <[email protected]>
Co-authored-by: "G. Ramalingam" <[email protected]>
Co-authored-by: kunal-vaishnavi
<[email protected]>
Co-authored-by: Ti-Tai Wang <[email protected]>
  • Loading branch information
5 people authored Apr 5, 2024
1 parent e29b43a commit 2c74be7
Show file tree
Hide file tree
Showing 388 changed files with 12,542 additions and 94 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
**/*.pb filter=lfs diff=lfs merge=lfs -text
**/*.onnx filter=lfs diff=lfs merge=lfs -text
2 changes: 2 additions & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install nox
run: python -m pip install nox
- name: Pull Test Data
run: git lfs pull
- name: Run tests
run: nox -t ${{ matrix.nox-tag }} --forcecolor -- -v --cov=onnxscript --cov-report=xml --cov-append --cov-branch -n=auto --junit-xml pytest.xml
env:
Expand Down
33 changes: 27 additions & 6 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include_patterns = [
'**/*.pyi',
]
exclude_patterns = [
'onnxscript/tests/models/**',
'tests/models/**',
]
command = [
'python',
Expand Down Expand Up @@ -43,9 +43,26 @@ exclude_patterns = [
'onnxscript/evaluator_test.py',
'onnxscript/evaluator.py',
'onnxscript/onnx_types.py',
'onnxscript/tests/**', # Skip linting test files for speed
'tests/**', # Skip linting test files for speed
'onnxscript/**/*_test.py', # Skip linting test files for speed
'onnxscript/function_libs/torch_lib/ops/**', # Operators typing do not play well with mypy
'onnxscript/optimizer/evaluator.py', # FIXME
'onnxscript/optimizer/constant_folding.py', # FIXME
'onnxscript/_legacy_ir/__init__.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/fastgelu.py', # FIXME
'onnxscript/rewriter/onnxruntime/instance_to_group_normalization.py', # FIXME
'onnxscript/rewriter/function_rule.py', # FIXME
'onnxscript/_legacy_ir/irbuilder.py', # FIXME
'onnxscript/optimizer/fold_constants_v0.py', # FIXME
'onnxscript/rewriter/pattern.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/multihead_attention.py', # FIXME
'onnxscript/tools/function_unittest_producer.py', # FIXME
'onnxscript/_legacy_ir/visitor.py', # FIXME
'onnxscript/_legacy_ir/protobuilder.py', # FIXME
'onnxscript/rewriter/onnxruntime/transformers/layernorm.py', # FIXME
'onnxscript/ir/serde.py', # FIXME
'onnxscript/rewriter/generic_pattern_test.py', # FIXME
'onnxscript/rewriter/generic_pattern.py', # FIXME
]
command = [
'python',
Expand Down Expand Up @@ -74,7 +91,7 @@ include_patterns = [
'**/*.py',
]
exclude_patterns = [
'onnxscript/tests/onnx_backend_test_code/**',
'tests/onnx_backend_test_code/**',
]
command = [
'python',
Expand Down Expand Up @@ -102,12 +119,16 @@ include_patterns = [
'**/*.py',
]
exclude_patterns = [
'examples/**', # TODO: Merge with docs/examples
'docs/examples/**',
'docs/tutorial/examples/**',
'onnxscript/converter_test.py',
'onnxscript/tests/functions/**',
'onnxscript/tests/models/**',
'onnxscript/tests/onnx_backend_test_code/**',
'tests/functions/**',
'tests/models/**',
'tests/onnx_backend_test_code/**',
'onnxscript/optimizer/**', # FIXME
'onnxscript/rewriter/**', # FIXME
'onnxscript/_legacy_ir/**', # FIXME
]
command = [
'python',
Expand Down
191 changes: 191 additions & 0 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Onnx Pattern Rewriting.
This script shows how to define a rewriting rule based on patterns.
The objective is to replace some nodes in an onnx model into another
sequence of nodes but more efficient.
First a dummy model
===================
"""

import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
import onnxscript._legacy_ir as oir
import onnxscript.rewriter.generic_pattern as org


def get_rotary_model(bad_model=False):
inputs = [
oh.make_tensor_value_info("x", onnx.TensorProto.INT64, shape=[]),
oh.make_tensor_value_info("pos_ids", onnx.TensorProto.FLOAT, shape=[]),
oh.make_tensor_value_info("axis", onnx.TensorProto.INT64, shape=[]),
]
nodes = [
oh.make_node("Unsqueeze", ["x", "axis"], ["_onx_unsqueeze0"]),
oh.make_node("Cast", ["_onx_unsqueeze0"], ["_onx_cast0"], to=1),
oh.make_node("MatMul", ["pos_ids", "_onx_cast0"], ["_onx_matmul0"]),
oh.make_node("Transpose", ["_onx_matmul0"], ["_onx_transpose0"]),
oh.make_node(
"ConcatTrainingBad" if bad_model else "ConcatTraining",
["_onx_transpose0", "_onx_transpose0"],
["_onx_concattraining0", "_onx_concattraining1"],
domain="com.microsoft",
),
oh.make_node("Sin", ["_onx_concattraining0"], ["_onx_sin0"]),
oh.make_node("Cast", ["_onx_sin0"], ["_onx_cast02"], to=1),
oh.make_node("Cos", ["_onx_concattraining0"], ["_onx_cos0"]),
oh.make_node("Cast", ["_onx_cos0"], ["_onx_cast03"], to=1),
]
outputs = [
oh.make_tensor_value_info("_onx_cast02", onnx.TensorProto.UNDEFINED, []),
oh.make_tensor_value_info("_onx_cast03", onnx.TensorProto.UNDEFINED, []),
]
model = oh.make_model(
oh.make_graph(
nodes,
"experiment",
inputs,
outputs,
),
opset_imports=[
oh.make_opsetid("", 18),
oh.make_opsetid("com.microsoft", 18),
],
)
return model


model = get_rotary_model()
ir_model = oir.irbuilder.build_ir(model)


####################################
# The rewriting pattern
# =====================

op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)


def rotary_match_pattern(x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)

matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, length = msft_op.ConcatTraining(transpose, transpose)

sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
cos = op.Cos(output)
cast2 = op.Cast(cos, to=onnx.TensorProto.FLOAT)
return cast1, cast2


def validate_rotary_mapping(g, matched_nodes, added_nodes) -> bool:
"""The validation post matching.
Returns True to validate the replacement,
False not to apply it.
:param g: model
:param matched_nodes: matched nodes
:param added_nodes: nodes replacing the matched nodes
"""
del g
del matched_nodes
del added_nodes
return True


def rotary_apply_pattern(x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
return part1, part2


###########################
# The rule
# ========
#
# The rule is easy to create.


rule = org.make_pattern_rule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
)

################################
# ``validate_rotary_mapping`` always return True.
# This argument can be ignored in that case.

rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)

##########################
# Let's apply it.
rule.apply_to_model(ir_model)


########################
# And finally, we can generate the model.

opt_onx = oir.protobuilder.build_model_proto(ir_model)

########################
# Let's see what it looks like.

for node in opt_onx.graph.node:
print(f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}")

#############################
# What if it fails?
# =================


model = get_rotary_model(True)
ir_model = oir.irbuilder.build_ir(model)

rule.apply_to_model(ir_model)
opt_onx = oir.protobuilder.build_model_proto(ir_model)

print([n.op_type for n in opt_onx.graph.node])

################################
# The match did not happen.
# Let's increase the verbosity.

rule = org.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern, verbose=10)

rule.apply_to_model(ir_model)

######################################
# The logs shows every time the algorithm rejected a pattern.
# We can see the following:
#
# ::
#
# [OnnxGenericPattern.match] NONE - line: 673:onnxscript.rewriter.generic_pattern, op_type=Cast
# --hint--: BACKWARD: different node types
# --pattern
# ConcatTraining(transpose, transpose) -> (output, length)
# -- model
# ConcatTrainingBad(_onx_transpose0, _onx_transpose0) -> (_onx_concattraining0, _onx_concattraining1)
# iteration=1
# --marked-- #2
# Cast(_onx_cos0) ~ Cast(cos) [140186194226496-140186194222320]
# Cos(_onx_concattraining0) ~ Cos(output) [140186194230816-140186194223472]
# len(stacked)=0:[]
#
# Line 673 in file `generic_pattern.py`, the match was rejected.
# It says while comparing two nodes in the backward direction,
# node types do not match.
# It also says that two nodes were actually matched.
Loading

0 comments on commit 2c74be7

Please sign in to comment.