From afbec0f59772c0df688ec85a6744c2caf7a76a54 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Wed, 3 Jul 2024 19:29:58 -0700 Subject: [PATCH 1/3] added ut --- .../python/tools/symbolic_shape_infer.py | 20 +++++++ ...untime_test_python_symbolic_shape_infer.py | 53 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9bc2328cc71b6..213b3ab4299aa 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -225,6 +225,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "SimplifiedLayerNormalization": self._infer_LayerNormalization, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, + "MatMulNBits": self._infer_MatMulNBits, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather, @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802 def _infer_MatMulInteger(self, node): # noqa: N802 self._compute_matmul_shape(node, onnx.TensorProto.INT32) + def _infer_MatMulNBits(self, node): # noqa: N802 + lhs_shape = self._get_shape(node, 0) + rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")] + lhs_rank = len(lhs_shape) + assert lhs_rank > 0 + if lhs_rank == 1: + new_shape = rhs_shape[1:] + else: + new_shape = lhs_shape[:-1] + rhs_shape[1:] + # merge reduce dim + self._check_merged_dims( + [lhs_shape[-1], rhs_shape[0]], + allow_broadcast=False, + ) + # infer output_dtype from input type when not specified + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) + def _infer_NonMaxSuppression(self, node): # noqa: N802 selected = str(self._new_symbolic_dim_from_output(node)) vi = self.known_vi_[node.output[0]] diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index eca1430448e8e..dd919a87feb95 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -594,6 +594,59 @@ def test_dequantize_linear_ms_domain(self): ] self._check_shapes(graph, inferred.graph, expected_shapes) + def test_matmulnbits(self): + """ + Test ORT MatMulNBits op. + Check that the output shape is propagated from the inputs and that the output data + type comes from the first input. + """ + b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8) + b = numpy_helper.from_array(b_np, name="b") + scale_np = numpy.random.rand(4).astype(numpy.float32) + scale = numpy_helper.from_array(scale_np, name="scale") + zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8) + zero_point = numpy_helper.from_array(zero_point_np, name="zero_point") + + initializers = [b, scale, zero_point] + + kwargs = { + "K": 10, + "N": 4, + "block_size": 16 + } + + nodes = [ + helper.make_node( + "MatMulNBits", + inputs=[ + "input_f32", + "b", + "scale", + "zero_point", + ], + outputs=["output_f32"], + **kwargs + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): From f54a984170c83e792595476e17e7ad700e89a979 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 5 Jul 2024 10:27:50 -0700 Subject: [PATCH 2/3] fix linting --- .../onnxruntime_test_python_symbolic_shape_infer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index dd919a87feb95..29680c98fb4de 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -609,11 +609,7 @@ def test_matmulnbits(self): initializers = [b, scale, zero_point] - kwargs = { - "K": 10, - "N": 4, - "block_size": 16 - } + kwargs = {"K": 10, "N": 4, "block_size": 16} nodes = [ helper.make_node( @@ -625,7 +621,7 @@ def test_matmulnbits(self): "zero_point", ], outputs=["output_f32"], - **kwargs + **kwargs, ), ] From 917787666a752ca5d88e8a1bb18ed2fcd5696e93 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Fri, 5 Jul 2024 11:02:51 -0700 Subject: [PATCH 3/3] fixed order --- onnxruntime/python/tools/symbolic_shape_infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 213b3ab4299aa..ac959d5c061f7 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "GroupQueryAttention": self._infer_GroupQueryAttention, - "SparseAttention": self._infer_SparseAttention, - "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, + "MatMulNBits": self._infer_MatMulNBits, "MultiHeadAttention": self._infer_MultiHeadAttention, "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, @@ -223,9 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "RestorePadding": self._infer_RestorePadding, "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, + "SkipGroupNorm": self._infer_SkipGroupNorm, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, - "MatMulNBits": self._infer_MatMulNBits, + "SparseAttention": self._infer_SparseAttention, } self.aten_op_dispatcher_ = { "embedding": self._infer_Gather,