From 4af62918418a04bcd219471dfb8f8b27a1e8a8b5 Mon Sep 17 00:00:00 2001 From: duanshengliu <44742794+duanshengliu@users.noreply.github.com> Date: Sat, 24 Aug 2024 04:45:06 +0800 Subject: [PATCH] Refine `op_types_to_quantize` argument handling in matmul_4bits_quantizer.py (#21815) ### Description Refine `op_types_to_quantize` argument handling in matmul_4bits_quantizer.py ### Motivation and Context The default `op_types_to_quantize "MatMul"` will cause `tuple(args.op_types_to_quantize)` to become `('M', 'a', 't', 'M', 'u', 'l')`, which is not expected. --- .../python/tools/quantization/matmul_4bits_quantizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 975f82439c160..16ad36c48cc74 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -1062,7 +1062,6 @@ def parse_args(): ) parser.add_argument( "--op_types_to_quantize", - default="MatMul", type=str, nargs="+", choices=["MatMul", "Gather"], @@ -1089,7 +1088,7 @@ def parse_args(): input_model_path = args.input_model output_model_path = args.output_model quant_format = QuantFormat[args.quant_format] - op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else None + op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else ("MatMul",) quant_axes = tuple(args.quant_axes) if args.quant_axes else None if os.path.exists(output_model_path):