Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MatMulNBits shape infer to SymbolicShapeInference #21246

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,9 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"GroupQueryAttention": self._infer_GroupQueryAttention,
"SparseAttention": self._infer_SparseAttention,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
"MatMulNBits": self._infer_MatMulNBits,
"MultiHeadAttention": self._infer_MultiHeadAttention,
"NhwcConv": self._infer_NhwcConv,
"PackedAttention": self._infer_PackedAttention,
Expand All @@ -223,8 +222,10 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"RestorePadding": self._infer_RestorePadding,
"RotaryEmbedding": self._infer_RotaryEmbedding,
"SimplifiedLayerNormalization": self._infer_LayerNormalization,
"SkipGroupNorm": self._infer_SkipGroupNorm,
"SkipLayerNormalization": self._infer_SkipLayerNormalization,
"SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
"SparseAttention": self._infer_SparseAttention,
}
self.aten_op_dispatcher_ = {
"embedding": self._infer_Gather,
Expand Down Expand Up @@ -1256,6 +1257,25 @@ def _infer_MatMul(self, node): # noqa: N802
def _infer_MatMulInteger(self, node): # noqa: N802
self._compute_matmul_shape(node, onnx.TensorProto.INT32)

def _infer_MatMulNBits(self, node): # noqa: N802
lhs_shape = self._get_shape(node, 0)
rhs_shape = [get_attribute(node, "K"), get_attribute(node, "N")]
lhs_rank = len(lhs_shape)
assert lhs_rank > 0
if lhs_rank == 1:
new_shape = rhs_shape[1:]
else:
new_shape = lhs_shape[:-1] + rhs_shape[1:]
# merge reduce dim
self._check_merged_dims(
[lhs_shape[-1], rhs_shape[0]],
allow_broadcast=False,
)
# infer output_dtype from input type when not specified
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))

def _infer_NonMaxSuppression(self, node): # noqa: N802
selected = str(self._new_symbolic_dim_from_output(node))
vi = self.known_vi_[node.output[0]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,55 @@ def test_dequantize_linear_ms_domain(self):
]
self._check_shapes(graph, inferred.graph, expected_shapes)

def test_matmulnbits(self):
"""
Test ORT MatMulNBits op.
Check that the output shape is propagated from the inputs and that the output data
type comes from the first input.
"""
b_np = numpy.random.randint(0, 255, (4, 1, 8), numpy.uint8)
b = numpy_helper.from_array(b_np, name="b")
scale_np = numpy.random.rand(4).astype(numpy.float32)
scale = numpy_helper.from_array(scale_np, name="scale")
zero_point_np = numpy.random.randint(0, 255, (4), numpy.uint8)
zero_point = numpy_helper.from_array(zero_point_np, name="zero_point")

initializers = [b, scale, zero_point]

kwargs = {"K": 10, "N": 4, "block_size": 16}

nodes = [
helper.make_node(
"MatMulNBits",
inputs=[
"input_f32",
"b",
"scale",
"zero_point",
],
outputs=["output_f32"],
**kwargs,
),
]

inputs = [
helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["x", 2, 3, 10]),
]

outputs = [
helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None),
]

graph = helper.make_graph(nodes, "MatMulNBits_Test", inputs, outputs, initializers)
model = helper.make_model(graph)

inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True)

expected_shapes = [
helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["x", 2, 3, 4]),
]
self._check_shapes(graph, inferred.graph, expected_shapes)


class TestSymbolicShapeInferenceForSlice(unittest.TestCase):
def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim):
Expand Down
Loading