From 7df44225b22a31c3b28fbd5ee8166507fb5c9ab3 Mon Sep 17 00:00:00 2001 From: Jing Fang Date: Tue, 16 Jul 2024 11:03:37 -0700 Subject: [PATCH] add transform part of the dq matmul tool chain --- .../core/optimizer/graph_transformer_utils.h | 7 +- .../onnxruntime_session_options_config_keys.h | 5 + .../core/optimizer/graph_transformer_utils.cc | 26 ++- .../selectors_actions/qdq_actions.cc | 179 +++++++++++++++++- .../selectors_actions/qdq_actions.h | 30 +++ .../qdq_selector_action_transformer.cc | 39 +++- .../qdq_selector_action_transformer.h | 6 +- .../selectors_actions/qdq_selectors.cc | 81 ++++++++ .../selectors_actions/qdq_selectors.h | 15 ++ .../optimizer/selectors_actions/actions.cc | 4 +- .../optimizer/selectors_actions/actions.h | 3 +- 11 files changed, 380 insertions(+), 15 deletions(-) diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..0bb5c7432f0a7 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -10,6 +10,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/session_options.h" #include "core/optimizer/graph_transformer.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/optimizer/rule_based_graph_transformer.h" @@ -49,7 +50,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) @@ -78,7 +80,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable = {}); + const InlinedHashSet& rules_and_transformers_to_disable = {}, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index c32e2a77e8453..17ae649e6f174 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // - "0": Gemm FastMath mode is not enabled. [DEFAULT] // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; + +// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. +// Refer to MatMulNBits op schema for more details. +// If not provided, default is 4. +static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level"; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4298551aec412..6e5be28f12745 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -13,6 +13,7 @@ #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" #include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/platform/threadpool.h" #if !defined(ORT_MINIMAL_BUILD) @@ -187,7 +188,8 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -287,6 +289,10 @@ InlinedVector> GenerateTransformers( onnxruntime::kJsExecutionProvider}; const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -300,7 +306,10 @@ InlinedVector> GenerateTransformers( if (!qdq_is_int8_allowed) { transformers.emplace_back(std::make_unique(avx2_precision_mode, cpu_ep)); } - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + SatApplyContextVariant{}, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep)); @@ -409,7 +418,8 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, - const InlinedHashSet& rules_and_transformers_to_disable) { + const InlinedHashSet& rules_and_transformers_to_disable, + concurrency::ThreadPool* intra_op_thread_pool) { InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); @@ -423,12 +433,18 @@ InlinedVector> GenerateTransformersForMinimalB const bool qdq_is_int8_allowed = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed, QDQIsInt8Allowed() ? "1" : "0") == "1"; - + const int64_t qdq_matmulnbits_accuracy_level = + ParseStringWithClassicLocale( + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, + "4")); // runtime optimizations only support CPU EP now const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; if (!disable_quant_qdq) { - transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, apply_context)); + transformers.emplace_back(std::make_unique(qdq_is_int8_allowed, + apply_context, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool)); } transformers.emplace_back(std::make_unique(cpu_ep, apply_context)); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc index 3d2a81ce7f8cd..8aff6155a11af 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h" - #include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/optimizer/initializer.h" #include "core/graph/node_attr_utils.h" +#include "core/mlas/inc/mlas_q4.h" + namespace onnxruntime { namespace QDQ { @@ -273,6 +275,181 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select } } +DQMatMulToMatMulNBitsAction::DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) + : accuracy_level_{accuracy_level}, + domain_{kMSDomain}, + op_type_{"MatMulNBits"}, + value_moves_{[]() { + NTO::NodeLocation target{NTO::NodeType::kTarget, 0}; + return std::vector{ + MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput), + MoveAll(target, ArgType::kOutput)}; + }()}, + intra_op_thread_pool_{intra_op_thread_pool} { + ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4"); + + if (!intra_op_thread_pool) { + OrtThreadPoolParams to; + intra_op_thread_pool_optional_ = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + } +} + +NodeAttributes +DQMatMulToMatMulNBitsAction::ExtraAttributes(const RuntimeState& runtime_state) const { + NodeAttributes extra_attributes; + + const auto* dq_node = runtime_state.selected_nodes.Input(0); + auto& attrs = dq_node->GetAttributes(); + const auto* weight_shape = dq_node->InputDefs()[0]->Shape(); + + utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes); + // currently only 4bits is supported. In the future, derive bits from DQ's weight type. + utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast(4)), extra_attributes); + utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes); + + return extra_attributes; +} + +Status DQMatMulToMatMulNBitsAction::ProcessNewNode(Graph& graph, + const NodesToOptimize& selected_nodes, + Node& replacement_node) const { + const auto* dq_node = selected_nodes.Input(0); + const auto* weight_arg = dq_node->InputDefs()[0]; + const auto* scale_arg = dq_node->InputDefs()[1]; + const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr; + const auto& attrs = dq_node->GetAttributes(); + + const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr; + const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr; + graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto); + graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto); + if (zp_arg) { + graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto); + } + + auto K = weight_arg->Shape()->dim(0).dim_value(); + auto N = weight_arg->Shape()->dim(1).dim_value(); + auto block_size = attrs.at("block_size").i(); + auto quant_num = (K + block_size - 1) / block_size; + auto blob_bytes = (block_size + 1) / 2; + + // Unfortunately iterating the source data is complicated, the data maybe in + // external file, a raw buffer, or a repeated field depending on the data + // type. UnpackTensor() already contains some of these logic and is closest + // to what we need. But it does not handle external data. + Initializer weight_src(*weight_tensor_proto, graph.ModelPath()); + Initializer scale_src(*scale_tensor_proto, graph.ModelPath()); + std::optional> zp_src_ptr; + Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(weight_arg->Name() + "_T"), + std::vector{N, quant_num, blob_bytes}); + Initializer scale_dst(static_cast(scale_src.data_type()), + graph.GenerateNodeArgName(scale_arg->Name() + "_T"), + std::vector{N * quant_num}); + std::optional> zp_dst_ptr; + + if (zp_tensor_proto) { + zp_src_ptr.emplace(std::make_unique(*zp_tensor_proto, graph.ModelPath())); + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName(zp_arg->Name() + "_T"), + std::vector{N * ((quant_num + 1) / 2)})); + } else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + zp_dst_ptr.emplace(std::make_unique(ONNX_NAMESPACE::TensorProto_DataType_UINT8, + graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"), + std::vector{N * ((quant_num + 1) / 2)})); + } + + if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); + } + } else { + if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); + + } else { + MlasQDQTransposeBlockwiseQuantized( + weight_src.DataAsByteSpan().data(), + scale_src.data(), + zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr, + weight_dst.data(), + scale_dst.data(), + zp_dst_ptr ? zp_dst_ptr.value()->data() : nullptr, + true, + static_cast(K), + static_cast(N), + static_cast(block_size), + intra_op_thread_pool_ ? intra_op_thread_pool_ : intra_op_thread_pool_optional_.value().get()); + } + } + + ONNX_NAMESPACE::TensorProto weight_T_tp; + ONNX_NAMESPACE::TensorProto scale_T_tp; + std::optional> zp_T_tp_ptr; + + // TODO(fajin): external_data to memory location to avoid arena allocation + // https://github.com/microsoft/onnxruntime/tree/fajin/dqmatmultensorprotohack + weight_dst.ToProto(weight_T_tp); + scale_dst.ToProto(scale_T_tp); + if (zp_dst_ptr) { + zp_T_tp_ptr = std::make_unique(); + zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value()); + } + + auto& input_defs = replacement_node.MutableInputDefs(); + input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp)); + replacement_node.MutableInputArgsCount().push_back(1); + + if (zp_T_tp_ptr) { + input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value())); + replacement_node.MutableInputArgsCount().push_back(1); + } + + return Status::OK(); +} + static std::vector GetGemmMoveInfo(bool does_q_node_exist) { NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0}; NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1}; diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h index 8179a030508a5..833a57485fe4c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h @@ -3,7 +3,12 @@ #pragma once +#include +#include +#include + #include "core/optimizer/selectors_actions/actions.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -76,6 +81,31 @@ struct MatMulReplaceWithQLinear : public Action { BinaryReplaceWithQLinear qlinear_matmul_replacer_; }; +// used together with DQMatMulNodeGroupSelector, which does the sanity check +struct DQMatMulToMatMulNBitsAction : public ReplaceWithNew { + DQMatMulToMatMulNBitsAction(int64_t accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool); + + private: + std::string OpType(const RuntimeState&) const override { return op_type_; } + + std::string Domain(const RuntimeState&) const override { return domain_; } + + NodeAttributes ExtraAttributes(const RuntimeState&) const override; + + std::vector ValueMoves(const RuntimeState&) const override { return value_moves_; } + + // transpose initializers, and add to the MatMulNBits inputs + Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override; + + const int64_t accuracy_level_; + const std::string domain_; + const std::string op_type_; + const std::vector value_moves_; + concurrency::ThreadPool* intra_op_thread_pool_; + std::optional> intra_op_thread_pool_optional_; +}; + struct GemmReplaceWithQuant : public Action { GemmReplaceWithQuant(); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index 80ead8f8c68d6..17e66a3953b97 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -228,6 +228,30 @@ void MatMulQDQRules(SelectorActionRegistry& qdq_selector_action_registry, bool i #endif } +void DQMatMulToMatMulNBitsRules(SelectorActionRegistry& qdq_selector_action_registry, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { + // 2 nodes. DQ -> MatMul. DQ is the second input to MatMul. + // DQ's weight is int4/uint4. DQ's scale is float/float16. + // DQ is block-quantized along axis 0, with block_size >= 16 and as 2's power. + const std::string action_name{"DQMatMulToMatMulNBits"}; + + std::unique_ptr action = + std::make_unique(qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); + +#if !defined(ORT_MINIMAL_BUILD) + std::unique_ptr selector = std::make_unique(); + qdq_selector_action_registry.RegisterSelectorAndAction(action_name, + {{"MatMul", {}}}, + std::move(selector), + std::move(action)); + +#else + qdq_selector_action_registry.RegisterAction(action_name, std::move(action)); +#endif +} + void GemmQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { // 3 to 5 nodes. 0=DQ A, 1=DQ B, 2=DQ C(optional), 3=Gemm, 4=Q Y(optional) // Replace with QGemm @@ -271,7 +295,9 @@ void WhereQDQRules(SelectorActionRegistry& qdq_selector_action_registry) { #endif } -SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { +SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) { SelectorActionRegistry qdq_selector_action_registry; SplitQDQRules(qdq_selector_action_registry); DropQDQNodesRules(qdq_selector_action_registry); @@ -283,17 +309,22 @@ SelectorActionRegistry CreateSelectorActionRegistry(bool is_int8_allowed) { MatMulQDQRules(qdq_selector_action_registry, is_int8_allowed); GemmQDQRules(qdq_selector_action_registry); WhereQDQRules(qdq_selector_action_registry); + DQMatMulToMatMulNBitsRules(qdq_selector_action_registry, + qdq_matmulnbits_accuracy_level, + intra_op_thread_pool); return qdq_selector_action_registry; } } // namespace -QDQSelectorActionTransformer::QDQSelectorActionTransformer( - bool is_int8_allowed, const SatApplyContextVariant& apply_context) +QDQSelectorActionTransformer::QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context, + int64_t qdq_matmulnbits_accuracy_level, + concurrency::ThreadPool* intra_op_thread_pool) : SelectorActionTransformer{ "QDQSelectorActionTransformer", - CreateSelectorActionRegistry(is_int8_allowed), + CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool), apply_context, // this transformer is only compatible with the CPU and DML EP {kCpuExecutionProvider, kDmlExecutionProvider}} { diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h index 1780923f3f273..ba636f76d1900 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h @@ -5,6 +5,7 @@ #include "core/optimizer/selectors_actions/selector_action_transformer.h" #include "core/mlas/inc/mlas.h" +#include "core/platform/threadpool.h" namespace onnxruntime { @@ -21,7 +22,10 @@ Transformer that fuses QDQ and fp32 ops into quantized ops. */ class QDQSelectorActionTransformer : public SelectorActionTransformer { public: - QDQSelectorActionTransformer(bool is_int8_allowed, const SatApplyContextVariant& apply_context = {}); + QDQSelectorActionTransformer(bool is_int8_allowed, + const SatApplyContextVariant& apply_context = {}, + int64_t qdq_matmulnbits_accuracy_level = 4, + concurrency::ThreadPool* intra_op_thread_pool = nullptr); }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 09705f61c82ce..692db4eb327b5 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -414,6 +414,87 @@ bool MatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, } } +bool DQMatMulNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + ORT_UNUSED_PARAMETER(q_nodes); + const auto& graph = graph_viewer.GetGraph(); + + // MatMul has only 1 DQ input and the DQ must have 1 output edge and not be a graph output + if (dq_nodes.size() != 1 || !optimizer_utils::CheckOutputEdges(graph, *dq_nodes[0], 1)) { + return false; + } + + // DQ must be MatMul's the second input + if (node.InputDefs()[1] != dq_nodes[0]->OutputDefs()[0]) { + return false; + } + + // DQ weight/zero points types are int4/uint4, scales/output types are float or float16 + const auto weight_arg = dq_nodes[0]->InputDefs()[0]; + const auto scale_arg = dq_nodes[0]->InputDefs()[1]; + const auto zero_point_arg = dq_nodes[0]->InputDefs().size() == 3 ? dq_nodes[0]->InputDefs()[2] : nullptr; + int32_t dt_weight = weight_arg->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_scales = scale_arg->TypeAsProto()->tensor_type().elem_type(); + if (dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && + dt_scales != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16) { + return false; + } + + if (!Is4BitIntType(dt_weight)) { + return false; + } + + // DQ is blockwise quantized along axis 0, and block_size must be 2's power and >= 16 + const auto& dq_attrs = dq_nodes[0]->GetAttributes(); + if (const auto a_iter = dq_attrs.find("axis"); + a_iter == dq_attrs.end() || a_iter->second.i() != 0) { + return false; + } + + const auto a_iter = dq_attrs.find("block_size"); + if (a_iter == dq_attrs.end()) { + return false; + } + + auto block_size = a_iter->second.i(); + if (block_size < 16 || ((block_size - 1) & block_size)) { + return false; + } + + // weight, scale and zero points (if exists) must be constants + const auto* weight_tensor_proto = graph.GetConstantInitializer(weight_arg->Name(), true); + const auto* scale_tensor_proto = graph.GetConstantInitializer(scale_arg->Name(), true); + const auto* zp_tensor_proto = zero_point_arg ? graph.GetConstantInitializer(zero_point_arg->Name(), true) : nullptr; + + if (!weight_tensor_proto || !scale_tensor_proto) { + return false; + } + + if (zero_point_arg && !zp_tensor_proto) { + return false; + } + + // weight, scale and zero points (if exists) must have the rank 2 + if (weight_tensor_proto->dims_size() != 2 || + scale_tensor_proto->dims_size() != 2 || + (zp_tensor_proto && zp_tensor_proto->dims_size() != 2)) { + return false; + } + + // check weight, scale and zero points (if exists) shapes + if ((weight_tensor_proto->dims()[0] + block_size - 1) / block_size != scale_tensor_proto->dims()[0] || + weight_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1] || + (zp_tensor_proto && + (zp_tensor_proto->dims()[0] != scale_tensor_proto->dims()[0] || + zp_tensor_proto->dims()[1] != scale_tensor_proto->dims()[1]))) { + return false; + } + + return true; +} + bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& node, const std::vector& dq_nodes, diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index 1a2a620acb480..491a15b62cb03 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -204,6 +204,14 @@ class MatMulNodeGroupSelector : public NodeGroupSelector { bool allow_4bit_; }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulNodeGroupSelector : public NodeGroupSelector { + private: + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmNodeGroupSelector : public NodeGroupSelector { @@ -358,6 +366,13 @@ class MatMulSelector : public BaseSelector { allow_16bit, allow_4bit)) {} }; +// Convert "1 DQ node for input B -> MatMul" to "MatMulNBits" +class DQMatMulToMatMulNBitsSelector : public BaseSelector { + public: + explicit DQMatMulToMatMulNBitsSelector(gsl::span compatible_providers = {}) + : BaseSelector(std::make_unique(), compatible_providers) {} +}; + // Input: DQ nodes for A, B and optional C // Output: optional Q node for Y class GemmSelector : public BaseSelector { diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.cc b/onnxruntime/core/optimizer/selectors_actions/actions.cc index c8d5acbf66b78..bb4033afedc49 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.cc +++ b/onnxruntime/core/optimizer/selectors_actions/actions.cc @@ -102,12 +102,14 @@ static Status CreateReplacementNode(Graph& graph, Status ReplaceWithNew::Run(Graph& graph, const NodesToOptimize& selected_nodes) const { const RuntimeState runtime_state{graph, selected_nodes}; + Node* replacement{}; ORT_RETURN_IF_ERROR(CreateReplacementNode(graph, selected_nodes, OpType(runtime_state), Domain(runtime_state), ExtraAttributes(runtime_state), ValueMoves(runtime_state), - /* only_update_dest_definitions */ false, nullptr)); + /* only_update_dest_definitions */ false, &replacement)); + ORT_RETURN_IF_ERROR(ProcessNewNode(graph, selected_nodes, *replacement)); return node_remover_.Run(graph, selected_nodes); } diff --git a/onnxruntime/core/optimizer/selectors_actions/actions.h b/onnxruntime/core/optimizer/selectors_actions/actions.h index 9384bfa7027cd..4d5b520cc47cb 100644 --- a/onnxruntime/core/optimizer/selectors_actions/actions.h +++ b/onnxruntime/core/optimizer/selectors_actions/actions.h @@ -158,6 +158,8 @@ struct ReplaceWithNew : public Action { // specifies how the inputs and outputs for the replaced nodes are moved to the new node virtual std::vector ValueMoves(const RuntimeState&) const = 0; + virtual Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const { return Status::OK(); } + RemoveNodes node_remover_; }; @@ -187,5 +189,4 @@ struct ReplaceWithNewFixed : public ReplaceWithNew { const NodeAttributes extra_attrs_; const std::vector value_moves_; }; - } // namespace onnxruntime