Skip to content

Commit

Permalink
Update example (#1521)
Browse files Browse the repository at this point in the history
The example in `examples\pattern_rewriting.py` was not updated when the
pattern-matchers API were unified. Fix this.

Question: do we want this standalone example? I think the documentation
folder examples are a better place to add anything we want, and this can
be removed. (But leaving it here for now.)

TODO: Improve the logging info produced in verbose mode, especially for
the pattern graph.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
gramalingam and justinchuby authored May 10, 2024
1 parent 66d34e4 commit 9153dda
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 35 deletions.
2 changes: 0 additions & 2 deletions docs/intermediate_representation/tensors.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ In the following scenario, we show how to go from a `TensorProto` to an `ir.Tens
print("tensor_mean.size:", tensor_mean.size)
print("tensor_mean.nbytes:", tensor_mean.nbytes)
print("tensor_mean.raw:", tensor_mean.raw)
print("\nUse the display() method to view the tensor")
tensor_mean.display()
```

## Working with non-native NumPy dtypes: bfloat16, float8, int4
Expand Down
43 changes: 10 additions & 33 deletions examples/pattern_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import onnx.helper as oh
import onnx.numpy_helper as onh

import onnxscript
from onnxscript import ir
from onnxscript.rewriter import generic_pattern

Expand Down Expand Up @@ -67,18 +66,15 @@ def get_rotary_model(bad_model=False):
# The rewriting pattern
# =====================

op = onnxscript.opset18
msft_op = onnxscript.values.Opset("com.microsoft", 1)


def rotary_match_pattern(x, pos_ids, axis):
def rotary_match_pattern(op, x, pos_ids, axis):
"""The pattern to match."""
unsqueeze = op.Unsqueeze(x, axis)
cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT)

matmul = op.MatMul(pos_ids, cast)
transpose = op.Transpose(matmul)
output, length = msft_op.ConcatTraining(transpose, transpose)
output, length = op.ConcatTraining(transpose, transpose, domain="com.microsoft", outputs=2)

sin = op.Sin(output)
cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT)
Expand All @@ -87,25 +83,13 @@ def rotary_match_pattern(x, pos_ids, axis):
return cast1, cast2


def validate_rotary_mapping(g, match_result) -> bool:
"""The validation post matching.
Returns True to validate the replacement,
False not to apply it.
:param g: model
:param match_result: matched nodes
"""
del g
del match_result
return True


def rotary_apply_pattern(x, pos_ids, axis):
def rotary_apply_pattern(op, x, pos_ids, axis):
"""The replacement pattern."""
cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16)))
part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache)
part1, part2 = op.RotaryEmbedding(
x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2
)
return part1, part2


Expand All @@ -115,19 +99,10 @@ def rotary_apply_pattern(x, pos_ids, axis):
#
# The rule is easy to create.


rule_with_validation_function = generic_pattern.make_pattern_rule(
rotary_match_pattern,
rotary_apply_pattern,
validate_rotary_mapping,
rule = generic_pattern.make_pattern_rule(
rotary_match_pattern, rotary_apply_pattern, verbose=10
)

################################
# ``validate_rotary_mapping`` always return True.
# This argument can be ignored in that case.

rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern)

##########################
# Let's apply it.
rule.apply_to_model(ir_model)
Expand Down Expand Up @@ -167,6 +142,8 @@ def rotary_apply_pattern(x, pos_ids, axis):

rule.apply_to_model(ir_model)

# TODO(rama): Update the following, the trace-printed looks different now.

######################################
# The logs shows every time the algorithm rejected a pattern.
# We can see the following:
Expand Down

0 comments on commit 9153dda

Please sign in to comment.