From 445b8ec3eb671cc4bf1c2ffd8e31b0f450988634 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 5 Dec 2023 17:31:26 +0000 Subject: [PATCH 1/5] add sm80 kernel to op try to use graph transformer as prepack --- cmake/onnxruntime_optimizer.cmake | 4 + cmake/onnxruntime_providers_cuda.cmake | 5 + .../core/optimizer/graph_transformer_utils.h | 5 +- .../cuda/quantization/matmul_nbits.cu | 182 ++++++++++ .../cuda/quantization/matmul_nbits.cuh | 10 + .../core/graph/contrib_ops/contrib_defs.cc | 11 + onnxruntime/core/graph/graph_utils.cc | 6 - onnxruntime/core/graph/graph_utils.h | 24 +- .../core/mickey/blk_q4/f16_prepack_sm80.h | 59 ++++ onnxruntime/core/optimizer/gpu_ops_prepack.cc | 317 +++++++++++++++++ onnxruntime/core/optimizer/gpu_ops_prepack.h | 24 ++ .../core/optimizer/graph_transformer_utils.cc | 18 +- .../providers/cuda/cuda_provider_factory.cc | 17 + .../providers/cuda/cuda_provider_factory.h | 3 +- onnxruntime/core/session/inference_session.cc | 16 +- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 149 ++++++++ .../test/optimizer/gpu_op_prepack_test.cc | 329 ++++++++++++++++++ .../test/optimizer/graph_transform_test.cc | 27 +- .../optimizer/graph_transform_test_builder.cc | 1 + .../optimizer/graph_transform_utils_test.cc | 16 +- .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 149 -------- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 2 +- .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 2 +- .../cuda/test_cases/cuda_test_provider.cc | 6 + 24 files changed, 1197 insertions(+), 185 deletions(-) create mode 100644 onnxruntime/core/optimizer/gpu_ops_prepack.cc create mode 100644 onnxruntime/core/optimizer/gpu_ops_prepack.h create mode 100644 onnxruntime/test/optimizer/gpu_op_prepack_test.cc diff --git a/cmake/onnxruntime_optimizer.cmake b/cmake/onnxruntime_optimizer.cmake index f15d5b8dd6f80..5a5702e9d7cd4 100644 --- a/cmake/onnxruntime_optimizer.cmake +++ b/cmake/onnxruntime_optimizer.cmake @@ -109,6 +109,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}) + +# using optimizer as cuda prepacking, so extra headers are needed +target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) + if (onnxruntime_ENABLE_TRAINING) target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT}) if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index aeeac10ead27d..2a73a14f1588d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -213,6 +213,7 @@ include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) @@ -284,6 +285,10 @@ endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) + # TODO only needed in DEBUG builds, need cmake expert advice on how to do that + set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") + set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") + install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index e609745b5e03f..56853e3229030 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -22,6 +22,7 @@ namespace onnxruntime { class IExecutionProvider; +class ExecutionProviders; namespace optimizer_utils { @@ -48,7 +49,7 @@ std::unique_ptr GenerateRuleBasedGraphTransformer( InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, - const IExecutionProvider& execution_provider /*required by constant folding*/, + const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/, const InlinedHashSet& rules_and_transformers_to_disable = {}); #endif // !defined(ORT_MINIMAL_BUILD) @@ -77,7 +78,7 @@ InlinedVector> GenerateTransformersForMinimalB TransformerLevel level, const SessionOptions& session_options, const SatApplyContextVariant& apply_context, - const IExecutionProvider& cpu_execution_provider, + const ExecutionProviders& execution_providers, const InlinedHashSet& rules_and_transformers_to_disable = {}); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index d4d583906b7f4..618ba49aed8d7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -10,6 +10,8 @@ #include "core/providers/cuda/cuda_common.h" #include "matmul_nbits.cuh" +#include "blk_q4/f16_gemm_sm80.h" + using namespace onnxruntime::cuda; using namespace cub; @@ -348,6 +350,186 @@ template bool TryMatMul4Bits( int shared_mem_per_block, cudaStream_t stream); +/** + * @brief Helper function to run the GEMM kernel for 4bits quantized gemm on SM80. + * Only support fp16 for now. +*/ +template< + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, + gsl::span a, + gsl::span weights, + gsl::span scales, + gsl::span offsets, + gsl::span output) { + + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + + ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct"); + cutlass::TensorRef ref_a( + reinterpret_cast(a.data()), + LayoutInputA::packed({m, k})); + + ORT_RETURN_IF_NOT(weights.size_bytes() == k/2 * n/2 * sizeof(ElementWPack), "weights size is not correct"); + cutlass::TensorRef ref_W( + reinterpret_cast(weights.data()), + LayoutInputWPack::packed({k/2, n/2})); + + ORT_RETURN_IF_NOT(scales.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQScale), + "scales size is not correct"); + cutlass::TensorRef ref_scales( + reinterpret_cast(scales.data()), + LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); + + ORT_RETURN_IF_NOT(output.size_bytes() == m * n * sizeof(ElementOutput), "output size is not correct"); + cutlass::TensorRef ref_output( + reinterpret_cast(output.data()), + LayoutOutput::packed({m, n})); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets) { + ORT_RETURN_IF_NOT(offsets.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQOffset), + "offsets size is not correct"); + cutlass::TensorRef ref_offsets( + reinterpret_cast(offsets.data()), + LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn})); + status = GemmRunner::run( + stream, problem_size, ref_a, ref_W, ref_scales, ref_offsets, + ref_output, ref_output); + } else { + status = GemmRunner::run( + stream, problem_size, ref_a, ref_W, ref_scales, + ref_output, ref_output); + } + ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + return Status::OK(); +} + +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + gsl::span a, + gsl::span weights, + gsl::span scales, + gsl::span offsets, + gsl::span output) { + + switch (block_size) + { + case 16: + if (column_wise_blocking) { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<16, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<16, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<16, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<16, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<16, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<16, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<16, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<16, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + + case 32: + if (column_wise_blocking) { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<32, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<32, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<32, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<32, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<32, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<32, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<32, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<32, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + + case 64: + if (column_wise_blocking) { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<64, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<64, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<64, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<64, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } else { + if (m > 16) { + if (offsets.empty()) + return blkq4_gemm_sm80<64, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<64, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + } else { + if (offsets.empty()) + return blkq4_gemm_sm80<64, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + else + return blkq4_gemm_sm80<64, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + } + } + break; + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index 9ccbe4c4d97a8..62b05901c9533 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -22,6 +22,16 @@ bool TryMatMul4Bits( int shared_mem_per_block, cudaStream_t stream); +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + gsl::span a, + gsl::span weights, + gsl::span scales, + gsl::span offsets, + gsl::span output); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 6709398c788f0..96b583ec3ec3c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3362,6 +3362,15 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; + // This is really bad as we expose a property to the user that should never be set by user. + // We have to use this as we perform cuda prepacking during graph optimization phase and + // this is the only way to pass the information to the runtime. + static const char* PrepackProperty_doc = R"DOC( +Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default). +This property should NEVER be set by user. It is set by ONNX Runtime internally during +model loading time. +)DOC"; + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) .SetDomain(kMSDomain) .SinceVersion(1) @@ -3377,6 +3386,8 @@ Input zero_points is stored as uint8_t or same as type(A). It has the same packi "computation. 4 means input A can be quantized with the same block_size to int8 internally from " "type T1.", AttributeProto::INT, static_cast(0)) + .Attr("column_wise_blocking", "whether to quantize weight columnwise (value 1, default), or rowwise (value 0)", AttributeProto::INT, static_cast(1)) + .Attr("prepacked", PrepackProperty_doc, AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") diff --git a/onnxruntime/core/graph/graph_utils.cc b/onnxruntime/core/graph/graph_utils.cc index 20e8161ee79fd..7e0924b3a8064 100644 --- a/onnxruntime/core/graph/graph_utils.cc +++ b/onnxruntime/core/graph/graph_utils.cc @@ -244,12 +244,6 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, std::string_view op_typ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { - const auto& attrs = node.GetAttributes(); - const auto iter = attrs.find(attr_name); - return iter == attrs.end() ? nullptr : &iter->second; -} - NodeArg& AddInitializer(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) { // sanity check as AddInitializedTensor silently ignores attempts to add a duplicate initializer const ONNX_NAMESPACE::TensorProto* existing = nullptr; diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index cf76ec785fda5..be1049d78fdf0 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -26,7 +26,29 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Returns the attribute of a Node with a given name. */ -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name); +static inline +const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { + const auto& attrs = node.GetAttributes(); + const auto iter = attrs.find(attr_name); + return iter == attrs.end() ? nullptr : &iter->second; +} + +template +inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, T& value); + +template <> +inline Status TryGetNodeAttribute(const Node& node, const std::string& attr_name, int64_t& value) { + const auto* attr = GetNodeAttribute(node, attr_name); + if (!attr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not set."); + } + if (!attr->has_i()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node attribute '", attr_name, "' is not an integer."); + } + value = attr->i(); + return Status::OK(); +} + /** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index a08cfb97eed4a..784ecb0fee8b5 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -19,6 +19,7 @@ #pragma once #include "core/common/common.h" +#include "core/framework/float16.h" #include "core/util/matrix_layout.h" namespace onnxruntime { @@ -321,5 +322,63 @@ struct BlockwiseQuantization { } }; +static inline bool IsSm80WithWholeBlocks( + int weight_rows, [[maybe_unused]] int weight_cols, + int major, [[maybe_unused]] int minor) { + if (major < 8) { + return false; + } + + // Kernel implementation detail: + // K must be aligned with thread block tile (64) due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. + // To fix this, we need to write our own predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return (weight_rows % 64 == 0); +} + +template +inline +bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { + using Base = BlockwiseQuantization; + if (weight_cols % Base::QuantBlocking::kColumn != 0) { + return false; + } + if (weight_rows % Base::QuantBlocking::kRow != 0) { + return false; + } + return IsSm80WithWholeBlocks(weight_rows, weight_cols, major, minor); +} + +static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) { + switch (block_size) + { + case 16: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 32: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 64: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + } + return false; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc new file mode 100644 index 0000000000000..acfe660eff395 --- /dev/null +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc @@ -0,0 +1,317 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Module Abstract: +// This module defines the logic for prepacking weights +// (aka Initializers in onnxruntime) of GPU operators. +// Unlike CPU operators, overriding the PrePack() method +// of class OpKernel results in GPU memory fragmentation. +// So we try to rewrite the weight tensors during graph +// optimization phase to avoid this problem +// +// Unfortunately, there are still some seriouse problems +// with this approach: +// 1. Rewriting of the initializer tensors is restricted +// by operator shape inferencing rules. For example, +// there are 3 initializers for MatMulNBits, +// we can't combine them into a single initializer. +// And we have to make sure the operator's shape inference +// logic does NOT verify the initializer's shape. +// 2. These rewriting logic is tightly coupled to each GPU +// operators. It really should be defined together with +// these operators, instead of defining them in a complete +// different module. +// 3. The logic of prepacking depends on underlying GPU +// hardware. Currently this part is hard-coded for SM80. + + +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/gpu_ops_prepack.h" +#include "core/optimizer/utils.h" +#include "core/graph/graph_utils.h" + +#include "blk_q4/f16_prepack_sm80.h" + +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_execution_provider_info.h" + +namespace onnxruntime { + +extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); + +/** + * @brief Read initialized tensor from protobuf, and store it in ort_value. + * Keep in mind that ort_value is the owner of the tensor memory after calling this function. +*/ +inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { + const ONNX_NAMESPACE::TensorProto* tensor_proto; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), + "Missing initializer for ", arg->Name()); + + const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + + return utils::TensorProtoToOrtValue( + Env::Default(), path_c_str, *tensor_proto, + std::make_shared(), ort_value); +} + +template +inline gsl::span make_span(std::string& str) { + return gsl::make_span(reinterpret_cast(str.data()), str.size() / sizeof(T)); +} + +// +// Prepacking logic specific to MatMulNBits on sm80 +// + +static inline bool IsNodeMatMulNbitsFp16(const Node& node){ + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) { + return false; + } + const auto* acts = node.InputDefs()[0]; + if (acts == nullptr || acts->Type() == nullptr || acts->Type()->find("float16") == std::string::npos) { + return false; + } + return true; +} + +template +void Sm80BlkQ4PrepackT( + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + MLFloat16, + block_size, + 4, + column_quant_blk>; + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + + packed_w.resize(q_weight_shape.product() * sizeof(uint8_t)); + Base::prepack_weights( + rows, columns, weights, + make_span(packed_w)); + + packed_scales.resize(meta_shape.product() * sizeof(MLFloat16)); + Base::prepack_quant_scales( + rows, columns, scales, + make_span(packed_scales)); + + if (!zp.empty()) { + packed_zp.resize(meta_shape.product() * sizeof(uint8_t)); + Base::prepack_quant_offsets( + rows, columns, zp, + make_span(packed_zp)); + } +} + +void Sm80BlkQ4Prepack( + int block_size, bool column_quant_blk, + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { + switch (block_size) { + case 16: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<16, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<16, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + case 32: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<32, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<32, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + case 64: + if (column_quant_blk) { + Sm80BlkQ4PrepackT<64, true>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } else { + Sm80BlkQ4PrepackT<64, false>(rows, columns, weights, scales, zp, packed_w, packed_scales, packed_zp); + } + break; + default: + ORT_THROW("Unsupported block size: ", block_size); + } +} + +/** + *@brief Prepack weights of the operator MatMulNBits. + * The caller should make sure the node is of type MatMulNBits. + */ +Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { + modified = false; + int64_t att_i; + + // + // Verify prepacking is needed and supported + // + Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i); + bool prepacked = status.IsOK() ? att_i != 0 : false; + if (prepacked) { + return Status::OK(); // already prepacked, nothing to do + } + + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "bits", att_i)); + int nbits = static_cast(att_i); + if (nbits != 4) { + return Status::OK(); // only support 4 bits for now + } + + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "K", att_i)); + int k = static_cast(att_i); + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "N", att_i)); + int n = static_cast(att_i); + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "block_size", att_i)); + int block_size = static_cast(att_i); + + status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i); + bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true; + + auto* provider_info = TryGetProviderInfo_CUDA(); + ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators."); + int major, minor; + ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, + "Failed to query CUDA device version while prepacking cuda operators."); + + if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) { + return Status::OK(); // not supported + } + + // + // Verification passed, start prepacking + // + auto& node_name = node.Name(); + auto& mutable_input_defs = node.MutableInputDefs(); + + NodeArg* old_weights_arg = mutable_input_defs[1]; + NodeArg* old_scales_arg = mutable_input_defs[2]; + NodeArg* old_zp_arg = nullptr; + + // holders of the packed weight tensor memory + std::string packed_weights; + std::string packed_scales; + std::string packed_zp; + + { + // owners of the weight tensor memory, keep around until consumed by the prepacking function + OrtValue weights_val; + OrtValue scales_val; + OrtValue zp_val; + + ORT_RETURN_IF_ERROR(GetOrtValue(old_weights_arg, graph, weights_val)); + const gsl::span weights = weights_val.GetMutable()->DataAsSpan(); + + ORT_RETURN_IF_ERROR(GetOrtValue(old_scales_arg, graph, scales_val)); + const gsl::span scales = scales_val.GetMutable()->DataAsSpan(); + + gsl::span zp; + if (mutable_input_defs.size() > 3) { + old_zp_arg = mutable_input_defs[3]; + if (old_zp_arg != nullptr && old_zp_arg->Exists()) { + ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val)); + zp = zp_val.GetMutable()->DataAsSpan(); + } + } + + Sm80BlkQ4Prepack(block_size, column_wise_quant_blk, k, n, weights, scales, zp, packed_weights, packed_scales, packed_zp); + +#if 0 + // debug print if prepacked tests fail + std::cout << " ====== packed weight ====== " << std::endl << std::hex; + const gsl::span packed_weights_span(reinterpret_cast(packed_weights.data()), packed_weights.size()); + for (int r = 0; r < k; r++) { + for (int c = 0; c < n/2; c++) { + std::cout << std::setw(2) << std::setfill('0') << static_cast(packed_weights_span[c * k + r]) << " "; + } + std::cout << std::endl; + } + std::cout << std::dec; +#endif + } + + // + // write packed weight tensor to node parameters + // + ONNX_NAMESPACE::TensorProto packed_weights_proto; + packed_weights_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_weight")); + packed_weights_proto.add_dims(packed_weights.size() / sizeof(uint8_t)); + packed_weights_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_weights_proto.set_raw_data(std::move(packed_weights)); + NodeArg& packed_weights_arg = graph_utils::AddInitializer(graph, packed_weights_proto); + graph.RemoveConsumerNode(old_weights_arg->Name(), &node); + mutable_input_defs[1] = &packed_weights_arg; + graph.AddConsumerNode(packed_weights_arg.Name(), &node); + + ONNX_NAMESPACE::TensorProto packed_scales_proto; + packed_scales_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_scales")); + packed_scales_proto.add_dims(packed_scales.size() / sizeof(MLFloat16)); + packed_scales_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_scales_proto.set_raw_data(std::move(packed_scales)); + NodeArg& packed_scales_arg = graph_utils::AddInitializer(graph, packed_scales_proto); + graph.RemoveConsumerNode(old_scales_arg->Name(), &node); + mutable_input_defs[2] = &packed_scales_arg; + graph.AddConsumerNode(packed_scales_arg.Name(), &node); + + if (!packed_zp.empty()) { + ONNX_NAMESPACE::TensorProto packed_zp_proto; + packed_zp_proto.set_name(graph.GenerateNodeArgName(node_name + "_prepacked_zp")); + packed_zp_proto.add_dims(packed_zp.size() / sizeof(uint8_t)); + packed_zp_proto.set_data_type(onnxruntime::utils::ToTensorProtoElementType()); + packed_zp_proto.set_raw_data(std::move(packed_zp)); + NodeArg& packed_zp_arg = graph_utils::AddInitializer(graph, packed_zp_proto); + graph.RemoveConsumerNode(old_zp_arg->Name(), &node); + mutable_input_defs[3] = &packed_zp_arg; + graph.AddConsumerNode(packed_zp_arg.Name(), &node); + } + + node.AddAttribute("prepacked", static_cast(1)); + modified = true; + return Status::OK(); +} + + +Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + // int fused_count = 0; + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; // node was removed as part of an earlier fusion + + Node& node = *p_node; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) { + continue; // only interested in CUDA nodes + } + + // Run prepack if the node is MatMulNBits. + // When we have more operators to support, we should use a map to dispatch the prepack function + // instead of adding a whole bunch of if branches here. + if (IsNodeMatMulNbitsFp16(node)) { + bool packed = false; + ORT_RETURN_IF_ERROR(PackMatMulNBitsFp16(node, graph, packed)); + modified |= packed; + continue; + } + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.h b/onnxruntime/core/optimizer/gpu_ops_prepack.h new file mode 100644 index 0000000000000..d6770a2bfb1cb --- /dev/null +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@Class AttentionFusion +Rewrite graph fusing attention subgraph to a single Attention node. +*/ +class GpuOpsPrepack : public GraphTransformer { + public: + GpuOpsPrepack() noexcept + : GraphTransformer("GpuOpsPrepack", InlinedHashSet{onnxruntime::kCudaExecutionProvider}) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + private: +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index f319e7254568d..6304a56b4e756 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -6,6 +6,7 @@ #include #include +#include "core/framework/execution_providers.h" #include "core/optimizer/conv_activation_fusion.h" #include "core/optimizer/nhwc_transformer.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" @@ -42,6 +43,7 @@ #include "core/optimizer/gemm_activation_fusion.h" #include "core/optimizer/gemm_sum_fusion.h" #include "core/optimizer/gemm_transpose_fusion.h" +#include "core/optimizer/gpu_ops_prepack.h" #include "core/optimizer/identical_children_consolidation.h" #include "core/optimizer/identity_elimination.h" #include "core/optimizer/layer_norm_fusion.h" @@ -183,8 +185,11 @@ std::unique_ptr GenerateRuleBasedGraphTransformer( InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, - const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const ExecutionProviders& execution_providers, /*required by constant folding*/ const InlinedHashSet& rules_and_transformers_to_disable) { + // CPU EP required to run constant folding + const IExecutionProvider& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider); + InlinedVector> transformers; const bool disable_quant_qdq = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; @@ -376,6 +381,14 @@ InlinedVector> GenerateTransformers( // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); + +#ifdef USE_CUDA + // Cuda weight prepacking. + auto* cuda_ep = execution_providers.Get(onnxruntime::kCudaExecutionProvider); + if (cuda_ep != nullptr) { + transformers.emplace_back(std::make_unique()); + } +#endif #endif } break; @@ -397,8 +410,9 @@ InlinedVector> GenerateTransformersForMinimalB TransformerLevel level, const SessionOptions& session_options, const SatApplyContextVariant& apply_context, - const IExecutionProvider& cpu_execution_provider, + const ExecutionProviders& execution_providers, const InlinedHashSet& rules_and_transformers_to_disable) { + const auto& cpu_execution_provider = *execution_providers.Get(onnxruntime::kCpuExecutionProvider); InlinedVector> transformers; const bool saving = std::holds_alternative(apply_context); diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 103c79c93b2ca..b959c4943cefe 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -79,6 +79,23 @@ struct ProviderInfo_CUDA_Impl final : ProviderInfo_CUDA { return nullptr; } + OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) override { + int device_id; + auto cuda_err = cudaGetDevice(&device_id); + if (cuda_err != cudaSuccess) { + return CreateStatus(ORT_FAIL, "Failed to get device id."); + } + cudaDeviceProp prop; + cuda_err = cudaGetDeviceProperties(&prop, device_id); + if (cuda_err != cudaSuccess) { + return CreateStatus(ORT_FAIL, "Failed to get device properties."); + } + *major = prop.major; + *minor = prop.minor; + + return nullptr; + } + std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) override { return std::make_unique(device_id, name); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/core/providers/cuda/cuda_provider_factory.h index 4d5ef658f6be0..320eb4ed82cfc 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.h +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.h @@ -22,7 +22,8 @@ class NvtxRangeCreator; struct ProviderInfo_CUDA { virtual OrtStatus* SetCurrentGpuDeviceId(_In_ int device_id) = 0; - virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; + virtual OrtStatus* GetCurrentGpuDeviceId(_Out_ int* device_id) = 0; + virtual OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) = 0; virtual std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateCUDAPinnedAllocator(const char* name) = 0; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5fd66c459d382..4162ca5970320 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1540,15 +1540,15 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) Status ApplyOrtFormatModelRuntimeOptimizations( onnxruntime::Graph& graph, const logging::Logger& logger, const SessionOptions& session_options, - const InlinedHashSet& optimizers_to_disable, const IExecutionProvider& cpu_ep) { + const InlinedHashSet& optimizers_to_disable, const ExecutionProviders& execution_providers) { bool modified = false; for (int level = static_cast(TransformerLevel::Level2); level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, - optimizers_to_disable); + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, + execution_providers, optimizers_to_disable); for (const auto& transformer : transformers) { ORT_RETURN_IF_ERROR(transformer->Apply(graph, modified, logger)); @@ -1919,9 +1919,8 @@ common::Status InferenceSession::Initialize() { *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); ORT_RETURN_IF_ERROR_SESSIONID_( - ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, cpu_ep)); + ApplyOrtFormatModelRuntimeOptimizations(graph, *session_logger_, session_options_, optimizers_to_disable_, execution_providers_)); #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } @@ -3059,7 +3058,6 @@ common::Status InferenceSession::AddPredefinedTransformers( TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { - const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); if (graph_optimization_level >= level) { @@ -3070,7 +3068,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, execution_providers_, optimizers_to_disable_); } else { const auto sat_context = @@ -3079,8 +3077,8 @@ common::Status InferenceSession::AddPredefinedTransformers( ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{ record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; - return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - optimizers_to_disable_); + return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, + execution_providers_, optimizers_to_disable_); } }(); diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index 6ea8b55505214..ca5f80030dd6d 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -21,6 +21,155 @@ namespace onnxruntime { namespace test { +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr (has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr (has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } +} + static inline void sm80_prepack_weights_ref( int rows, int columns, diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc new file mode 100644 index 0000000000000..b2caac6efd869 --- /dev/null +++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc @@ -0,0 +1,329 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" +#include "core/mlas/inc/mlas.h" +#include "core/graph/graph.h" +#include "core/optimizer/initializer.h" + +#include "core/mlas/inc/mlas_q4.h" +#include "core/providers/cuda/cuda_provider_factory_creator.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" + +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +namespace onnxruntime { +namespace test { +#ifndef DISABLE_CONTRIB_OPS + +std::shared_ptr LoadCudaEp() { + OrtCUDAProviderOptions cuda_options; + auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); + if (!factory) { + return nullptr; + } + return std::move(factory->CreateProvider()); +} + +/** + * @brief Testing helper for GPU prepacking logic in the graph transformer. + * This is an modification of the TransformerTester function from + * onnxruntime/test/optimizer/graph_transform_test_builder.cc + * with: + * - the addition of cuda execution provider in the session. + * - a different location for the model checker, right after session initialization + * as the initializers will be deleted during session run. +*/ +void GpuPrepackTester( + const std::shared_ptr& cuda_ep, + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version = 12, + double per_sample_tolerance = 0.0, + double relative_per_sample_tolerance = 0.0, + const std::function& add_session_options = {}, + const InlinedHashSet& disabled_optimizers = {}) { + // Build the model for this test. + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = opset_version; + domain_to_version[kMSDomain] = 1; + Model model("TransformerTester", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + ASSERT_TRUE(build_test_case); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + auto run_model = [&](TransformerLevel level, std::vector& fetches) { + SessionOptions session_options; + session_options.graph_optimization_level = level; + if (level == target_level) { + // we don't really need the model file, but it seems to be the only way to keep the + // transformed initializers so that they can be verified. + session_options.optimized_model_filepath = + ToPathString("gpu_prepack_test_model_opt_level_" + std::to_string(static_cast(level)) + ".onnx"); + } + if (add_session_options) { + add_session_options(session_options); + } + InferenceSessionWrapper session{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(cuda_ep)); + + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); + if (!disabled_optimizers.empty()) { + ASSERT_STATUS_OK(session.FilterEnabledOptimizers(InlinedHashSet{disabled_optimizers})); + } + + ASSERT_STATUS_OK(session.Initialize()); + + RunOptions run_options; + ASSERT_STATUS_OK(session.Run(run_options, + helper.feeds_, + helper.output_names_, + &fetches)); + + if (level == target_level) { + if (check_transformed_graph) { + check_transformed_graph(session); + } + } + }; + + std::vector baseline_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(baseline_level, baseline_fetches)); + + std::vector target_fetches; + ASSERT_NO_FATAL_FAILURE(run_model(target_level, target_fetches)); + + size_t num_outputs = baseline_fetches.size(); + ASSERT_EQ(num_outputs, target_fetches.size()); + + // for (size_t i = 0; i < num_outputs; i++) { + // std::pair ret = + // CompareOrtValue(target_fetches[i], + // baseline_fetches[i], + // per_sample_tolerance, + // relative_per_sample_tolerance, + // false); + // EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + // } +} + +inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { + const ONNX_NAMESPACE::TensorProto* tensor_proto; + ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), + "Missing initializer for ", arg->Name()); + + const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + + return utils::TensorProtoToOrtValue( + Env::Default(), path_c_str, *tensor_proto, + std::make_shared(), ort_value); +} + +template +void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep){ + // + // Type definitions + // + using ElementT = MLFloat16; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + columnwise_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + // + // Generate random inputs + // + const auto q_weight_shape = Base::get_quant_weights_shape(K, N); + const auto meta_shape = Base::get_quant_meta_shape(K, N); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + blkq4_weights_gen( + K, N, dequants, q_weights, q_scales, q_zp); + + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( + q_weights, q_weight_shape); + MatrixRef tensor_scale( + q_scales, meta_shape); + MatrixRef tensor_offset; + if constexpr (has_offsets) { + tensor_offset = MatrixRef(q_zp, zp_shape); + } + + // Compute prepacked weights + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(K, N / 2)); + onnxruntime::test::sm80_prepack_weights_ref(K, N, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_scales_ref( + K, N, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } + } + + std::vector packed_zp_ref; + if constexpr (has_offsets) { + packed_zp_ref.resize(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_offsets_ref( + K, N, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } + } + } + + auto build_test_case = [&](ModelTestBuilder& builder) { + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise_blocking, K, N, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + auto* input_arg = builder.MakeInput({M, K}, MLFloat16(0.0f), MLFloat16(31.0f)); + auto* output_arg = builder.MakeOutput(); + auto* weight_arg = builder.MakeInitializer({q_weight_shape[1], q_weight_shape[0]}, q_weights); + auto* scale_arg = builder.MakeInitializer({static_cast(q_scales.size())}, q_scales); + + std::vector input_args{input_arg, weight_arg, scale_arg}; + if constexpr (has_offsets) { + auto* zero_point_arg = builder.MakeInitializer({static_cast(q_zp.size())}, q_zp); + input_args.push_back(zero_point_arg); + } else { + ASSERT_TRUE(q_zp.empty()); + } + Node& node = builder.AddNode("MatMulNBits", input_args, {output_arg}, kMSDomain); + node.AddAttribute("K", static_cast(K)); + node.AddAttribute("N", static_cast(N)); + node.AddAttribute("block_size", static_cast(block_size)); + node.AddAttribute("bits", static_cast(4)); + node.AddAttribute("column_wise_blocking", static_cast(columnwise_blocking)); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + const auto& graph = session.GetGraph(); + for (const auto& node : graph.Nodes()) { + if (node.OpType() == "MatMulNBits") { + ASSERT_EQ(node.Domain(), kMSDomain); + ASSERT_EQ(node.GetAttributes().at("prepacked").i(), 1); + { + // Verify prepacked weights + OrtValue packed_w_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[1], graph, packed_w_val)); + const gsl::span weights_data = packed_w_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(weights_data.size(), packed_w_ref.size()); + for (size_t i = 0; i < packed_w_ref.size(); ++i) { + int expected = packed_w_ref[i]; + int found = weights_data[i]; + ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N/2 << "]"; + } + } + { + // Verify prepacked scales + OrtValue packed_s_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[2], graph, packed_s_val)); + const gsl::span scales_data = packed_s_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(scales_data.size(), packed_scales_ref.size()); + for (size_t i = 0; i < packed_scales_ref.size(); ++i) { + float expected = packed_scales_ref[i]; + float found = scales_data[i]; + ASSERT_EQ(expected, found) << "prepacked scale mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; + } + } + if constexpr (has_offsets) { + // Verify prepacked zero points + OrtValue packed_z_val; + ASSERT_STATUS_OK(GetOrtValue(node.InputDefs()[3], graph, packed_z_val)); + const gsl::span offsets_data = packed_z_val.GetMutable()->DataAsSpan(); + ASSERT_EQ(offsets_data.size(), packed_zp_ref.size()); + for (size_t i = 0; i < packed_zp_ref.size(); ++i) { + int expected = packed_zp_ref[i]; + int found = offsets_data[i]; + ASSERT_EQ(expected, found) << "prepacked zero-point mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; + } + } else { + ASSERT_LE(node.InputDefs().size(), 3); + } + std::cout << "Prepacked weights verified." << std::endl; + } + } + }; + + GpuPrepackTester(cuda_ep, + build_test_case, + check_graph, + TransformerLevel::Level2, + TransformerLevel::Level3); + +} + +TEST(GpuOpPrepackTests, MatmulNBits) { + std::shared_ptr provider = LoadCudaEp(); + if (!provider) { + GTEST_SKIP() << "Skipping tests when CUDA EP is not available"; + } + + MatMulQ4Test<64, true, true>(1, 128, 64, provider); + MatMulQ4Test<64, false, true>(1, 128, 64, provider); + MatMulQ4Test<64, true, false>(1, 128, 64, provider); + MatMulQ4Test<64, false, false>(1, 128, 64, provider); + + // MatMulQ4Test<32, true, true>(1, 64, 64, provider); + // MatMulQ4Test<32, true, false>(1, 64, 64, provider); + // MatMulQ4Test<32, false, true>(1, 64, 64, provider); + // MatMulQ4Test<32, false, false>(1, 64, 64, provider); + // MatMulQ4Test<32, true, true>(1, 64, 128, provider); + // MatMulQ4Test<32, true, false>(1, 64, 128, provider); + // MatMulQ4Test<32, false, true>(1, 64, 128, provider); + // MatMulQ4Test<32, false, false>(1, 64, 128, provider); +} + +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 1535e2b60a3bd..039560bece0d4 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -16,6 +16,7 @@ #include "asserts.h" #include "core/common/span_utils.h" #include "core/framework/data_types.h" +#include "core/framework/execution_providers.h" #include "core/framework/ort_value.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" @@ -819,12 +820,15 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map e = - std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + execution_providers.Add(kCpuExecutionProvider, std::move(e)); bool has_constant_folding = false; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_options, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -4563,11 +4567,14 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { } static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_options) { - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + execution_providers.Add(kCpuExecutionProvider, std::move(e)); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4580,9 +4587,13 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt // Test session option configuration for DoubleQDQPairsRemover TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { - std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); + std::shared_ptr cpu_ep = std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, + execution_providers, {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.cc b/onnxruntime/test/optimizer/graph_transform_test_builder.cc index a5024f510b3cd..8d40843da17a7 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.cc @@ -84,6 +84,7 @@ void TransformerTester(const std::function& buil add_session_options(session_options); } InferenceSessionWrapper session{session_options, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_data.data(), static_cast(model_data.size()))); if (transformer) { ASSERT_STATUS_OK(session.RegisterGraphTransformer(std::move(transformer), level)); diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..919d08cd628ea 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -35,10 +35,15 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l1_transformer = "ConstantFolding"; std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + std::shared_ptr e = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + auto status = execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, execution_providers); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, + execution_providers, disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +66,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, execution_providers); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, + execution_providers, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h index bbe370675fc48..50ade8da14f54 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -27,155 +27,6 @@ namespace test { Status sm80_supported(); -/** - * @brief Generate a set of quantized weights, scales and offsets - * and dequantized weights for testing quantization and - * dequantization. All outputs are column major layout. - * - * @tparam ElementT The type of the dequantized weights. - * @tparam block_size The block size of the quantization. - * @tparam col_blocking Whether to use column blocking (all elements of - * a block comes from a single column) or row blocking - * @tparam has_offsets Whether to generate offsets. - * - * @param[in] rows The number of rows of the weight matrix. - * @param[in] columns The number of columns of the weight matrix. - * @param[out] dequants The dequantized weights, column major layout. - * @param[out] q_weights The quantized weights, column major layout. - * @param[out] q_scales The scales, column major layout. - * @param[out] q_zp The zero points, column major layout. - */ -template -inline void blkq4_weights_gen( - int rows, int columns, - std::vector& dequants, - std::vector& q_weights, - std::vector& q_scales, - std::vector& q_zp) { - using Base = onnxruntime::cuda::BlockwiseQuantization< - ElementT, - block_size, - 4, - col_blocking>; - - using QuantBlocking = typename Base::QuantBlocking; - using ElementW = typename Base::ElementW; - using LayoutWPack = typename Base::LayoutWPack; - using ElementQOffset = typename Base::ElementQOffset; - - static_assert(std::is_same::value); - static_assert(std::is_same::value); - static_assert(std::is_same::value); - - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution dis(0, 8192); - - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); - const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); - - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // - - q_weights.resize(q_weight_shape.product()); - MatrixRef tensor_q_weight( - q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - q_scales.resize(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - uint32_t v = dis(gen); - uint32_t m = (v % 63) + 1; - uint32_t e = (v >> 6) % 4; - q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); - } - MatrixRef tensor_scale( - q_scales, meta_shape); - - MatrixRef tensor_offset; - if constexpr (has_offsets) { - q_zp.resize(zp_shape.product()); - tensor_offset = MatrixRef( - q_zp, zp_shape); - for (int c = 0; c < zp_shape[1]; c++) { - for (int r = 0; r < zp_shape[0]; ++r) { - uint8_t v0 = dis(gen) % 16; - uint8_t v1 = 8; - if (r * 2 + 1 < meta_shape[0]) { - v1 = dis(gen) % 16; - } - tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); - } - } - } - - dequants.resize(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B - for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { - for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - uint8_t offset = 8; - if constexpr (has_offsets) { - if (scale_cord[0] % 2 == 0) { - offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; - } else { - offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; - } - } - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); - } - } -} - template < int block_size, bool column_wise_blocking, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index e687ae73e66f2..b1d122c571341 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -47,7 +47,7 @@ void testPrepack(int rows, int columns) { std::vector q_scales; std::vector q_zp; std::vector dequants; - onnxruntime::cuda::test::blkq4_weights_gen( + blkq4_weights_gen( rows, columns, dequants, q_weights, q_scales, q_zp); // for quantization tool, the input is row major, all outputs are column major diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index 69c929d446ce4..a7f9cf16de054 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -192,7 +192,7 @@ void run_blkq4_gemm(int m, int n, int k) { std::vector q_scales; std::vector q_zp; std::vector dequants; - onnxruntime::cuda::test::blkq4_weights_gen( + onnxruntime::test::blkq4_weights_gen( problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); using PrepackT = onnxruntime::cuda::BlockwiseQuantization< diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc index 96c1e173316de..bc0715e77c96f 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_test_provider.cc @@ -44,6 +44,12 @@ struct ProviderInfo_CUDA_TestImpl : ProviderInfo_CUDA { return nullptr; } + virtual OrtStatus* GetCurrentGpuDeviceVersion(_Out_ int* major, _Out_ int* minor) override { + *major = 0; + *minor = 0; + return nullptr; + } + std::unique_ptr CreateCUDAAllocator(int16_t, const char*) override { return nullptr; } From a88d6804a3ac997505cd282e92cef0969bdafcd5 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Wed, 14 Feb 2024 22:30:53 +0000 Subject: [PATCH 2/5] connect kernel to op --- .../cuda/quantization/matmul_nbits.cc | 21 ++-- .../cuda/quantization/matmul_nbits.cu | 119 ++++++++++++------ .../cuda/quantization/matmul_nbits.cuh | 11 +- .../cuda/quantization/matmul_nbits.h | 16 ++- .../core/mickey/blk_q4/f16_prepack_sm80.h | 41 ++++-- .../test/optimizer/gpu_op_prepack_test.cc | 81 ++++++++---- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 3 + .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 25 ++-- 8 files changed, 211 insertions(+), 106 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1cec6f6a12f1c..c2bc76d7ee1aa 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -14,7 +14,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -using namespace onnxruntime::cuda; template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { @@ -24,14 +23,6 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(3); const Tensor* reorder_idx = ctx->Input(4); - const auto* a_data = a->Data(); - const uint8_t* blob_data = b->Data(); - const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); - - typedef typename ToCudaType::MappedType CudaT; - constexpr bool transa = false; constexpr bool transb = true; MatMulComputeHelper helper; @@ -43,6 +34,18 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); + if (prepack_ > 0){ + return PrepackedGemm( + static_cast(ctx->GetComputeStream()->GetHandle()), + a, b, scales, zero_points, Y); + } + + const auto* a_data = a->Data(); + const uint8_t* blob_data = b->Data(); + const auto* scales_data = scales->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); + bool is_4bit_done = (reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType()) && TryMatMul4Bits( diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 618ba49aed8d7..9e394cc82e355 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -12,9 +12,6 @@ #include "blk_q4/f16_gemm_sm80.h" -using namespace onnxruntime::cuda; -using namespace cub; - namespace onnxruntime { namespace contrib { namespace cuda { @@ -355,24 +352,28 @@ template bool TryMatMul4Bits( * Only support fp16 for now. */ template< + typename ElementT, int block_size, bool column_wise_blocking, bool small_m, bool has_offsets> Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, - gsl::span a, + gsl::span a, gsl::span weights, - gsl::span scales, + gsl::span scales, gsl::span offsets, - gsl::span output) { - + gsl::span output) { + static_assert(std::is_same::value + || std::is_same::value + || std::is_same::value, + "Only support fp16 for now"); using ElementDequant = cutlass::half_t; using QuantBlocking = typename std::conditional, cutlass::MatrixShape<1, block_size>>::type; - using GemmRunner = BlkQ4F16GemmImpl; + using GemmRunner = onnxruntime::cuda::BlkQ4F16GemmImpl; using ElementAccumulator = typename GemmRunner::ElementAccumulator; using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; @@ -430,15 +431,20 @@ Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, return Status::OK(); } -Status blkq4_fp16_gemm_sm80_dispatch( - int block_size, - bool column_wise_blocking, - int m, int n, int k, cudaStream_t stream, - gsl::span a, - gsl::span weights, - gsl::span scales, - gsl::span offsets, - gsl::span output) { +template +Status +blkq4_fp16_gemm_sm80_dispatch( + int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, + ElementT const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + ElementT const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + ElementT* output_ptr, size_t output_size) { + auto a = gsl::make_span(a_ptr, a_size); + auto weights = gsl::make_span(weights_ptr, weights_size); + auto scales = gsl::make_span(scales_ptr, scales_size); + auto offsets = gsl::make_span(offsets_ptr, offsets_size); + auto output = gsl::make_span(output_ptr, output_size); switch (block_size) { @@ -446,26 +452,26 @@ Status blkq4_fp16_gemm_sm80_dispatch( if (column_wise_blocking) { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<16, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<16, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<16, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<16, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } else { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<16, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<16, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<16, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<16, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } break; @@ -474,26 +480,26 @@ Status blkq4_fp16_gemm_sm80_dispatch( if (column_wise_blocking) { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<32, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<32, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<32, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<32, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } else { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<32, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<32, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<32, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<32, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } break; @@ -502,26 +508,26 @@ Status blkq4_fp16_gemm_sm80_dispatch( if (column_wise_blocking) { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<64, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<64, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<64, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<64, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } else { if (m > 16) { if (offsets.empty()) - return blkq4_gemm_sm80<64, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<64, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } else { if (offsets.empty()) - return blkq4_gemm_sm80<64, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); else - return blkq4_gemm_sm80<64, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output); + return blkq4_gemm_sm80(m, n, k, stream, a, weights, scales, offsets, output); } } break; @@ -530,6 +536,37 @@ Status blkq4_fp16_gemm_sm80_dispatch( return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size); } +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + half const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + half const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + half* output_ptr, size_t output_size); + +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, + bool column_wise_blocking, + int m, int n, int k, cudaStream_t stream, + cutlass::half_t const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + cutlass::half_t const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + cutlass::half_t* output_ptr, size_t output_size); + +template +Status blkq4_fp16_gemm_sm80_dispatch( + int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, + onnxruntime::MLFloat16 const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + onnxruntime::MLFloat16 const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + onnxruntime::MLFloat16* output_ptr, size_t output_size); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh index 62b05901c9533..b3a325bf806b4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh @@ -22,15 +22,16 @@ bool TryMatMul4Bits( int shared_mem_per_block, cudaStream_t stream); +template Status blkq4_fp16_gemm_sm80_dispatch( int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream, - gsl::span a, - gsl::span weights, - gsl::span scales, - gsl::span offsets, - gsl::span output); + ElementT const* a_ptr, size_t a_size, + uint8_t const* weights_ptr, size_t weights_size, + ElementT const* scales_ptr, size_t scales_size, + uint8_t const* offsets_ptr, size_t offsets_size, + ElementT* output_ptr, size_t output_size); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index f5c2c6c4e4fdf..9e672d724c7a8 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -14,16 +14,29 @@ namespace onnxruntime { namespace contrib { namespace cuda { -using namespace onnxruntime::cuda; template class MatMulNBits final : public CudaKernel { public: + using CudaT = typename ToCudaType::MappedType; + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op," + " additional bits support is planned."); + int64_t column_wise_quant_blk = 1; + info.GetAttrOrDefault("column_wise_blocking", &column_wise_quant_blk, int64_t(1)); + column_wise_quant_blk_ = column_wise_quant_blk != 0; + info.GetAttrOrDefault("prepacked", &prepack_, int64_t(0)); + } + + Status PrepackedGemm(cudaStream_t stream, const Tensor* a, const Tensor* b, + const Tensor* scales, const Tensor* zero_points, Tensor* Y) const { + ORT_THROW("Prepacked gemm is not supported for MatMulNBits op."); } Status ComputeInternal(OpKernelContext* context) const override; @@ -34,6 +47,7 @@ class MatMulNBits final : public CudaKernel { int64_t block_size_; int64_t nbits_; bool column_wise_quant_blk_{true}; + int64_t prepack_{0}; }; } // namespace cuda diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index 784ecb0fee8b5..a1f9b80d7754b 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -80,6 +80,23 @@ struct BlockwiseQuantization { return make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); } + static inline bool weight_dimension_supported(int rows, int columns) { + // prepacking works on a 16x16 block, so the dimensions must be multiples of 16 + if (((rows % 16) != 0) || ((columns % 16) != 0)) { + return false; + } + + // verify the dimensions are multiples of the block size + if (((rows % QuantBlocking::kRow) != 0) || ((columns % QuantBlocking::kColumn) != 0)) { + return false; + } + + // All the above restrictions can be relaxed by adding more logic to the prepack + // and gemm implementation to support edge cases. But for now, these restrictions + // does not affect LLM weight dimensions, so we keep it simple. + return true; + } + /** * @brief Prepack weight matrix to facilitate matrix loading, depending on MMA * instruction layout. @@ -114,10 +131,10 @@ struct BlockwiseQuantization { const gsl::span& weights, // <- int4 weights, column major const gsl::span& weights_prepacked // <- int4 prepacked weights tensor, same size buffer ) { - ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 && - (rows % QuantBlocking::kRow) == 0 && - (columns % QuantBlocking::kColumn) == 0, - "Does not support odd number of rows or columns!"); +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(rows, columns), + "This function must be guarded by weight_dimension_supported()!"); +#endif ORT_ENFORCE(weights.size() == size_t(rows * columns / 2), "Weight tensor shape mismatch!"); ORT_ENFORCE(weights_prepacked.size() == weights.size(), @@ -175,6 +192,10 @@ struct BlockwiseQuantization { const gsl::span& scales, // <- quant scales, column major layout const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer ) { +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(rows, columns), + "This function must be guarded by weight_dimension_supported()!"); +#endif auto meta_shape = get_quant_meta_shape(rows, columns); ORT_ENFORCE(scales.size() == size_t(meta_shape.product()), "Quantization scale tensor shape mismatch!"); @@ -245,10 +266,11 @@ struct BlockwiseQuantization { const gsl::span& offsets, // <- quant offsets, int4, column major layout const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer ) { +#ifndef NDEBUG + ORT_ENFORCE(weight_dimension_supported(rows, columns), + "This function must be guarded by weight_dimension_supported()!"); +#endif auto meta_shape = get_quant_meta_shape(rows, columns); - - ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0, - "Does not support odd number of rows or columns!"); ORT_ENFORCE(offsets_prepacked.size() == size_t(meta_shape.product()), "Wrong buffer size for prepacked quantization offsets!"); ORT_ENFORCE(offsets.size() == size_t(((meta_shape[0] + 1) / 2) * meta_shape[1]), @@ -346,10 +368,7 @@ template inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { using Base = BlockwiseQuantization; - if (weight_cols % Base::QuantBlocking::kColumn != 0) { - return false; - } - if (weight_rows % Base::QuantBlocking::kRow != 0) { + if (!Base::weight_dimension_supported(weight_rows, weight_cols)) { return false; } return IsSm80WithWholeBlocks(weight_rows, weight_cols, major, minor); diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc index b2caac6efd869..5fb90c59684f3 100644 --- a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc +++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc @@ -49,8 +49,8 @@ void GpuPrepackTester( TransformerLevel baseline_level, TransformerLevel target_level, int opset_version = 12, - double per_sample_tolerance = 0.0, - double relative_per_sample_tolerance = 0.0, + double per_sample_tolerance = 0.001, + double relative_per_sample_tolerance = 0.001, const std::function& add_session_options = {}, const InlinedHashSet& disabled_optimizers = {}) { // Build the model for this test. @@ -114,15 +114,15 @@ void GpuPrepackTester( size_t num_outputs = baseline_fetches.size(); ASSERT_EQ(num_outputs, target_fetches.size()); - // for (size_t i = 0; i < num_outputs; i++) { - // std::pair ret = - // CompareOrtValue(target_fetches[i], - // baseline_fetches[i], - // per_sample_tolerance, - // relative_per_sample_tolerance, - // false); - // EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; - // } + for (size_t i = 0; i < num_outputs; i++) { + std::pair ret = + CompareOrtValue(target_fetches[i], + baseline_fetches[i], + per_sample_tolerance, + relative_per_sample_tolerance, + false); + EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; + } } inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { @@ -225,7 +225,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise_blocking, K, N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - auto* input_arg = builder.MakeInput({M, K}, MLFloat16(0.0f), MLFloat16(31.0f)); + auto* input_arg = builder.MakeInput({M, K}, MLFloat16(-2.0f), MLFloat16(2.0f)); auto* output_arg = builder.MakeOutput(); auto* weight_arg = builder.MakeInitializer({q_weight_shape[1], q_weight_shape[0]}, q_weights); auto* scale_arg = builder.MakeInitializer({static_cast(q_scales.size())}, q_scales); @@ -289,7 +289,6 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr } else { ASSERT_LE(node.InputDefs().size(), 3); } - std::cout << "Prepacked weights verified." << std::endl; } } }; @@ -308,19 +307,49 @@ TEST(GpuOpPrepackTests, MatmulNBits) { GTEST_SKIP() << "Skipping tests when CUDA EP is not available"; } - MatMulQ4Test<64, true, true>(1, 128, 64, provider); - MatMulQ4Test<64, false, true>(1, 128, 64, provider); - MatMulQ4Test<64, true, false>(1, 128, 64, provider); - MatMulQ4Test<64, false, false>(1, 128, 64, provider); - - // MatMulQ4Test<32, true, true>(1, 64, 64, provider); - // MatMulQ4Test<32, true, false>(1, 64, 64, provider); - // MatMulQ4Test<32, false, true>(1, 64, 64, provider); - // MatMulQ4Test<32, false, false>(1, 64, 64, provider); - // MatMulQ4Test<32, true, true>(1, 64, 128, provider); - // MatMulQ4Test<32, true, false>(1, 64, 128, provider); - // MatMulQ4Test<32, false, true>(1, 64, 128, provider); - // MatMulQ4Test<32, false, false>(1, 64, 128, provider); + // + // GpuPrepackTester function implements two different verifications. + // First is the hook check_graph, which we use to verify the prepacked weights, scales and zero points. + // Second is the comparison of the outputs of the model with and without prepacking, this actually + // doubles as a verification of kernel correctness and prepacking correctness. + // + // We do have other sets of tests for the prepack and kernel correctness, defined in + // onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc + // + // What we are doing here is to verify we correctly connected the prepacking logic in + // the graph transformer, and that the prepacked weights, scales and zero points are correctly + // passed to the kernel in MatMulNBits cuda op. Plus the redundant verifications allows us to + // locate the problem more easily. + // + // So we don't need to test all the combinations here, just a few representative ones. + // + + std::cout << "Testing MatMulQ4Test<64, true, true>(4, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, true, true>(4, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, false, true>(4, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, false, true>(4, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, true, false>(8, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, true, false>(8, 128, 64, provider); + std::cout << "Testing MatMulQ4Test<64, false, false>(8, 128, 64, provider)" << std::endl; + MatMulQ4Test<64, false, false>(8, 128, 64, provider); + + std::cout << "Testing MatMulQ4Test<32, true, true>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, true, true>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, true, false>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, true, false>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, false, true>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, false, true>(16, 64, 128, provider); + std::cout << "Testing MatMulQ4Test<32, false, false>(16, 64, 128, provider)" << std::endl; + MatMulQ4Test<32, false, false>(16, 64, 128, provider); + + std::cout << "Testing MatMulQ4Test<16, true, true>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, true, true>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, true, false>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, true, false>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, false, true>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, false, true>(32, 96, 128, provider); + std::cout << "Testing MatMulQ4Test<16, false, false>(32, 96, 128, provider)" << std::endl; + MatMulQ4Test<16, false, false>(32, 96, 128, provider); } #endif diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index b1d122c571341..0f25769e24e96 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -33,6 +33,9 @@ void testPrepack(int rows, int columns) { 4, col_blocking>; + EXPECT_TRUE(Base::weight_dimension_supported(rows, columns)) + << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; + using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; using LayoutWPack = typename Base::LayoutWPack; diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu index a7f9cf16de054..7202d33fa70fd 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -27,6 +27,7 @@ #include "core/common/common.h" #include "blkq4_fp16_gemm_sm80.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.cuh" namespace onnxruntime { namespace cuda{ @@ -200,6 +201,8 @@ void run_blkq4_gemm(int m, int n, int k) { block_size, 4, column_wise_blocking>; + ORT_ENFORCE(PrepackT::weight_dimension_supported(k, n), + "Test setup problem, unsupported weight dimension: [", k, ", ", n, "]"); std::vector packed_w(q_weight_shape.product()); PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); @@ -256,19 +259,15 @@ void run_blkq4_gemm(int m, int n, int k) { tensor_d.sync_device(); // run GEMM - cutlass::Status status; - if constexpr (has_offsets){ - status = GemmRunner::run( - nullptr, problem_size, tensor_a.device_ref(), ref_W, - ref_scales, ref_zp, - tensor_c.device_ref(), tensor_d.device_ref()); - } else { - status = GemmRunner::run( - nullptr, problem_size, tensor_a.device_ref(), ref_W, - ref_scales, - tensor_c.device_ref(), tensor_d.device_ref()); - } - ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + ORT_THROW_IF_ERROR(onnxruntime::contrib::cuda::blkq4_fp16_gemm_sm80_dispatch( + block_size, column_wise_blocking, + m, n, k, + nullptr, + tensor_a.device_data(), tensor_a.size(), + d_packed_w.data().get(), d_packed_w.size(), + d_packed_scales.data().get(), d_packed_scales.size(), + d_packed_zp.data().get(), d_packed_zp.size(), + tensor_d.device_data(), tensor_d.size())); // Running reference kernel using ElementInputB = ElementInputA; From 6bd5c0c1e7fcef39cfc5848733770a3d286c6b74 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 7 Mar 2024 05:12:39 +0000 Subject: [PATCH 3/5] resolve conflicts --- .../cuda/quantization/matmul_nbits.cc | 29 +++++++++++++++++++ onnxruntime/core/optimizer/gpu_ops_prepack.cc | 9 +++++- .../cuda_execution_provider_test.cc | 1 + 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index c2bc76d7ee1aa..1d1853725aac4 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -15,6 +15,31 @@ namespace onnxruntime { namespace contrib { namespace cuda { +template<> +Status MatMulNBits::PrepackedGemm( + cudaStream_t stream, + const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + Tensor* Y) const { + int64_t M = a->Shape()[0]; + uint8_t const* zero_points_ptr = nullptr; + size_t zero_points_size = 0; + if (zero_points != nullptr) { + zero_points_ptr = zero_points->Data(); + zero_points_size = zero_points->Shape().Size(); + } + + return blkq4_fp16_gemm_sm80_dispatch( + int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, + a->Data(), a->Shape().Size(), + b->Data(), b->Shape().Size(), + scales->Data(), scales->Shape().Size(), + zero_points_ptr, zero_points_size, + Y->MutableData(), Y->Shape().Size()); +} + template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); @@ -35,6 +60,10 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if (Y->Shape().Size() == 0) return Status::OK(); if (prepack_ > 0){ + ORT_RETURN_IF(reorder_idx != nullptr, + "Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!"); + ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType(), + "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!"); return PrepackedGemm( static_cast(ctx->GetComputeStream()->GetHandle()), a, b, scales, zero_points, Y); diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc index acfe660eff395..b0219124c13de 100644 --- a/onnxruntime/core/optimizer/gpu_ops_prepack.cc +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc @@ -195,6 +195,9 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { // auto& node_name = node.Name(); auto& mutable_input_defs = node.MutableInputDefs(); + if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) { + return Status::OK(); // not supported + } NodeArg* old_weights_arg = mutable_input_defs[1]; NodeArg* old_scales_arg = mutable_input_defs[2]; @@ -222,7 +225,11 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { old_zp_arg = mutable_input_defs[3]; if (old_zp_arg != nullptr && old_zp_arg->Exists()) { ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val)); - zp = zp_val.GetMutable()->DataAsSpan(); + Tensor* zp_tensor_ptr = zp_val.GetMutable(); + if (!zp_tensor_ptr->IsDataType()) { + return Status::OK(); // not supported + } + zp = zp_tensor_ptr->DataAsSpan(); } } diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 8dfaaedcbb378..da87a487d0343 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -6,6 +6,7 @@ // 1. slow down performance critical applications and // 2. increase binary size of ORT. #include +#include "core/framework/framework_common.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/cuda/cuda_stream_handle.h" From f71cfa6ab2764bcda43465236cebc3f1adb674ba Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 7 Mar 2024 17:32:01 +0000 Subject: [PATCH 4/5] bug fix: can't use A tensor shape[0] for M --- cmake/onnxruntime_providers_cuda.cmake | 4 - .../cuda/quantization/matmul_nbits.cc | 34 ++++---- .../cuda/quantization/matmul_nbits.cu | 5 +- .../cuda/quantization/matmul_nbits.h | 9 +- onnxruntime/core/graph/graph_utils.h | 4 +- .../core/mickey/blk_q4/f16_prepack_sm80.h | 56 ++++++------- .../threadblock/quantb_mma_multistage.h | 2 +- .../quantb_meta_mma_tensor_op_tile_iterator.h | 14 ++-- onnxruntime/core/optimizer/gpu_ops_prepack.cc | 84 +++++++++---------- onnxruntime/core/util/matrix_layout.h | 8 +- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 8 +- .../test/optimizer/gpu_op_prepack_test.cc | 68 ++++++++++----- .../test/optimizer/graph_transform_test.cc | 6 +- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 2 +- .../optimizer/graph_transformer_utils_test.cc | 16 ++-- 15 files changed, 175 insertions(+), 145 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 2a73a14f1588d..858c6c8b09abe 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -285,10 +285,6 @@ endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) - # TODO only needed in DEBUG builds, need cmake expert advice on how to do that - set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") - set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") - install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1d1853725aac4..b8cfde96200b1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -15,15 +15,15 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template<> +template <> Status MatMulNBits::PrepackedGemm( - cudaStream_t stream, - const Tensor* a, - const Tensor* b, - const Tensor* scales, - const Tensor* zero_points, - Tensor* Y) const { - int64_t M = a->Shape()[0]; + cudaStream_t stream, + int M, + const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + Tensor* Y) const { uint8_t const* zero_points_ptr = nullptr; size_t zero_points_size = 0; if (zero_points != nullptr) { @@ -32,12 +32,12 @@ Status MatMulNBits::PrepackedGemm( } return blkq4_fp16_gemm_sm80_dispatch( - int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, - a->Data(), a->Shape().Size(), - b->Data(), b->Shape().Size(), - scales->Data(), scales->Shape().Size(), - zero_points_ptr, zero_points_size, - Y->MutableData(), Y->Shape().Size()); + int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, + a->Data(), a->Shape().Size(), + b->Data(), b->Shape().Size(), + scales->Data(), scales->Shape().Size(), + zero_points_ptr, zero_points_size, + Y->MutableData(), Y->Shape().Size()); } template @@ -59,14 +59,14 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - if (prepack_ > 0){ + if (prepack_ > 0) { ORT_RETURN_IF(reorder_idx != nullptr, "Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!"); ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType(), "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!"); return PrepackedGemm( - static_cast(ctx->GetComputeStream()->GetHandle()), - a, b, scales, zero_points, Y); + static_cast(ctx->GetComputeStream()->GetHandle()), + helper.M(), a, b, scales, zero_points, Y); } const auto* a_data = a->Data(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 9e394cc82e355..4b03cd58ab0b2 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -12,6 +12,9 @@ #include "blk_q4/f16_gemm_sm80.h" +using namespace onnxruntime::cuda; +using namespace cub; + namespace onnxruntime { namespace contrib { namespace cuda { @@ -390,7 +393,7 @@ Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, const cutlass::gemm::GemmCoord problem_size = {m, n, k}; - ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct"); + ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant)); cutlass::TensorRef ref_a( reinterpret_cast(a.data()), LayoutInputA::packed({m, k})); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 9e672d724c7a8..4ba7283283581 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -34,8 +34,13 @@ class MatMulNBits final : public CudaKernel { info.GetAttrOrDefault("prepacked", &prepack_, int64_t(0)); } - Status PrepackedGemm(cudaStream_t stream, const Tensor* a, const Tensor* b, - const Tensor* scales, const Tensor* zero_points, Tensor* Y) const { + Status PrepackedGemm([[maybe_unused]] cudaStream_t stream, + [[maybe_unused]] int M, + [[maybe_unused]] const Tensor* a, + [[maybe_unused]] const Tensor* b, + [[maybe_unused]] const Tensor* scales, + [[maybe_unused]] const Tensor* zero_points, + [[maybe_unused]] Tensor* Y) const { ORT_THROW("Prepacked gemm is not supported for MatMulNBits op."); } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index be1049d78fdf0..3273e709077d5 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -26,8 +26,7 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Returns the attribute of a Node with a given name. */ -static inline -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { +static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { const auto& attrs = node.GetAttributes(); const auto iter = attrs.find(attr_name); return iter == attrs.end() ? nullptr : &iter->second; @@ -49,7 +48,6 @@ inline Status TryGetNodeAttribute(const Node& node, const std::string& return Status::OK(); } - /** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. @returns The NodeArg for the new initializer. diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index a1f9b80d7754b..7e0a53d739121 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -187,8 +187,8 @@ struct BlockwiseQuantization { static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; static void prepack_quant_scales( - size_t rows, - size_t columns, + int rows, + int columns, const gsl::span& scales, // <- quant scales, column major layout const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer ) { @@ -261,8 +261,8 @@ struct BlockwiseQuantization { } static void prepack_quant_offsets( - size_t rows, - size_t columns, + int rows, + int columns, const gsl::span& offsets, // <- quant offsets, int4, column major layout const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer ) { @@ -345,8 +345,8 @@ struct BlockwiseQuantization { }; static inline bool IsSm80WithWholeBlocks( - int weight_rows, [[maybe_unused]] int weight_cols, - int major, [[maybe_unused]] int minor) { + int weight_rows, [[maybe_unused]] int weight_cols, + int major, [[maybe_unused]] int minor) { if (major < 8) { return false; } @@ -364,9 +364,8 @@ static inline bool IsSm80WithWholeBlocks( return (weight_rows % 64 == 0); } -template -inline -bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { +template +inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { using Base = BlockwiseQuantization; if (!Base::weight_dimension_supported(weight_rows, weight_cols)) { return false; @@ -375,26 +374,25 @@ bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int } static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) { - switch (block_size) - { - case 16: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } - case 32: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } - case 64: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } + switch (block_size) { + case 16: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 32: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 64: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } } return false; } diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h index 8b6bac8c5099a..ad17abf73ef37 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -132,7 +132,7 @@ struct DummyType{ } CUTLASS_HOST_DEVICE - std::monostate& operator[](int idx) { + std::monostate& operator[]([[maybe_unused]] int idx) { return dummy_; } }; diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 4ba39dda3db8d..016e7dcf27869 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -453,7 +453,7 @@ class QuantBMetaMmaTensorOpTileIterator(dest.data()); const b64* scales_ptr = reinterpret_cast(scales.data()); - const ElementOffset* offsets_ptr = nullptr; + [[maybe_unused]] const ElementOffset* offsets_ptr = nullptr; if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } CUTLASS_PRAGMA_UNROLL @@ -461,11 +461,11 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) const uint32_t* p = reinterpret_cast(offsets_ptr); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( "{\n\t" " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands @@ -796,12 +796,12 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); - uint32_t* addon_ptr = reinterpret_cast(addon); - if constexpr (kHasOffset){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // possible buffer over read 2 bytes here. const uint32_t* p = reinterpret_cast(offsets.data()); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t* addon_ptr = reinterpret_cast(addon); + asm volatile( "{\n\t" " .reg .b32 rb0, rb1, rb2;\n" @@ -823,6 +823,8 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) + uint32_t* addon_ptr = reinterpret_cast(addon); + asm volatile( "{\n\t" " .reg .b32 rb0;\n" diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc index b0219124c13de..484f6c42ba5a9 100644 --- a/onnxruntime/core/optimizer/gpu_ops_prepack.cc +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc @@ -24,7 +24,6 @@ // 3. The logic of prepacking depends on underlying GPU // hardware. Currently this part is hard-coded for SM80. - #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/gpu_ops_prepack.h" @@ -43,17 +42,17 @@ extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); /** * @brief Read initialized tensor from protobuf, and store it in ort_value. * Keep in mind that ort_value is the owner of the tensor memory after calling this function. -*/ + */ inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { const ONNX_NAMESPACE::TensorProto* tensor_proto; ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), "Missing initializer for ", arg->Name()); - const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + const auto path_str = graph.ModelPath().ToPathString(); return utils::TensorProtoToOrtValue( - Env::Default(), path_c_str, *tensor_proto, - std::make_shared(), ort_value); + Env::Default(), path_str.c_str(), *tensor_proto, + std::make_shared(), ort_value); } template @@ -65,7 +64,7 @@ inline gsl::span make_span(std::string& str) { // Prepacking logic specific to MatMulNBits on sm80 // -static inline bool IsNodeMatMulNbitsFp16(const Node& node){ +static inline bool IsNodeMatMulNbitsFp16(const Node& node) { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) { return false; } @@ -78,13 +77,13 @@ static inline bool IsNodeMatMulNbitsFp16(const Node& node){ template void Sm80BlkQ4PrepackT( - int rows, int columns, - gsl::span weights, - gsl::span scales, - gsl::span zp, - std::string& packed_w, - std::string& packed_scales, - std::string& packed_zp) { + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { using Base = onnxruntime::cuda::BlockwiseQuantization< MLFloat16, block_size, @@ -93,33 +92,33 @@ void Sm80BlkQ4PrepackT( const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - packed_w.resize(q_weight_shape.product() * sizeof(uint8_t)); + packed_w.resize(SafeInt(q_weight_shape.product() * sizeof(uint8_t))); Base::prepack_weights( - rows, columns, weights, - make_span(packed_w)); + rows, columns, weights, + make_span(packed_w)); - packed_scales.resize(meta_shape.product() * sizeof(MLFloat16)); + packed_scales.resize(SafeInt(meta_shape.product() * sizeof(MLFloat16))); Base::prepack_quant_scales( - rows, columns, scales, - make_span(packed_scales)); + rows, columns, scales, + make_span(packed_scales)); if (!zp.empty()) { - packed_zp.resize(meta_shape.product() * sizeof(uint8_t)); + packed_zp.resize(SafeInt(meta_shape.product() * sizeof(uint8_t))); Base::prepack_quant_offsets( - rows, columns, zp, - make_span(packed_zp)); + rows, columns, zp, + make_span(packed_zp)); } } void Sm80BlkQ4Prepack( - int block_size, bool column_quant_blk, - int rows, int columns, - gsl::span weights, - gsl::span scales, - gsl::span zp, - std::string& packed_w, - std::string& packed_scales, - std::string& packed_zp) { + int block_size, bool column_quant_blk, + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { switch (block_size) { case 16: if (column_quant_blk) { @@ -161,21 +160,23 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i); bool prepacked = status.IsOK() ? att_i != 0 : false; if (prepacked) { - return Status::OK(); // already prepacked, nothing to do + return Status::OK(); // already prepacked, nothing to do } ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "bits", att_i)); - int nbits = static_cast(att_i); + int nbits = SafeInt(att_i); if (nbits != 4) { - return Status::OK(); // only support 4 bits for now + return Status::OK(); // only support 4 bits for now } + // A single dimension can not exceed 2G yet. ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "K", att_i)); - int k = static_cast(att_i); + int k = SafeInt(att_i); ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "N", att_i)); - int n = static_cast(att_i); + int n = SafeInt(att_i); + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "block_size", att_i)); - int block_size = static_cast(att_i); + int block_size = SafeInt(att_i); status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i); bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true; @@ -184,10 +185,10 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators."); int major, minor; ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, - "Failed to query CUDA device version while prepacking cuda operators."); + "Failed to query CUDA device version while prepacking cuda operators."); if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) { - return Status::OK(); // not supported + return Status::OK(); // not supported } // @@ -196,7 +197,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { auto& node_name = node.Name(); auto& mutable_input_defs = node.MutableInputDefs(); if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) { - return Status::OK(); // not supported + return Status::OK(); // not supported } NodeArg* old_weights_arg = mutable_input_defs[1]; @@ -227,7 +228,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val)); Tensor* zp_tensor_ptr = zp_val.GetMutable(); if (!zp_tensor_ptr->IsDataType()) { - return Status::OK(); // not supported + return Status::OK(); // not supported } zp = zp_tensor_ptr->DataAsSpan(); } @@ -289,7 +290,6 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { return Status::OK(); } - Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -304,7 +304,7 @@ Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, c ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) { - continue; // only interested in CUDA nodes + continue; // only interested in CUDA nodes } // Run prepack if the node is MatMulNBits. diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index 783a29d8a2055..b4cf8cf518564 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -378,7 +378,7 @@ class MatrixRef { MatrixRef( NonConstMatrixRef const& ref, ///< MatrixRef to non-const data /// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const - _Magic magic = (typename std::enable_if::type)0 + [[maybe_unused]] _Magic magic = (typename std::enable_if::type)0 ) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {} ORT_FORCEINLINE @@ -428,18 +428,18 @@ class MatrixRef { /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference at(MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } ORT_FORCEINLINE Reference at(int row, int col) const { - return data_[offset(make_Position(row, col))]; + return data_[static_cast(offset(make_Position(row, col)))]; } /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference operator[](MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } }; diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index ca5f80030dd6d..18700c9de84ca 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -68,7 +68,6 @@ inline void blkq4_weights_gen( const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); // // For testing quantization and dequantization, it is not straight @@ -114,9 +113,9 @@ inline void blkq4_weights_gen( q_scales.resize(meta_shape.product()); for (size_t i = 0; i < q_scales.size(); i++) { - uint32_t v = dis(gen); - uint32_t m = (v % 63) + 1; - uint32_t e = (v >> 6) % 4; + uint32_t vscale = dis(gen); + uint32_t m = (vscale % 63) + 1; + uint32_t e = (vscale >> 6) % 4; q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); } MatrixRef tensor_scale( @@ -124,6 +123,7 @@ inline void blkq4_weights_gen( MatrixRef tensor_offset; if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); q_zp.resize(zp_shape.product()); tensor_offset = MatrixRef( q_zp, zp_shape); diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc index 5fb90c59684f3..d9573d30cf296 100644 --- a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc +++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc @@ -13,6 +13,8 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_execution_provider_info.h" #include "test/cuda_host/blkq4_fp16_quant_sm80.h" #include "test/compare_ortvalue.h" @@ -21,16 +23,24 @@ #include "test/util/include/inference_session_wrapper.h" namespace onnxruntime { + +extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); + namespace test { #ifndef DISABLE_CONTRIB_OPS std::shared_ptr LoadCudaEp() { - OrtCUDAProviderOptions cuda_options; - auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); - if (!factory) { + try { + OrtCUDAProviderOptions cuda_options; + auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); + if (!factory) { + return nullptr; + } + return factory->CreateProvider(); + } catch (const ::onnxruntime::OnnxRuntimeException& e) { + std::cerr << "LoadCudaEp: " << e.what() << std::endl; return nullptr; } - return std::move(factory->CreateProvider()); } /** @@ -41,18 +51,18 @@ std::shared_ptr LoadCudaEp() { * - the addition of cuda execution provider in the session. * - a different location for the model checker, right after session initialization * as the initializers will be deleted during session run. -*/ + */ void GpuPrepackTester( - const std::shared_ptr& cuda_ep, - const std::function& build_test_case, - const std::function& check_transformed_graph, - TransformerLevel baseline_level, - TransformerLevel target_level, - int opset_version = 12, - double per_sample_tolerance = 0.001, - double relative_per_sample_tolerance = 0.001, - const std::function& add_session_options = {}, - const InlinedHashSet& disabled_optimizers = {}) { + const std::shared_ptr& cuda_ep, + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version = 12, + double per_sample_tolerance = 0.001, + double relative_per_sample_tolerance = 0.001, + const std::function& add_session_options = {}, + const InlinedHashSet& disabled_optimizers = {}) { // Build the model for this test. std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = opset_version; @@ -130,15 +140,15 @@ inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_ ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), "Missing initializer for ", arg->Name()); - const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + const auto path_str = graph.ModelPath().ToPathString(); return utils::TensorProtoToOrtValue( - Env::Default(), path_c_str, *tensor_proto, - std::make_shared(), ort_value); + Env::Default(), path_str.c_str(), *tensor_proto, + std::make_shared(), ort_value); } template -void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep){ +void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep) { // // Type definitions // @@ -160,7 +170,6 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr // const auto q_weight_shape = Base::get_quant_weights_shape(K, N); const auto meta_shape = Base::get_quant_meta_shape(K, N); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); std::vector q_weights; std::vector q_scales; @@ -176,6 +185,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr q_scales, meta_shape); MatrixRef tensor_offset; if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); tensor_offset = MatrixRef(q_zp, zp_shape); } @@ -260,7 +270,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr for (size_t i = 0; i < packed_w_ref.size(); ++i) { int expected = packed_w_ref[i]; int found = weights_data[i]; - ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N/2 << "]"; + ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N / 2 << "]"; } } { @@ -287,7 +297,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr ASSERT_EQ(expected, found) << "prepacked zero-point mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; } } else { - ASSERT_LE(node.InputDefs().size(), 3); + ASSERT_LE(node.InputDefs().size(), static_cast(3)); } } } @@ -298,7 +308,6 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr check_graph, TransformerLevel::Level2, TransformerLevel::Level3); - } TEST(GpuOpPrepackTests, MatmulNBits) { @@ -307,6 +316,19 @@ TEST(GpuOpPrepackTests, MatmulNBits) { GTEST_SKIP() << "Skipping tests when CUDA EP is not available"; } + // + // Currently these tests only work on sm_80. Going forward, however, + // we need a better solution when we may have different tests for different + // hardware. + // + auto* provider_info = TryGetProviderInfo_CUDA(); + int major, minor; + ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, + "Failed to query CUDA device version while prepacking cuda operators."); + if (major < 8) { + GTEST_SKIP() << "Skipping tests when CUDA EP is not sm_80"; + } + // // GpuPrepackTester function implements two different verifications. // First is the hook check_graph, which we use to verify the prepacked weights, scales and zero points. diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 039560bece0d4..84ed86459a77c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -823,7 +823,7 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map e = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_constant_folding = false; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -4570,7 +4570,7 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::shared_ptr e = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_gelu_approximation = false; auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, @@ -4589,7 +4589,7 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::shared_ptr cpu_ep = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep))); bool has_double_qdq_remover = false; auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 0f25769e24e96..6092ba9ff098a 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -34,7 +34,7 @@ void testPrepack(int rows, int columns) { col_blocking>; EXPECT_TRUE(Base::weight_dimension_supported(rows, columns)) - << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; + << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..2e8d85f38b2db 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -21,10 +21,15 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { std::string l1_transformer = "ConstantFolding"; std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + std::shared_ptr cpu_ep = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + auto status = execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, execution_providers); + auto filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, {}, execution_providers, disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +52,9 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, execution_providers); + filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level2, {}, execution_providers, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } From 38d1851d5dfb99172a3af5209267e403ad08acbc Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 14 Mar 2024 18:13:19 +0000 Subject: [PATCH 5/5] strange compilation error --- cmake/onnxruntime_unittests.cmake | 7 + docs/ContribOperators.md | 256 +++++++++--------- .../cuda/quantization/matmul_nbits.cc | 2 +- .../core/mickey/blk_q4/f16_gemm_sm80.h | 10 + 4 files changed, 149 insertions(+), 126 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b004054c616a5..40ed0021f7f07 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -267,6 +267,13 @@ if(NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) "${TEST_SRC_DIR}/optimizer/*.h" ) + if (MSVC AND ((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC"))) + set_source_files_properties("${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" PROPERTIES COMPILE_FLAGS "/bigobj") + list(REMOVE_ITEM onnxruntime_test_optimizer_src + "${TEST_SRC_DIR}/optimizer/gpu_op_prepack_test.cc" + ) +endif() + set(onnxruntime_test_framework_src_patterns "${TEST_SRC_DIR}/framework/*.cc" "${TEST_SRC_DIR}/framework/*.h" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5f0100fad95a2..1c96a6ec796e2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -116,24 +116,24 @@ Do not modify directly.* ### **com.microsoft.Attention** Multi-Head Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). - + The weights for input projection of Q, K and V are merged. The data is stacked on the second dimension. Its shape is (input_hidden_size, hidden_size + hidden_size + v_hidden_size). Here hidden_size is the hidden dimension of Q and K, and v_hidden_size is that of V. - + The mask_index is optional. Besides raw attention mask with shape (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) with value 0 for masked and 1 otherwise, we support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size), where value is actual sequence length excluding padding. When input has left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by the inclusive start positions. - + When unidirectional is 1, each token only attends to previous tokens. - + Both past and present state are optional. They shall be used together, and not allowed to use only one of them. The qkv_hidden_sizes is required only when K and V have different hidden sizes. - + When there is past state, hidden dimension for Q, K and V shall be the same. - + The total_sequence_length is past_sequence_length + kv_sequence_length. Here kv_sequence_length is the length of K or V. For self attention, kv_sequence_length equals to sequence_length (sequence length of Q). For cross attention, query and key might have different lengths. @@ -205,133 +205,133 @@ This version of the operator has been available since version 1 of the 'com.micr Computes an one-layer RNN where its RNN Cell is an AttentionWrapper wrapped a LSTM Cell. The RNN layer contains following basic component: LSTM Cell, Bahdanau Attention Mechanism, AttentionWrapp. - + Activation functions: - + Relu(x) - max(0, x) - + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - + Sigmoid(x) - 1/(1 + e^{-x}) - + (NOTE: Below are optional) - + Affine(x) - alpha*x + beta - + LeakyRelu(x) - x if x >= 0 else alpha * x - + ThresholdedRelu(x) - x if x >= alpha else 0 - + ScaledTanh(x) - alpha*Tanh(beta*x) - + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - + Elu(x) - x if x >= 0 else alpha*(e^x - 1) - + Softsign(x) - x/(1 + |x|) - + Softplus(x) - log(1 + e^x) - + Softmax(x) - exp(x) / sum(exp(x)) - + Bahdanau Attention Mechanism: `M` - Memory tensor. - + `VALUES` - masked Memory by its real sequence length. - + `MW` - Memory layer weight. - + `KEYS` - Processed memory tensor by the memory layer. KEYS = M * MW - + `Query` - Query tensor, normally at specific time step in sequence. - + `QW` - Query layer weight in the attention mechanism - + `PQ` - processed query, = `Query` * `QW` - + `V' - attention vector - + `ALIGN` - calculated alignment based on Query and KEYS ALIGN = softmax(reduce_sum(`V` * Tanh(`KEYS` + `PQ`))) - + `CONTEXT` - context based on `ALIGN` and `VALUES` CONTEXT = `ALIGN` * `VALUES` - - + + LSTM Cell: `X` - input tensor concat with attention state in the attention wrapper - + `i` - input gate - + `o` - output gate - + `f` - forget gate - + `c` - cell gate - + `t` - time step (t-1 means previous time step) - + `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates - + `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates - + `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates - + `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates - + `P[iof]` - P peephole weight vector for input, output, and forget gates - + `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates - + `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates - + `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates - + `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates - + `PB[iof]` - P peephole weight vector for backward input, output, and forget gates - + `H` - Hidden state - + `num_directions` - 2 if direction == bidirectional else 1 - + Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - + - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - + - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - + - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - + - Ct = ft (.) Ct-1 + it (.) ct - + - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - + - Ht = ot (.) h(Ct) - - + + AttentionWrapp Notations: `lstm()' - wrapped inner cell. Ht, Ct = lstm(concat(Xt, ATTNt-1), Ct-1) - + `am()` - attention mechanism the wrapper used. CONTEXTt, ALIGNt = am(Ht, ALIGNt-1) - + `AW` - attention layer weights, optional. - + `ATTN` - attention state, initial is zero. If `AW` provided, it is the output of the attention layer, ATTNt = concat(Ht, CONTEXTt) * AW otherwise, ATTNt = CONTEXTt - + RNN layer output: `Y` - if needed is the sequence of Ht from lstm cell. - + `Y_h` - is the last valid H from lstm cell. - + `Y_c` - is the last valid C from lstm cell. - + #### Version @@ -585,7 +585,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.BiasGelu** Bias Gelu. - It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. + It's an extension of Gelu. It takes the sum of input A and bias input B as the input of Gelu activation. #### Version @@ -810,7 +810,7 @@ This version of the operator has been available since version 1 of the 'com.micr ``` scale = 1. / (1. - ratio). ``` - + This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt that the mask is output as a bit-packed uint32 tensor, instead of a boolean tensor. #### Version @@ -1206,17 +1206,17 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.DecoderMaskedSelfAttention** Self attention that supports input sequence length of 1. - + The weights for input projection of Q, K and V are merged. The data is stacked on the second dimension. Its shape is (input_hidden_size, hidden_size + hidden_size + v_hidden_size). Here hidden_size is the hidden dimension of Q and K, and v_hidden_size is that of V. - + The mask_index is optional. If it is provided, only raw attention mask with shape (batch_size, total_sequence_length) is supported currently. - + Both past and present state need to be provided. - + The qkv_hidden_sizes is required only when K and V have different hidden sizes. - + The total_sequence_length is past_sequence_length + kv_sequence_length. Here kv_sequence_length is the length of K or V. Currently, only self attention is supported which means that kv_sequence_length equals to sequence_length (sequence length of Q). @@ -2285,7 +2285,7 @@ This version of the operator has been available since version 1 of the 'com.micr which are used to interpolate the output value `output[n, :, h, w]`. The GridSample operator is often used in doing grid generator and sampler in the [Spatial Transformer Networks](https://arxiv.org/abs/1506.02025). See also in [torch.nn.functional.grid_sample](https://pytorch.org/docs/master/generated/torch.nn.functional.grid_sample.html#torch-nn-functional-grid-sample). - + #### Version @@ -2331,13 +2331,13 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.GroupNorm** Applies Group Normalization over a mini-batch of inputs as described in the paper Group Normalization (https://arxiv.org/abs/1803.08494). - + This operator transforms input according to y = gamma * (x - mean) / sqrt(variance + epsilon) + beta - + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. num_channels must be divisible by num_groups. The mean and standard-deviation are calculated separately over the each group. The weight and bias are per-channel affine transform parameter vectors of size num_channels. - + The activation attribute can be used to enable activation after group normalization. #### Version @@ -2388,7 +2388,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.GroupQueryAttention** Group Query Self/Cross Attention. - + Supports different number of heads for q and kv. Only supports causal or local attention. #### Version @@ -2530,10 +2530,10 @@ This version of the operator has been available since version 1 of the 'com.micr Longformer Self Attention with a local context and a global context. Tokens attend locally: Each token attends to its W previous tokens and W succeeding tokens with W being the window length. A selected few tokens attend globally to all other tokens. - + The attention mask is of shape (batch_size, sequence_length), where sequence_length is a multiple of 2W after padding. Mask value < 0 (like -10000.0) means the token is masked, 0 otherwise. - + Global attention flags have value 1 for the tokens attend globally and 0 otherwise. #### Version @@ -2592,32 +2592,32 @@ This version of the operator has been available since version 1 of the 'com.micr 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's quantization constants or scales are specified by input 'absmax'. - + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. - - + + 1. (Default value) transB=True (Majorly used for forward pass) Shape of A: [D0, D1, ..., Dn, K] Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. - + The computation math: dequant_B = dequant(B, absmax, quant_type, block_size) transposed_dequant_B = dequant_B^T output = A @ transposed_dequant_B - + Shape of output: [D0, D1, ..., Dn, N] - + 2. transB=False (Majorly used for backward pass) Shape of A: [D0, D1, ..., Dn, N] Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features]. - + The computation math: dequant_B = dequant(B, absmax, quant_type, block_size) output = A @ dequant_B - + Shape of output: [D0, D1, ..., Dn, K] - + #### Version @@ -2807,7 +2807,7 @@ This version of the operator has been available since version 1 of the 'com.micr 2. Input B is quantized with x bits which is specified by attribute 'bits'. It is quantized blockwisely along dimension 0 (e.g. column) with block size specified by attribute block_size. And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - n_blocks_per_col = (K + block_size - 1) / block_size - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) @@ -2819,8 +2819,8 @@ This version of the operator has been available since version 1 of the 'com.micr 3bit example: |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. The last uint_8 may have some bits unused. - - + + Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] @@ -2843,6 +2843,14 @@ This version of the operator has been available since version 1 of the 'com.micr
number of bits used for weight quantization (default 4)
block_size : int (required)
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
+
column_wise_blocking : int
+
whether to quantize weight columnwise (value 1, default), or rowwise (value 0)
+
prepacked : int
+
+Indicates whether the weight matrix is prepacked (value 1), or not (value 0, default). +This property should NEVER be set by user. It is set by ONNX Runtime internally during +model loading time. +
#### Inputs (3 - 5) @@ -2933,7 +2941,7 @@ This version of the operator has been available since version 1 of the 'com.micr Mixture of experts. Examples: Switch transformer(https://arxiv.org/pdf/2101.03961.pdf) use top 1, GLaM(https://arxiv.org/abs/2112.06905) activates top 2 FFN, and Vision MOE(https://arxiv.org/pdf/2106.05974.pdf) usually uses top 32 experts. - + #### Version @@ -2985,11 +2993,11 @@ This version of the operator has been available since version 1 of the 'com.micr Performs element-wise binary quantized multiplication (with Numpy-style broadcasting support). "This operator supports **multidirectional (i.e., Numpy-style) broadcasting**" The output of this op is the int32 accumulated result of the mul operation - + ``` C (int32) = (A - A_zero_point) * (B - B_zero_point) ``` - + #### Version @@ -3028,7 +3036,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.MultiHeadAttention** Multi-Head Self/Cross Attention. Bias from input projection is included. - + The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0 means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of each key sequence excluding paddings. @@ -3329,25 +3337,25 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.PackedAttention** This is the packed version of Attention. - + Sequences in one batch usually don't have same length and they are padded to have same length, e.g., below is a batch with 3 sequences and tokens* are padded. Sequence_0: 0, 1*, 2*, 3* Sequence_1: 4, 5, 6*, 7* Sequence_2: 8, 9, 10, 11 - + PackedAttention is designed to takes in packed input, i.e., only the real tokens without padding. An input as above will be packed into 3 tensors like below: - input ([h0, h4, h5, h8, h9, h10, h11]) - token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7* - cumulated_token_count: 0, 1, 1+2, 1+2+4 - + Input tensors contains the hidden embedding of real tokens. Token_offset records the offset of token in the unpacked input. cumulated_token_count records cumulated length of each sequnces length. - + The operator only supports BERT like model with padding on right now. - + #### Version @@ -3401,13 +3409,13 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.PackedMultiHeadAttention** This is the packed version of MultiHeadAttention. - + Sequences in one batch usually don't have same length and they are padded to have same length, e.g., below is a batch with 3 sequences and * is padding token. Sequence_0: 0, 1*, 2*, 3* Sequence_1: 4, 5, 6*, 7* Sequence_2: 8, 9, 10, 11 - + PackedMultiHeadAttention is designed to takes in packed input, i.e., only the real tokens without padding. An input as above will be packed into 3 tensors like below: - query ([q0, q4, q5, q8, q9, q10, q11]) @@ -3415,11 +3423,11 @@ This version of the operator has been available since version 1 of the 'com.micr - value ([v0, v4, v5, v8, v9, v10, v11]) - token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7* - cumulative_sequence_length: 0, 1, 1+2, 1+2+4 - + The query, key and value tensors contain result of hidden embedding of real tokens after input projections. Token_offset records the offset of token in the unpacked input. cumulative_sequence_length records cumulated length of each sequnces length. - + The operator only supports BERT like model with padding on right now. #### Version @@ -3491,7 +3499,7 @@ This version of the operator has been available since version 1 of the 'com.micr [0.0, 0.0, 4.5, 5.7], ], ] - + #### Version @@ -3671,7 +3679,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearAdd** Performs element-wise binary addition on 8 bit data types (with Numpy-style broadcasting support). - + C = (A_scale * (A - A_zero_point) + B_scale * (B - B_zero_point))/C_scale + C_zero_point #### Version @@ -3729,11 +3737,11 @@ This version of the operator has been available since version 1 of the 'com.micr output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) ``` if ceil_mode is enabled - + ``` * pad_shape[i] is sum of pads along axis i ``` - + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: ``` VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - kernel_spatial_shape[i] + 1) / strides_spatial_shape[i]) @@ -3743,9 +3751,9 @@ This version of the operator has been available since version 1 of the 'com.micr ``` pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] ``` - + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). - + Input and output scales and zero points are used to convert the output to a new quantization range. Output = Dequantize(Input) -> AveragePool on fp32 data -> Quantize(output) @@ -4013,7 +4021,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QLinearMul** Performs element-wise binary multiplication on 8 bit data types (with Numpy-style broadcasting support). - + C = ((A - A_zero_point) * (B - B_zero_point)) * (A_scale * B_scale)/C_scale + C_zero_point #### Version @@ -4064,10 +4072,10 @@ This version of the operator has been available since version 1 of the 'com.micr with the exception that numpy default keepdims to False instead of True. Input and Output scales and zero points are used to requantize the output in a new range. This helps to improve accuracy as after ReduceMean operation the range of the output is expected to decrease. - + ``` "Output = Dequantize(Input) -> ReduceMean on fp32 data -> Quantize(output)", - + ``` #### Version @@ -4117,7 +4125,7 @@ This version of the operator has been available since version 1 of the 'com.micr QLinearSigmoid takes quantized input data (Tensor), and quantize parameter for output, and produces one output data (Tensor) where the function `f(x) = quantize(Sigmoid(dequantize(x)))`, is applied to the data tensor elementwise. - Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` + Wwhere the function `Sigmoid(x) = 1 / (1 + exp(-x))` #### Version @@ -4903,10 +4911,10 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RemovePadding** Compress transformer input by removing paddings. It assumes padding is on the right side of sequence. - + The input has padding with shape (batch_size, sequence_length, hidden_size). This will generate two outputs: output has shape (total_tokens, hidden_size); token_offset with shape (batch_size, sequence_length). - + token_offset has offsets of all non-padding tokens first, then offset of all padding tokens. It is a list of batch_size * sequence_length elements, which is reshaped to 2D for convenience of shape inference. @@ -4949,7 +4957,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.RestorePadding** Restore paddings and fill padding with zeros. - + The input has padding with shape (total_tokens, hidden_size) and token_offset with shape (batch_size, sequence_length). The output has shape (batch_size, sequence_length, hidden_size). @@ -5194,16 +5202,16 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.SkipGroupNorm** This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. - + This operator transforms input according to s = x + skip + bias y = gamma * (s - mean) / sqrt(variance + epsilon) + beta - + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. The num_channels must be divisible by num_groups. The mean and standard-deviation of s are calculated separately over the each group. The weight and bias are per-channel affine transform parameter vectors of size num_channels. - + The activation attribute can be used to enable activation after group normalization. #### Version @@ -5525,7 +5533,7 @@ This version of the operator has been available since version 1 of the 'com.micr Based on Torch operator Embedding, creates a lookup table of embedding vectors of fixed size, for a dictionary of fixed size. - + #### Version @@ -5615,7 +5623,7 @@ This version of the operator has been available since version 1 of the 'com.micr the main diagonal. A negative k value includes as many diagonals below the main diagonal. If upper is set to false, a positive k retains the lower triangular matrix including k diagonals above the main diagonal. A negative k value excludes as many diagonals below the main diagonal. - + #### Version @@ -5707,7 +5715,7 @@ This version of the operator has been available since version 1 of the 'com.micr output_uniques = [2, 1, 3, 4] output_idx = [0, 1, 1, 2, 3, 2] output_counts = [1, 2, 2, 1] - + #### Version @@ -6019,5 +6027,3 @@ No versioning maintained for experimental ops.
T : tensor(float)
Constrain input and output types to float32 tensors.
- - diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index b8cfde96200b1..a29aaf460b7a6 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -66,7 +66,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!"); return PrepackedGemm( static_cast(ctx->GetComputeStream()->GetHandle()), - helper.M(), a, b, scales, zero_points, Y); + static_cast(helper.M()), a, b, scales, zero_points, Y); } const auto* a_data = a->Data(); diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h index 52bff7e40dbe3..66b54a2504475 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -11,9 +11,19 @@ #pragma once +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/cutlass.h" #include "cutlass_ext/q4gemm/device/quantb_gemm.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + namespace onnxruntime { namespace cuda {