From 65573be864aaf9c0a4dcaa095d99dcde2c9b5a5e Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Thu, 7 Mar 2024 17:32:01 +0000 Subject: [PATCH] bug fix: can't use A tensor shape[0] for M --- cmake/onnxruntime_providers_cuda.cmake | 4 - .../cuda/quantization/matmul_nbits.cc | 34 ++++---- .../cuda/quantization/matmul_nbits.cu | 2 +- .../cuda/quantization/matmul_nbits.h | 9 +- onnxruntime/core/graph/graph_utils.h | 4 +- .../core/mickey/blk_q4/f16_prepack_sm80.h | 56 ++++++------- .../quantb_meta_mma_tensor_op_tile_iterator.h | 14 ++-- onnxruntime/core/optimizer/gpu_ops_prepack.cc | 84 +++++++++---------- onnxruntime/core/util/matrix_layout.h | 8 +- .../test/cuda_host/blkq4_fp16_quant_sm80.h | 8 +- .../test/optimizer/gpu_op_prepack_test.cc | 68 ++++++++++----- .../test/optimizer/graph_transform_test.cc | 6 +- .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 2 +- .../optimizer/graph_transformer_utils_test.cc | 16 ++-- 14 files changed, 171 insertions(+), 144 deletions(-) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 2a73a14f1588d..858c6c8b09abe 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -285,10 +285,6 @@ endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) - # TODO only needed in DEBUG builds, need cmake expert advice on how to do that - set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") - set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ") - install(TARGETS onnxruntime_providers_cuda ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 1d1853725aac4..b8cfde96200b1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -15,15 +15,15 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template<> +template <> Status MatMulNBits::PrepackedGemm( - cudaStream_t stream, - const Tensor* a, - const Tensor* b, - const Tensor* scales, - const Tensor* zero_points, - Tensor* Y) const { - int64_t M = a->Shape()[0]; + cudaStream_t stream, + int M, + const Tensor* a, + const Tensor* b, + const Tensor* scales, + const Tensor* zero_points, + Tensor* Y) const { uint8_t const* zero_points_ptr = nullptr; size_t zero_points_size = 0; if (zero_points != nullptr) { @@ -32,12 +32,12 @@ Status MatMulNBits::PrepackedGemm( } return blkq4_fp16_gemm_sm80_dispatch( - int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, - a->Data(), a->Shape().Size(), - b->Data(), b->Shape().Size(), - scales->Data(), scales->Shape().Size(), - zero_points_ptr, zero_points_size, - Y->MutableData(), Y->Shape().Size()); + int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream, + a->Data(), a->Shape().Size(), + b->Data(), b->Shape().Size(), + scales->Data(), scales->Shape().Size(), + zero_points_ptr, zero_points_size, + Y->MutableData(), Y->Shape().Size()); } template @@ -59,14 +59,14 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - if (prepack_ > 0){ + if (prepack_ > 0) { ORT_RETURN_IF(reorder_idx != nullptr, "Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!"); ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType(), "Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!"); return PrepackedGemm( - static_cast(ctx->GetComputeStream()->GetHandle()), - a, b, scales, zero_points, Y); + static_cast(ctx->GetComputeStream()->GetHandle()), + helper.M(), a, b, scales, zero_points, Y); } const auto* a_data = a->Data(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index 9e394cc82e355..0766c408bd7fd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -390,7 +390,7 @@ Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream, const cutlass::gemm::GemmCoord problem_size = {m, n, k}; - ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct"); + ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct: ", a.size_bytes(), " vs m: ", m, "k: ", k , " size: ", m * k * sizeof(ElementDequant)); cutlass::TensorRef ref_a( reinterpret_cast(a.data()), LayoutInputA::packed({m, k})); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h index 9e672d724c7a8..4ba7283283581 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -34,8 +34,13 @@ class MatMulNBits final : public CudaKernel { info.GetAttrOrDefault("prepacked", &prepack_, int64_t(0)); } - Status PrepackedGemm(cudaStream_t stream, const Tensor* a, const Tensor* b, - const Tensor* scales, const Tensor* zero_points, Tensor* Y) const { + Status PrepackedGemm([[maybe_unused]] cudaStream_t stream, + [[maybe_unused]] int M, + [[maybe_unused]] const Tensor* a, + [[maybe_unused]] const Tensor* b, + [[maybe_unused]] const Tensor* scales, + [[maybe_unused]] const Tensor* zero_points, + [[maybe_unused]] Tensor* Y) const { ORT_THROW("Prepacked gemm is not supported for MatMulNBits op."); } diff --git a/onnxruntime/core/graph/graph_utils.h b/onnxruntime/core/graph/graph_utils.h index be1049d78fdf0..3273e709077d5 100644 --- a/onnxruntime/core/graph/graph_utils.h +++ b/onnxruntime/core/graph/graph_utils.h @@ -26,8 +26,7 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node, #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) /** Returns the attribute of a Node with a given name. */ -static inline -const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { +static inline const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name) { const auto& attrs = node.GetAttributes(); const auto iter = attrs.find(attr_name); return iter == attrs.end() ? nullptr : &iter->second; @@ -49,7 +48,6 @@ inline Status TryGetNodeAttribute(const Node& node, const std::string& return Status::OK(); } - /** Add a new initializer to 'graph'. Checks that new_initializer does not already exist in 'graph' before adding it. @returns The NodeArg for the new initializer. diff --git a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index a1f9b80d7754b..7e0a53d739121 100644 --- a/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -187,8 +187,8 @@ struct BlockwiseQuantization { static constexpr bool ShouldRearrangeMeta = sizeof(ElementT) == 2 && QuantBlocking::kRow == 1; static void prepack_quant_scales( - size_t rows, - size_t columns, + int rows, + int columns, const gsl::span& scales, // <- quant scales, column major layout const gsl::span& scales_prepacked // <- quant scales prepacked, same size buffer ) { @@ -261,8 +261,8 @@ struct BlockwiseQuantization { } static void prepack_quant_offsets( - size_t rows, - size_t columns, + int rows, + int columns, const gsl::span& offsets, // <- quant offsets, int4, column major layout const gsl::span& offsets_prepacked // <- quant offsets prepacked, double size buffer ) { @@ -345,8 +345,8 @@ struct BlockwiseQuantization { }; static inline bool IsSm80WithWholeBlocks( - int weight_rows, [[maybe_unused]] int weight_cols, - int major, [[maybe_unused]] int minor) { + int weight_rows, [[maybe_unused]] int weight_cols, + int major, [[maybe_unused]] int minor) { if (major < 8) { return false; } @@ -364,9 +364,8 @@ static inline bool IsSm80WithWholeBlocks( return (weight_rows % 64 == 0); } -template -inline -bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { +template +inline bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int minor) { using Base = BlockwiseQuantization; if (!Base::weight_dimension_supported(weight_rows, weight_cols)) { return false; @@ -375,26 +374,25 @@ bool BlkQuantGemmSm80Supported(int weight_rows, int weight_cols, int major, int } static inline bool BlkQuantGemmSm80Supported(int block_size, bool col_blocking, int weight_rows, int weight_cols, int major, int minor) { - switch (block_size) - { - case 16: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } - case 32: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } - case 64: - if (col_blocking) { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } else { - return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); - } + switch (block_size) { + case 16: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 32: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } + case 64: + if (col_blocking) { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } else { + return onnxruntime::cuda::BlkQuantGemmSm80Supported(weight_rows, weight_cols, major, minor); + } } return false; } diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h index 4ba39dda3db8d..016e7dcf27869 100644 --- a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -453,7 +453,7 @@ class QuantBMetaMmaTensorOpTileIterator(dest.data()); const b64* scales_ptr = reinterpret_cast(scales.data()); - const ElementOffset* offsets_ptr = nullptr; + [[maybe_unused]] const ElementOffset* offsets_ptr = nullptr; if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } CUTLASS_PRAGMA_UNROLL @@ -461,11 +461,11 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) const uint32_t* p = reinterpret_cast(offsets_ptr); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) asm volatile( "{\n\t" " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands @@ -796,12 +796,12 @@ class QuantBMetaMmaTensorOpTileIterator(scales.data()); - uint32_t* addon_ptr = reinterpret_cast(addon); - if constexpr (kHasOffset){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // possible buffer over read 2 bytes here. const uint32_t* p = reinterpret_cast(offsets.data()); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + uint32_t* addon_ptr = reinterpret_cast(addon); + asm volatile( "{\n\t" " .reg .b32 rb0, rb1, rb2;\n" @@ -823,6 +823,8 @@ class QuantBMetaMmaTensorOpTileIterator= 800)) + uint32_t* addon_ptr = reinterpret_cast(addon); + asm volatile( "{\n\t" " .reg .b32 rb0;\n" diff --git a/onnxruntime/core/optimizer/gpu_ops_prepack.cc b/onnxruntime/core/optimizer/gpu_ops_prepack.cc index b0219124c13de..484f6c42ba5a9 100644 --- a/onnxruntime/core/optimizer/gpu_ops_prepack.cc +++ b/onnxruntime/core/optimizer/gpu_ops_prepack.cc @@ -24,7 +24,6 @@ // 3. The logic of prepacking depends on underlying GPU // hardware. Currently this part is hard-coded for SM80. - #include "core/graph/graph_utils.h" #include "core/optimizer/initializer.h" #include "core/optimizer/gpu_ops_prepack.h" @@ -43,17 +42,17 @@ extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); /** * @brief Read initialized tensor from protobuf, and store it in ort_value. * Keep in mind that ort_value is the owner of the tensor memory after calling this function. -*/ + */ inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_value) { const ONNX_NAMESPACE::TensorProto* tensor_proto; ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), "Missing initializer for ", arg->Name()); - const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + const auto path_str = graph.ModelPath().ToPathString(); return utils::TensorProtoToOrtValue( - Env::Default(), path_c_str, *tensor_proto, - std::make_shared(), ort_value); + Env::Default(), path_str.c_str(), *tensor_proto, + std::make_shared(), ort_value); } template @@ -65,7 +64,7 @@ inline gsl::span make_span(std::string& str) { // Prepacking logic specific to MatMulNBits on sm80 // -static inline bool IsNodeMatMulNbitsFp16(const Node& node){ +static inline bool IsNodeMatMulNbitsFp16(const Node& node) { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMulNBits", {1}, kMSDomain)) { return false; } @@ -78,13 +77,13 @@ static inline bool IsNodeMatMulNbitsFp16(const Node& node){ template void Sm80BlkQ4PrepackT( - int rows, int columns, - gsl::span weights, - gsl::span scales, - gsl::span zp, - std::string& packed_w, - std::string& packed_scales, - std::string& packed_zp) { + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { using Base = onnxruntime::cuda::BlockwiseQuantization< MLFloat16, block_size, @@ -93,33 +92,33 @@ void Sm80BlkQ4PrepackT( const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - packed_w.resize(q_weight_shape.product() * sizeof(uint8_t)); + packed_w.resize(SafeInt(q_weight_shape.product() * sizeof(uint8_t))); Base::prepack_weights( - rows, columns, weights, - make_span(packed_w)); + rows, columns, weights, + make_span(packed_w)); - packed_scales.resize(meta_shape.product() * sizeof(MLFloat16)); + packed_scales.resize(SafeInt(meta_shape.product() * sizeof(MLFloat16))); Base::prepack_quant_scales( - rows, columns, scales, - make_span(packed_scales)); + rows, columns, scales, + make_span(packed_scales)); if (!zp.empty()) { - packed_zp.resize(meta_shape.product() * sizeof(uint8_t)); + packed_zp.resize(SafeInt(meta_shape.product() * sizeof(uint8_t))); Base::prepack_quant_offsets( - rows, columns, zp, - make_span(packed_zp)); + rows, columns, zp, + make_span(packed_zp)); } } void Sm80BlkQ4Prepack( - int block_size, bool column_quant_blk, - int rows, int columns, - gsl::span weights, - gsl::span scales, - gsl::span zp, - std::string& packed_w, - std::string& packed_scales, - std::string& packed_zp) { + int block_size, bool column_quant_blk, + int rows, int columns, + gsl::span weights, + gsl::span scales, + gsl::span zp, + std::string& packed_w, + std::string& packed_scales, + std::string& packed_zp) { switch (block_size) { case 16: if (column_quant_blk) { @@ -161,21 +160,23 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { Status status = graph_utils::TryGetNodeAttribute(node, "prepacked", att_i); bool prepacked = status.IsOK() ? att_i != 0 : false; if (prepacked) { - return Status::OK(); // already prepacked, nothing to do + return Status::OK(); // already prepacked, nothing to do } ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "bits", att_i)); - int nbits = static_cast(att_i); + int nbits = SafeInt(att_i); if (nbits != 4) { - return Status::OK(); // only support 4 bits for now + return Status::OK(); // only support 4 bits for now } + // A single dimension can not exceed 2G yet. ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "K", att_i)); - int k = static_cast(att_i); + int k = SafeInt(att_i); ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "N", att_i)); - int n = static_cast(att_i); + int n = SafeInt(att_i); + ORT_RETURN_IF_ERROR(graph_utils::TryGetNodeAttribute(node, "block_size", att_i)); - int block_size = static_cast(att_i); + int block_size = SafeInt(att_i); status = graph_utils::TryGetNodeAttribute(node, "column_wise_blocking", att_i); bool column_wise_quant_blk = status.IsOK() ? att_i != 0 : true; @@ -184,10 +185,10 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { ORT_ENFORCE(provider_info != nullptr, "Failed to query CUDA provider info while prepacking cuda operators."); int major, minor; ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, - "Failed to query CUDA device version while prepacking cuda operators."); + "Failed to query CUDA device version while prepacking cuda operators."); if (!onnxruntime::cuda::BlkQuantGemmSm80Supported(block_size, column_wise_quant_blk, k, n, major, minor)) { - return Status::OK(); // not supported + return Status::OK(); // not supported } // @@ -196,7 +197,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { auto& node_name = node.Name(); auto& mutable_input_defs = node.MutableInputDefs(); if (mutable_input_defs.size() < 3 || mutable_input_defs.size() > 4) { - return Status::OK(); // not supported + return Status::OK(); // not supported } NodeArg* old_weights_arg = mutable_input_defs[1]; @@ -227,7 +228,7 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { ORT_RETURN_IF_ERROR(GetOrtValue(old_zp_arg, graph, zp_val)); Tensor* zp_tensor_ptr = zp_val.GetMutable(); if (!zp_tensor_ptr->IsDataType()) { - return Status::OK(); // not supported + return Status::OK(); // not supported } zp = zp_tensor_ptr->DataAsSpan(); } @@ -289,7 +290,6 @@ Status PackMatMulNBitsFp16(Node& node, Graph& graph, bool& modified) { return Status::OK(); } - Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); @@ -304,7 +304,7 @@ Status GpuOpsPrepack::ApplyImpl(Graph& graph, bool& modified, int graph_level, c ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (node.GetExecutionProviderType() != onnxruntime::kCudaExecutionProvider) { - continue; // only interested in CUDA nodes + continue; // only interested in CUDA nodes } // Run prepack if the node is MatMulNBits. diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index 783a29d8a2055..b4cf8cf518564 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -378,7 +378,7 @@ class MatrixRef { MatrixRef( NonConstMatrixRef const& ref, ///< MatrixRef to non-const data /// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const - _Magic magic = (typename std::enable_if::type)0 + [[maybe_unused]] _Magic magic = (typename std::enable_if::type)0 ) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {} ORT_FORCEINLINE @@ -428,18 +428,18 @@ class MatrixRef { /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference at(MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } ORT_FORCEINLINE Reference at(int row, int col) const { - return data_[offset(make_Position(row, col))]; + return data_[static_cast(offset(make_Position(row, col)))]; } /// Returns a reference to the element at a given Coord ORT_FORCEINLINE Reference operator[](MatCoord const& coord) const { - return data_[offset(coord)]; + return data_[static_cast(offset(coord))]; } }; diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h index ca5f80030dd6d..18700c9de84ca 100644 --- a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -68,7 +68,6 @@ inline void blkq4_weights_gen( const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); // // For testing quantization and dequantization, it is not straight @@ -114,9 +113,9 @@ inline void blkq4_weights_gen( q_scales.resize(meta_shape.product()); for (size_t i = 0; i < q_scales.size(); i++) { - uint32_t v = dis(gen); - uint32_t m = (v % 63) + 1; - uint32_t e = (v >> 6) % 4; + uint32_t vscale = dis(gen); + uint32_t m = (vscale % 63) + 1; + uint32_t e = (vscale >> 6) % 4; q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); } MatrixRef tensor_scale( @@ -124,6 +123,7 @@ inline void blkq4_weights_gen( MatrixRef tensor_offset; if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); q_zp.resize(zp_shape.product()); tensor_offset = MatrixRef( q_zp, zp_shape); diff --git a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc index 5fb90c59684f3..d9573d30cf296 100644 --- a/onnxruntime/test/optimizer/gpu_op_prepack_test.cc +++ b/onnxruntime/test/optimizer/gpu_op_prepack_test.cc @@ -13,6 +13,8 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "core/providers/cuda/cuda_provider_factory.h" +#include "core/providers/cuda/cuda_execution_provider_info.h" #include "test/cuda_host/blkq4_fp16_quant_sm80.h" #include "test/compare_ortvalue.h" @@ -21,16 +23,24 @@ #include "test/util/include/inference_session_wrapper.h" namespace onnxruntime { + +extern ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); + namespace test { #ifndef DISABLE_CONTRIB_OPS std::shared_ptr LoadCudaEp() { - OrtCUDAProviderOptions cuda_options; - auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); - if (!factory) { + try { + OrtCUDAProviderOptions cuda_options; + auto factory = onnxruntime::CudaProviderFactoryCreator::Create(&cuda_options); + if (!factory) { + return nullptr; + } + return factory->CreateProvider(); + } catch (const ::onnxruntime::OnnxRuntimeException& e) { + std::cerr << "LoadCudaEp: " << e.what() << std::endl; return nullptr; } - return std::move(factory->CreateProvider()); } /** @@ -41,18 +51,18 @@ std::shared_ptr LoadCudaEp() { * - the addition of cuda execution provider in the session. * - a different location for the model checker, right after session initialization * as the initializers will be deleted during session run. -*/ + */ void GpuPrepackTester( - const std::shared_ptr& cuda_ep, - const std::function& build_test_case, - const std::function& check_transformed_graph, - TransformerLevel baseline_level, - TransformerLevel target_level, - int opset_version = 12, - double per_sample_tolerance = 0.001, - double relative_per_sample_tolerance = 0.001, - const std::function& add_session_options = {}, - const InlinedHashSet& disabled_optimizers = {}) { + const std::shared_ptr& cuda_ep, + const std::function& build_test_case, + const std::function& check_transformed_graph, + TransformerLevel baseline_level, + TransformerLevel target_level, + int opset_version = 12, + double per_sample_tolerance = 0.001, + double relative_per_sample_tolerance = 0.001, + const std::function& add_session_options = {}, + const InlinedHashSet& disabled_optimizers = {}) { // Build the model for this test. std::unordered_map domain_to_version; domain_to_version[kOnnxDomain] = opset_version; @@ -130,15 +140,15 @@ inline Status GetOrtValue(const NodeArg* arg, const Graph& graph, OrtValue& ort_ ORT_RETURN_IF_NOT(graph.GetInitializedTensor(arg->Name(), tensor_proto), "Missing initializer for ", arg->Name()); - const auto* path_c_str = graph.ModelPath().ToPathString().c_str(); + const auto path_str = graph.ModelPath().ToPathString(); return utils::TensorProtoToOrtValue( - Env::Default(), path_c_str, *tensor_proto, - std::make_shared(), ort_value); + Env::Default(), path_str.c_str(), *tensor_proto, + std::make_shared(), ort_value); } template -void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep){ +void MatMulQ4Test(int M, int N, int K, const std::shared_ptr& cuda_ep) { // // Type definitions // @@ -160,7 +170,6 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr // const auto q_weight_shape = Base::get_quant_weights_shape(K, N); const auto meta_shape = Base::get_quant_meta_shape(K, N); - const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); std::vector q_weights; std::vector q_scales; @@ -176,6 +185,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr q_scales, meta_shape); MatrixRef tensor_offset; if constexpr (has_offsets) { + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); tensor_offset = MatrixRef(q_zp, zp_shape); } @@ -260,7 +270,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr for (size_t i = 0; i < packed_w_ref.size(); ++i) { int expected = packed_w_ref[i]; int found = weights_data[i]; - ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N/2 << "]"; + ASSERT_EQ(expected, found) << "prepacked weight mismatch index i = " << i << " shape[" << K << "," << N / 2 << "]"; } } { @@ -287,7 +297,7 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr ASSERT_EQ(expected, found) << "prepacked zero-point mismatch index i = " << i << " shape[" << meta_shape[0] << "," << meta_shape[1] << "]"; } } else { - ASSERT_LE(node.InputDefs().size(), 3); + ASSERT_LE(node.InputDefs().size(), static_cast(3)); } } } @@ -298,7 +308,6 @@ void MatMulQ4Test(int M, int N, int K, const std::shared_ptr check_graph, TransformerLevel::Level2, TransformerLevel::Level3); - } TEST(GpuOpPrepackTests, MatmulNBits) { @@ -307,6 +316,19 @@ TEST(GpuOpPrepackTests, MatmulNBits) { GTEST_SKIP() << "Skipping tests when CUDA EP is not available"; } + // + // Currently these tests only work on sm_80. Going forward, however, + // we need a better solution when we may have different tests for different + // hardware. + // + auto* provider_info = TryGetProviderInfo_CUDA(); + int major, minor; + ORT_ENFORCE(provider_info->GetCurrentGpuDeviceVersion(&major, &minor) == nullptr, + "Failed to query CUDA device version while prepacking cuda operators."); + if (major < 8) { + GTEST_SKIP() << "Skipping tests when CUDA EP is not sm_80"; + } + // // GpuPrepackTester function implements two different verifications. // First is the hook check_graph, which we use to verify the prepacked weights, scales and zero points. diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 039560bece0d4..84ed86459a77c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -823,7 +823,7 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_map e = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_constant_folding = false; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; @@ -4570,7 +4570,7 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::shared_ptr e = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(e)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(e))); bool has_gelu_approximation = false; auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, @@ -4589,7 +4589,7 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::shared_ptr cpu_ep = std::make_shared(CPUExecutionProviderInfo()); ExecutionProviders execution_providers; - execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + ASSERT_STATUS_OK(execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep))); bool has_double_qdq_remover = false; auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc index 0f25769e24e96..6092ba9ff098a 100644 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -34,7 +34,7 @@ void testPrepack(int rows, int columns) { col_blocking>; EXPECT_TRUE(Base::weight_dimension_supported(rows, columns)) - << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; + << "Test setup problem, unsupported weight dimension: [" << rows << ", " << columns << "]"; using QuantBlocking = typename Base::QuantBlocking; using ElementW = typename Base::ElementW; diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..2e8d85f38b2db 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -21,10 +21,15 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { std::string l1_transformer = "ConstantFolding"; std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; - CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + std::shared_ptr cpu_ep = + std::make_shared(CPUExecutionProviderInfo()); + ExecutionProviders execution_providers; + auto status = execution_providers.Add(kCpuExecutionProvider, std::move(cpu_ep)); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, execution_providers); + auto filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level1, {}, execution_providers, disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +52,9 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, execution_providers); + filtered_transformers = optimizer_utils::GenerateTransformers( + TransformerLevel::Level2, {}, execution_providers, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif }