diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..0bb5c7432f0a7 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" #include "core/optimizer/graph_transformer.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/optimizer/rule_based_graph_transformer.h" @@ -49,7 +50,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -78,7 +80,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c32e2a77e8453..17ae649e6f174 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // - "0": Gemm FastMath mode is not enabled. [DEFAULT] // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..6e5be28f12745 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) @@ -187,7 +188,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -287,6 +289,10 @@ InlinedVector> GenerateTransformers( onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -300,7 +306,10 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + SatApplyContextVariant{}, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -409,7 +418,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -423,12 +433,18 @@ InlinedVector> GenerateTransformersForMinimalB const bool qdq_is_int8_allowed = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; - + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + apply_context, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3d2a81ce7f8cd..3985827f235ea 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" - #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" +#include "core/mlas/inc/mlas_q4.h" + namespace onnxruntime { namespace QDQ { @@ -273,6 +275,176 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + domain_{kMSDomain}, + op_type_{"MatMulNBits"}, + value_moves_{[]() { + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; + return std::vector{ + MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), + MoveAll(target, ArgType::kOutput)}; + }()}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); +} + +NodeAttributes +DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_state) const { + NodeAttributes extra_attributes; + + const auto* dq_node = runtime_state.selected_nodes.Input(0); + auto& attrs = dq_node->GetAttributes(); + const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); + // currently only 4bits is supported. In the future, derive bits from DQ's weight type. + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + + return extra_attributes; +} + +Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { + ORT_RETURN_IF_NOT(intra_op_thread_pool_, "Passed in thread pool should not be null"); + const auto* dq_node = selected_nodes.Input(0); + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); + graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size + 1) / 2; + + // Unfortunately iterating the source data is complicated, the data maybe in + // external file, a raw buffer, or a repeated field depending on the data + // type. UnpackTensor() already contains some of these logic and is closest + // to what we need. But it does not handle external data. + Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + std::optional> zp_src_ptr; + Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(weight_arg->Name() + "_T"), + std::vector{N, quant_num, blob_bytes}); + Initializer scale_dst(static_cast(scale_src.data_type()), + graph.GenerateNodeArgName(scale_arg->Name() + "_T"), + std::vector{N * quant_num}); + std::optional> zp_dst_ptr; + + if (zp_tensor_proto) { + zp_src_ptr.emplace(std::make_unique(*zp_tensor_proto, graph.ModelPath())); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)})); + } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)})); + } + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } else { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_); + } + } + + ONNX_NAMESPACE::TensorProto weight_T_tp; + ONNX_NAMESPACE::TensorProto scale_T_tp; + std::optional> zp_T_tp_ptr; + + // TODO(fajin): external_data to memory location to avoid arena allocation + // https://github.com/microsoft/onnxruntime/pull/12465 + weight_dst.ToProto(weight_T_tp); + scale_dst.ToProto(scale_T_tp); + if (zp_dst_ptr) { + zp_T_tp_ptr = std::make_unique(); + zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value()); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp_ptr) { + input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value())); + replacement_node.MutableInputArgsCount().push_back(1); + } + + return Status::OK(); +} + static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 8179a030508a5..d80c3f9d183bf 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -3,7 +3,12 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -76,6 +81,30 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +// used together with DQMatMulNodeGroupSelector, which does the sanity check +struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew { + DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + private: + std::string OpType(const RuntimeState&) const override { return op_type_; } + + std::string Domain(const RuntimeState&) const override { return domain_; } + + NodeAttributes ExtraAttributes(const RuntimeState&) const override; + + std::vector ValueMoves(const RuntimeState&) const override { return value_moves_; } + + // transpose initializers, and add to the MatMulNBits inputs + Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override; + + const int64_t accuracy_level_; + const std::string domain_; + const std::string op_type_; + const std::vector value_moves_; + concurrency::ThreadPool* intra_op_thread_pool_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 80ead8f8c68d6..0b10f092cc565 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,6 +228,30 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } +void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { + // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. + // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + const std::string action_name{"DQMatMul"}; + + std::unique_ptr action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr selector = std::make_unique(); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"MatMul", {}}}, + std::move(selector), + std::move(action)); + +#else + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif +} + void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) // Replace with QGemm @@ -271,7 +295,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { +SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -283,17 +309,22 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); + DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context) +QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 1780923f3f273..ba636f76d1900 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -5,6 +5,7 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -21,7 +22,10 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); + QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context = {}, + int64_t qdq_matmulnbits_accuracy_level = 4, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 09705f61c82ce..11185491da9b5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -414,6 +414,87 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + ORT_UNUSED_PARAMETER(q_nodes); + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + const auto* weight_arg = dq_nodes[0]->InputDefs()[0]; + const auto* scale_arg = dq_nodes[0]->InputDefs()[1]; + const auto* zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return false; + } + + if (!Is4BitIntType(dt_weight)) { + return false; + } + + // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 + const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + if (const auto a_iter = dq_attrs.find("axis"); + a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return false; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + // weight, scale and zero points (if exists) must be constants + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; + + if (!weight_tensor_proto || !scale_tensor_proto) { + return false; + } + + if (zero_point_arg && !zp_tensor_proto) { + return false; + } + + // weight, scale and zero points (if exists) must have the rank 2 + if (weight_tensor_proto->dims_size() != 2 || + scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + // check weight, scale and zero points (if exists) shapes + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && + (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + + return true; +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 1a2a620acb480..491a15b62cb03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -204,6 +204,14 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulNodeGroupSelector : public NodeGroupSelector { + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { @@ -358,6 +366,13 @@ class MatMulSelector : public BaseSelector { allow_16bit, allow_4bit)) {} }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulToMatMulNBitsSelector : public BaseSelector { + public: + explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index c8d5acbf66b78..bb4033afedc49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -102,12 +102,14 @@ static Status CreateReplacementNode(Graph& graph, Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { const RuntimeState runtime_state{graph, selected_nodes}; + Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, OpType(runtime_state), Domain(runtime_state), ExtraAttributes(runtime_state), ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + /* only_update_dest_definitions */ false, &replacement)); + ORT_RETURN_IF_ERROR(ProcessNewNode(graph, selected_nodes, *replacement)); return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9384bfa7027cd..4d5b520cc47cb 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,8 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } + RemoveNodes node_remover_; }; @@ -187,5 +189,4 @@ struct ReplaceWithNewFixed : public ReplaceWithNew { const NodeAttributes extra_attrs_; const std::vector value_moves_; }; - } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 3ef6490a56ded..72cafa1034a4c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1613,7 +1613,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep, + concurrency::ThreadPool* intra_op_thread_pool) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); @@ -1621,7 +1622,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + optimizers_to_disable, intra_op_thread_pool); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -2009,7 +2010,8 @@ common::Status InferenceSession::Initialize() { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, + cpu_ep, GetIntraOpThreadPoolToUse())); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3171,7 +3173,8 @@ common::Status InferenceSession::AddPredefinedTransformers( if (use_full_build_optimizations) { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } else { const auto sat_context = minimal_build_optimization_handling == @@ -3180,7 +3183,8 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); } }(); diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9ab4a82463d51..9bc50ce88ef16 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime { @@ -108,6 +109,22 @@ class RandomValueGenerator { return val; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt4 min, TInt4 max) { + using UnpackedType = typename TInt4::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 1; + size_t c = i & 0x1; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index 0282d09f340b2..33c56e6583e6b 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -116,22 +116,6 @@ class ModelTestBuilder { return MakeInput(shape, data); } - template - typename std::enable_if< - std::is_same_v || std::is_same_v, - NodeArg*>::type - MakeInputInt4(const std::vector& shape, typename TInt4::UnpackedType min, typename TInt4::UnpackedType max) { - using UnpackedType = typename TInt4::UnpackedType; - std::vector data_int8 = rand_gen_.Uniform(shape, min, max); - std::vector data(TInt4::CalcNumInt4Pairs(data_int8.size())); - for (size_t i = 0; i < data_int8.size(); i++) { - size_t r = i >> 1; - size_t c = i & 0x1; - data[r].SetElem(c, data_int8[i]); - } - return MakeInput(shape, data); - } - template NodeArg* MakeInput(const std::optional>& shape, std::optional input_name = std::nullopt) { diff --git a/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc new file mode 100644 index 0000000000000..3d117794104fa --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_matmulnbits_transformer_test.cc @@ -0,0 +1,425 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/common/span_utils.h" +#include "core/framework/int4.h" +#include "core/graph/node_attr_utils.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" + +#if defined(_MSC_VER) +#pragma warning(disable : 4127) +#endif // #if defined(_MSC_VER) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +// Input1 Input2 +// | | +// \ DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_NonConstDQ(const std::vector& input1_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* input2_arg = builder.MakeInput(input2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{input2_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {input2_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input1_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_NonConstDQ) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_NonConstDQ({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input2 +// | +// DQ / +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulNotConverted_FirstDQInput(const std::vector& weight_shape, + const std::vector& input2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* input2_arg = builder.MakeInput(input2_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {dq_output, input2_arg}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_FirstDQInput) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, 4); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); + RunDQMatMulNotConverted_FirstDQInput({12, 37}, {37, 12}, 0, 16, -1); +} + +// Input1 +// | +// \ DQ +// \ / +// MatMul +// | +// output +template +void RunDQMatMulNotConverted_TypeShapeMismatch(const std::vector& input1_shape, + const std::vector& weight_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + NodeArg* weight_arg = nullptr; + + // add DQ + if constexpr (std::is_same_v || std::is_same_v) { + weight_arg = builder.MakeInitializer(weight_shape, T(T::min_val, 0), T(T::max_val, 0)); + } else { + weight_arg = builder.MakeInitializer(weight_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + } + + auto* dq_output = builder.MakeIntermediate(); + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + + auto scale_shape = std::vector{weight_shape}; + scale_shape[axis] = (scale_shape[axis] + block_size - 1) / block_size; + auto* scale_arg = builder.MakeInitializer(scale_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + NodeArg* zp_arg; + if constexpr (std::is_same_v || std::is_same_v) { + zp_arg = builder.MakeInitializer(scale_shape, T(0, 0), T(2, 0)); + } else { + zp_arg = builder.MakeInitializer(scale_shape, 0, 2); + } + + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg, zp_arg}, {dq_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight_arg, scale_arg}, {dq_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_TypeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 16, 0); +} + +TEST(QDQTransformerTests, DQMatMulNotConvertedToMatMulNBits_ShapeMismatch) { + // DQ contrib op schema is not updated to support blocked quantization + // block size too small + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 8, 0); + // block size not 2's power + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 12}, 0, 17, 0); + // not axis 0 + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({12, 37}, {37, 37}, 1, 16, 0); + // not rank 2 + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); + RunDQMatMulNotConverted_TypeShapeMismatch({2, 12, 37}, {2, 37, 12}, 0, 16, 0); +} + +// Input1 +// | DQ +// \ / +// MatMul +// | DQ +// \ / +// MatMul +// | +// output +template +typename std::enable_if || std::is_same_v, void>::type +RunDQMatMulConverted(const std::vector& input1_shape, + const std::vector& weight1_shape, + const std::vector& weight2_shape, + const int64_t axis, + const int64_t block_size, + int64_t accuracy_level) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input1_shape, -100.0f, 100.0f); + auto* output_arg = builder.MakeOutput(); + + // add DQ + NodeAttributes attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", axis), attrs); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", block_size), attrs); + auto scale1_shape = std::vector{weight1_shape}; + auto scale2_shape = std::vector{weight2_shape}; + scale1_shape[axis] = (scale1_shape[axis] + block_size - 1) / block_size; + scale2_shape[axis] = (scale2_shape[axis] + block_size - 1) / block_size; + + auto* weight1_arg = builder.MakeInitializer(weight1_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* weight2_arg = builder.MakeInitializer(weight2_shape, T(T::min_val, 0), T(T::max_val, 0)); + auto* dq1_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + auto* matmul1_output = builder.MakeIntermediate(); + + auto* scales1_arg = builder.MakeInitializer(scale1_shape, 8.0f, 12.0f); + auto* scales2_arg = builder.MakeInitializer(scale2_shape, 8.0f, 12.0f); + if constexpr (use_zp) { + auto* zp1_arg = builder.MakeInitializer(scale1_shape, T(0, 0), T(2, 0)); + auto* zp2_arg = builder.MakeInitializer(scale2_shape, T(0, 0), T(2, 0)); + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg, zp1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg, zp2_arg}, {dq2_output}, "", &attrs); + } else { + builder.AddNode("DequantizeLinear", {weight1_arg, scales1_arg}, {dq1_output}, "", &attrs); + builder.AddNode("DequantizeLinear", {weight2_arg, scales2_arg}, {dq2_output}, "", &attrs); + } + + builder.AddNode("MatMul", {input_arg, dq1_output}, {matmul1_output}); + builder.AddNode("MatMul", {matmul1_output, dq2_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(false); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count["com.microsoft.MatMulNBits"], 2); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + std::function add_session_options_fn{}; + if (accuracy_level >= 0) { + add_session_options_fn = [accuracy_level](SessionOptions& sess_opts) { + std::ignore = sess_opts.config_options.AddConfigEntry(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + std::to_string(accuracy_level).c_str()); + }; + } + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 21 /*opset_version*/, + 1e-5 /*per_sample_tolerance*/, + 1e-5 /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_fn); +} + +TEST(QDQTransformerTests, DQMatMulConvertedToMatMulNBits) { + // DQ contrib op schema is not updated to support blocked quantization + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 0); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); + RunDQMatMulConverted({12, 12}, {12, 37}, {37, 12}, 0, 16, 1); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/qdq_test_utils.h b/onnxruntime/test/optimizer/qdq_test_utils.h index 862408f31f004..52ac2a2541a79 100644 --- a/onnxruntime/test/optimizer/qdq_test_utils.h +++ b/onnxruntime/test/optimizer/qdq_test_utils.h @@ -517,7 +517,7 @@ GetQDQTestCaseFn BuildQDQSplitTestCase(const std::vector& input_shape, NodeArg* input_arg = nullptr; if constexpr (std::is_same_v || std::is_same_v) { - input_arg = builder.MakeInputInt4(input_shape, InputType::min_val, InputType::max_val); + input_arg = builder.MakeInput(input_shape, InputType(InputType::min_val, 0), InputType(InputType::max_val, 0)); dq_zp = InputType(static_cast(InputType::max_val / 2)); q_zp = OutputType(static_cast(OutputType::max_val / 2)); } else {