From 20be09637372b903685dd090d3c5756964c7128c Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 12 Aug 2024 12:35:02 -0700 Subject: [PATCH] Support QDQ GatherElements in quantization tool --- onnxruntime/python/tools/quantization/operators/gather.py | 2 +- onnxruntime/python/tools/quantization/registry.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/operators/gather.py b/onnxruntime/python/tools/quantization/operators/gather.py index e390e874a2662..e6314407cdc9a 100644 --- a/onnxruntime/python/tools/quantization/operators/gather.py +++ b/onnxruntime/python/tools/quantization/operators/gather.py @@ -55,7 +55,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert node.op_type == "Gather" + assert node.op_type == "Gather" or node.op_type == "GatherElements" if self.quantizer.is_valid_quantize_weight(node.input[0]) or self.quantizer.force_quantize_no_input_check: self.quantizer.quantize_activation_tensor(node.input[0]) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index caac829126e38..160b056e1de17 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -79,6 +79,7 @@ "MatMul": QDQMatMul, "Split": QDQSplit, "Gather": QDQGather, + "GatherElements": QDQGather, "Where": QDQWhere, "InstanceNormalization": QDQNormalization, "LayerNormalization": QDQNormalization,