Skip to content

Commit

Permalink
add transform part of the dq matmul tool chain
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Jul 17, 2024
1 parent 6c7562b commit e8ce6b9
Show file tree
Hide file tree
Showing 16 changed files with 826 additions and 37 deletions.
7 changes: 5 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/common/inlined_containers.h"
#include "core/framework/session_options.h"
#include "core/optimizer/graph_transformer.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)
#include "core/optimizer/rule_based_graph_transformer.h"
Expand Down Expand Up @@ -49,7 +50,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -78,7 +80,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},

Check warning on line 83 in include/onnxruntime/core/optimizer/graph_transformer_utils.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/optimizer/graph_transformer_utils.h:83: Add #include <string> for string [build/include_what_you_use] [4]
concurrency::ThreadPool* intra_op_thread_pool = nullptr);

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
static const char* const kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel = "session.qdq_matmulnbits_accuracy_level";
26 changes: 21 additions & 5 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h"
#include "core/optimizer/selectors_actions/selector_action_transformer_apply_contexts.h"
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -187,7 +188,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
const InlinedHashSet<std::string>& rules_and_transformers_to_disable) {
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool disable_quant_qdq =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1";
Expand Down Expand Up @@ -287,6 +289,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
onnxruntime::kJsExecutionProvider};
const InlinedHashSet<std::string_view> cpu_dml_eps = {onnxruntime::kCpuExecutionProvider,
onnxruntime::kDmlExecutionProvider};
const int64_t qdq_matmulnbits_accuracy_level =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
#ifdef MLAS_TARGET_AMD64_IX86
const bool avx2_precision_mode =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow();
Expand All @@ -300,7 +306,10 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
if (!qdq_is_int8_allowed) {
transformers.emplace_back(std::make_unique<QDQS8ToU8Transformer>(avx2_precision_mode, cpu_ep));
}
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
SatApplyContextVariant{},
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
}

transformers.emplace_back(std::make_unique<GemmActivationFusion>(cpu_ep));
Expand Down Expand Up @@ -409,7 +418,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable) {
const InlinedHashSet<std::string>& rules_and_transformers_to_disable,
concurrency::ThreadPool* intra_op_thread_pool) {
InlinedVector<std::unique_ptr<GraphTransformer>> transformers;
const bool saving = std::holds_alternative<SatRuntimeOptimizationSaveContext>(apply_context);

Expand All @@ -423,12 +433,18 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const bool qdq_is_int8_allowed =
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQIsInt8Allowed,
QDQIsInt8Allowed() ? "1" : "0") == "1";

const int64_t qdq_matmulnbits_accuracy_level =
ParseStringWithClassicLocale<int64_t>(
session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel,
"4"));
// runtime optimizations only support CPU EP now
const InlinedHashSet<std::string_view> cpu_ep = {onnxruntime::kCpuExecutionProvider};

if (!disable_quant_qdq) {
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed, apply_context));
transformers.emplace_back(std::make_unique<QDQSelectorActionTransformer>(qdq_is_int8_allowed,
apply_context,
qdq_matmulnbits_accuracy_level,
intra_op_thread_pool));
}

transformers.emplace_back(std::make_unique<ConvActivationFusion>(cpu_ep, apply_context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
// Licensed under the MIT License.

#include "core/optimizer/qdq_transformer/selectors_actions/qdq_actions.h"

#include "core/optimizer/qdq_transformer/qdq_util.h"
#include "core/optimizer/initializer.h"
#include "core/graph/node_attr_utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/mlas/inc/mlas_q4.h"

namespace onnxruntime {
namespace QDQ {

Expand Down Expand Up @@ -273,6 +275,176 @@ Status MatMulReplaceWithQLinear::Run(Graph& graph, const NodesToOptimize& select
}
}

DQMatMulReplaceWithMatMulNBits::DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool)
: accuracy_level_{accuracy_level},
domain_{kMSDomain},
op_type_{"MatMulNBits"},
value_moves_{[]() {
NTO::NodeLocation target{NTO::NodeType::kTarget, 0};

Check warning on line 284 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:284:8: "NTO" is a misspelling of "NOT"

Check warning on line 284 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:284:33: "NTO" is a misspelling of "NOT"
return std::vector<NodeAndMoveInfo>{
MoveAndAppend(target, ArgType::kInput, 0, ArgType::kInput),
MoveAll(target, ArgType::kOutput)};
}()},
intra_op_thread_pool_{intra_op_thread_pool} {
ORT_ENFORCE(accuracy_level_ >= 0 && accuracy_level_ <= 4, "MatMulNBits accuracy level must be between 0 and 4");
}

NodeAttributes
DQMatMulReplaceWithMatMulNBits::ExtraAttributes(const RuntimeState& runtime_state) const {
NodeAttributes extra_attributes;

const auto* dq_node = runtime_state.selected_nodes.Input(0);
auto& attrs = dq_node->GetAttributes();
const auto* weight_shape = dq_node->InputDefs()[0]->Shape();

utils::SetNodeAttribute(utils::MakeAttribute("K", weight_shape->dim(0).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("N", weight_shape->dim(1).dim_value()), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("accuracy_level", accuracy_level_), extra_attributes);
// currently only 4bits is supported. In the future, derive bits from DQ's weight type.
utils::SetNodeAttribute(utils::MakeAttribute("bits", static_cast<int64_t>(4)), extra_attributes);
utils::SetNodeAttribute(utils::MakeAttribute("block_size", attrs.at("block_size").i()), extra_attributes);

return extra_attributes;
}

Status DQMatMulReplaceWithMatMulNBits::ProcessNewNode(Graph& graph,
const NodesToOptimize& selected_nodes,
Node& replacement_node) const {
ORT_RETURN_IF_NOT(intra_op_thread_pool_, "Passed in thread pool should not be null");
const auto* dq_node = selected_nodes.Input(0);
const auto* weight_arg = dq_node->InputDefs()[0];
const auto* scale_arg = dq_node->InputDefs()[1];
const auto* zp_arg = dq_node->InputDefs().size() > 2 ? dq_node->InputDefs()[2] : nullptr;
const auto& attrs = dq_node->GetAttributes();

const ONNX_NAMESPACE::TensorProto* weight_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* scale_tensor_proto = nullptr;
const ONNX_NAMESPACE::TensorProto* zp_tensor_proto = nullptr;
graph.GetInitializedTensor(weight_arg->Name(), weight_tensor_proto);
graph.GetInitializedTensor(scale_arg->Name(), scale_tensor_proto);
if (zp_arg) {
graph.GetInitializedTensor(zp_arg->Name(), zp_tensor_proto);
}

auto K = weight_arg->Shape()->dim(0).dim_value();
auto N = weight_arg->Shape()->dim(1).dim_value();
auto block_size = attrs.at("block_size").i();
auto quant_num = (K + block_size - 1) / block_size;
auto blob_bytes = (block_size + 1) / 2;

// Unfortunately iterating the source data is complicated, the data maybe in
// external file, a raw buffer, or a repeated field depending on the data
// type. UnpackTensor() already contains some of these logic and is closest
// to what we need. But it does not handle external data.
Initializer weight_src(*weight_tensor_proto, graph.ModelPath());
Initializer scale_src(*scale_tensor_proto, graph.ModelPath());
std::optional<std::unique_ptr<Initializer>> zp_src_ptr;
Initializer weight_dst(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(weight_arg->Name() + "_T"),
std::vector<int64_t>{N, quant_num, blob_bytes});
Initializer scale_dst(static_cast<ONNX_NAMESPACE::TensorProto_DataType>(scale_src.data_type()),
graph.GenerateNodeArgName(scale_arg->Name() + "_T"),
std::vector<int64_t>{N * quant_num});
std::optional<std::unique_ptr<Initializer>> zp_dst_ptr;

if (zp_tensor_proto) {
zp_src_ptr.emplace(std::make_unique<Initializer>(*zp_tensor_proto, graph.ModelPath()));
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName(zp_arg->Name() + "_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
} else if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT4) {
zp_dst_ptr.emplace(std::make_unique<Initializer>(ONNX_NAMESPACE::TensorProto_DataType_UINT8,
graph.GenerateNodeArgName("fused_DQ_MatMul_zero_point_T"),
std::vector<int64_t>{N * ((quant_num + 1) / 2)}));
}

if (scale_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
MlasQDQTransposeBlockwiseQuantized<float, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
} else {
MlasQDQTransposeBlockwiseQuantized<float, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<float>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<float>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
}
} else {
if (weight_src.data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT4) {
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, true>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);

} else {
MlasQDQTransposeBlockwiseQuantized<MLFloat16, 4, false>(
weight_src.DataAsByteSpan().data(),
scale_src.data<MLFloat16>(),
zp_src_ptr ? zp_src_ptr.value()->DataAsByteSpan().data() : nullptr,
weight_dst.data<uint8_t>(),
scale_dst.data<MLFloat16>(),
zp_dst_ptr ? zp_dst_ptr.value()->data<uint8_t>() : nullptr,
true,
static_cast<int>(K),
static_cast<int>(N),
static_cast<int>(block_size),
intra_op_thread_pool_);
}
}

ONNX_NAMESPACE::TensorProto weight_T_tp;
ONNX_NAMESPACE::TensorProto scale_T_tp;
std::optional<std::unique_ptr<ONNX_NAMESPACE::TensorProto>> zp_T_tp_ptr;

// TODO(fajin): external_data to memory location to avoid arena allocation
// https://github.com/microsoft/onnxruntime/pull/12465
weight_dst.ToProto(weight_T_tp);
scale_dst.ToProto(scale_T_tp);
if (zp_dst_ptr) {
zp_T_tp_ptr = std::make_unique<ONNX_NAMESPACE::TensorProto>();
zp_dst_ptr.value()->ToProto(*zp_T_tp_ptr.value());
}

auto& input_defs = replacement_node.MutableInputDefs();
input_defs.push_back(&graph_utils::AddInitializer(graph, weight_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);
input_defs.push_back(&graph_utils::AddInitializer(graph, scale_T_tp));
replacement_node.MutableInputArgsCount().push_back(1);

if (zp_T_tp_ptr) {
input_defs.push_back(&graph_utils::AddInitializer(graph, *zp_T_tp_ptr.value()));
replacement_node.MutableInputArgsCount().push_back(1);
}

return Status::OK();
}

static std::vector<NodeAndMoveInfo> GetGemmMoveInfo(bool does_q_node_exist) {
NTO::NodeLocation dq_A{NTO::NodeType::kInput, 0};

Check warning on line 449 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:449:2: "NTO" is a misspelling of "NOT"

Check warning on line 449 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:449:25: "NTO" is a misspelling of "NOT"
NTO::NodeLocation dq_B{NTO::NodeType::kInput, 1};

Check warning on line 450 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:450:2: "NTO" is a misspelling of "NOT"

Check warning on line 450 in onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "NTO" is a misspelling of "NOT" Raw Output: ./onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_actions.cc:450:25: "NTO" is a misspelling of "NOT"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

#pragma once

#include <memory>
#include <string>
#include <vector>

#include "core/optimizer/selectors_actions/actions.h"
#include "core/platform/threadpool.h"

namespace onnxruntime {

Expand Down Expand Up @@ -76,6 +81,30 @@ struct MatMulReplaceWithQLinear : public Action {
BinaryReplaceWithQLinear qlinear_matmul_replacer_;
};

// used together with DQMatMulNodeGroupSelector, which does the sanity check
struct DQMatMulReplaceWithMatMulNBits : public ReplaceWithNew {
DQMatMulReplaceWithMatMulNBits(int64_t accuracy_level,
concurrency::ThreadPool* intra_op_thread_pool);

private:
std::string OpType(const RuntimeState&) const override { return op_type_; }

std::string Domain(const RuntimeState&) const override { return domain_; }

NodeAttributes ExtraAttributes(const RuntimeState&) const override;

std::vector<NodeAndMoveInfo> ValueMoves(const RuntimeState&) const override { return value_moves_; }

// transpose initializers, and add to the MatMulNBits inputs
Status ProcessNewNode(Graph&, const NodesToOptimize&, Node&) const override;

const int64_t accuracy_level_;
const std::string domain_;
const std::string op_type_;
const std::vector<NodeAndMoveInfo> value_moves_;
concurrency::ThreadPool* intra_op_thread_pool_;
};

struct GemmReplaceWithQuant : public Action {
GemmReplaceWithQuant();

Expand Down
Loading

0 comments on commit e8ce6b9

Please sign in to comment.