diff --git a/docs/intermediate_representation/tensors.md b/docs/intermediate_representation/tensors.md index 0c3e25abc..a372e5f0b 100644 --- a/docs/intermediate_representation/tensors.md +++ b/docs/intermediate_representation/tensors.md @@ -137,8 +137,6 @@ In the following scenario, we show how to go from a `TensorProto` to an `ir.Tens print("tensor_mean.size:", tensor_mean.size) print("tensor_mean.nbytes:", tensor_mean.nbytes) print("tensor_mean.raw:", tensor_mean.raw) - print("\nUse the display() method to view the tensor") - tensor_mean.display() ``` ## Working with non-native NumPy dtypes: bfloat16, float8, int4 diff --git a/examples/pattern_rewriting.py b/examples/pattern_rewriting.py index 737ce02e8..7ebe10157 100644 --- a/examples/pattern_rewriting.py +++ b/examples/pattern_rewriting.py @@ -13,7 +13,6 @@ import onnx.helper as oh import onnx.numpy_helper as onh -import onnxscript from onnxscript import ir from onnxscript.rewriter import generic_pattern @@ -67,18 +66,15 @@ def get_rotary_model(bad_model=False): # The rewriting pattern # ===================== -op = onnxscript.opset18 -msft_op = onnxscript.values.Opset("com.microsoft", 1) - -def rotary_match_pattern(x, pos_ids, axis): +def rotary_match_pattern(op, x, pos_ids, axis): """The pattern to match.""" unsqueeze = op.Unsqueeze(x, axis) cast = op.Cast(unsqueeze, to=onnx.TensorProto.FLOAT) matmul = op.MatMul(pos_ids, cast) transpose = op.Transpose(matmul) - output, length = msft_op.ConcatTraining(transpose, transpose) + output, length = op.ConcatTraining(transpose, transpose, domain="com.microsoft", outputs=2) sin = op.Sin(output) cast1 = op.Cast(sin, to=onnx.TensorProto.FLOAT) @@ -87,25 +83,13 @@ def rotary_match_pattern(x, pos_ids, axis): return cast1, cast2 -def validate_rotary_mapping(g, match_result) -> bool: - """The validation post matching. - - Returns True to validate the replacement, - False not to apply it. - - :param g: model - :param match_result: matched nodes - """ - del g - del match_result - return True - - -def rotary_apply_pattern(x, pos_ids, axis): +def rotary_apply_pattern(op, x, pos_ids, axis): """The replacement pattern.""" cos_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) sin_cache = op.Constant(value=onh.from_array(np.random.rand(256, 256).astype(np.float16))) - part1, part2 = msft_op.RotaryEmbedding(x, pos_ids, cos_cache, sin_cache) + part1, part2 = op.RotaryEmbedding( + x, pos_ids, cos_cache, sin_cache, domain="com.microsoft", outputs=2 + ) return part1, part2 @@ -115,19 +99,10 @@ def rotary_apply_pattern(x, pos_ids, axis): # # The rule is easy to create. - -rule_with_validation_function = generic_pattern.make_pattern_rule( - rotary_match_pattern, - rotary_apply_pattern, - validate_rotary_mapping, +rule = generic_pattern.make_pattern_rule( + rotary_match_pattern, rotary_apply_pattern, verbose=10 ) -################################ -# ``validate_rotary_mapping`` always return True. -# This argument can be ignored in that case. - -rule = generic_pattern.make_pattern_rule(rotary_match_pattern, rotary_apply_pattern) - ########################## # Let's apply it. rule.apply_to_model(ir_model) @@ -167,6 +142,8 @@ def rotary_apply_pattern(x, pos_ids, axis): rule.apply_to_model(ir_model) +# TODO(rama): Update the following, the trace-printed looks different now. + ###################################### # The logs shows every time the algorithm rejected a pattern. # We can see the following: