From 0d10c7f3c1111cfff064e7990aa897ac9fd05c82 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Wed, 7 Feb 2024 21:04:37 +0000 Subject: [PATCH] Revert NeuralSpeed code for x64 MatMulNBits (#19382) ### Description Revert PR#19016 https://github.com/microsoft/onnxruntime/pull/19016 Revert PR#17669 https://github.com/microsoft/onnxruntime/pull/17669 --- cgmanifests/generated/cgmanifest.json | 10 - cmake/CMakeLists.txt | 12 - cmake/deps.txt | 1 - cmake/external/neural_speed.cmake | 15 - cmake/onnxruntime_providers_cpu.cmake | 15 - .../cpu/quantization/matmul_nbits.cc | 144 ------ .../cpu/quantization/neural_speed_defs.h | 45 -- .../cpu/quantization/neural_speed_gemm.cc | 438 ------------------ .../cpu/quantization/neural_speed_gemm.h | 129 ------ .../cpu/quantization/neural_speed_wrapper.h | 39 -- .../test/contrib_ops/matmul_4bits_test.cc | 175 ------- 11 files changed, 1023 deletions(-) delete mode 100644 cmake/external/neural_speed.cmake delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index efd901787fdb7..fc4ea25603152 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -202,16 +202,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", - "repositoryUrl": "https://github.com/intel/neural-speed.git" - }, - "comments": "neural_speed" - } - }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 90fe8276ea9c7..0ccd874cee3c9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,6 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -902,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if(USE_NEURAL_SPEED) - target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) - endif() - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1193,13 +1188,6 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) - include(neural_speed) - if (USE_NEURAL_SPEED) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) - endif() -endif() - # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/deps.txt b/cmake/deps.txt index cb431f8c77397..17c3cbf9a6c43 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -35,7 +35,6 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake deleted file mode 100644 index ed711351403a7..0000000000000 --- a/cmake/external/neural_speed.cmake +++ /dev/null @@ -1,15 +0,0 @@ -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - set(USE_NEURAL_SPEED TRUE) -elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - set(USE_NEURAL_SPEED TRUE) -endif() - -if(USE_NEURAL_SPEED) - FetchContent_Declare( - neural_speed - URL ${DEP_URL_neural_speed} - URL_HASH SHA1=${DEP_SHA1_neural_speed} - ) - set(BTLA_USE_OPENMP OFF) - onnxruntime_fetchcontent_makeavailable(neural_speed) -endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index b81a5c79ac0cc..f60faa4d39116 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,15 +60,6 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() - set(onnxruntime_cpu_neural_speed_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" - ) - if(NOT USE_NEURAL_SPEED) - list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) - endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -153,12 +144,6 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() -if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - if(USE_NEURAL_SPEED) - onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) - endif() -endif() - if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..e8d8bbca66fe7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -10,10 +10,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#ifdef ORT_NEURAL_SPEED -#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" -#endif - namespace onnxruntime { namespace contrib { @@ -23,16 +19,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level static_cast(CompMostAccurate), static_cast(CompLeastAccurate)); -#if defined(ORT_NEURAL_SPEED) - - ORT_UNUSED_PARAMETER(nbits); - ORT_UNUSED_PARAMETER(block_size); - - // Neural Speed APIs already expect a minimum accuracy level so just use the given value. - return accuracy_level; - -#else // defined(ORT_NEURAL_SPEED) - // Find a supported accuracy level that is not less accurate than the one given. // CompMostAccurate is always supported with the fallback implementation. // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. @@ -45,8 +31,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } return effective_accuracy_level; - -#endif // defined(ORT_NEURAL_SPEED) } } // namespace @@ -61,17 +45,6 @@ class MatMulNBits final : public OpKernel { accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); -#ifdef ORT_NEURAL_SPEED - const Tensor* tensor_B = nullptr; - const Tensor* tensor_scale = nullptr; - const Tensor* tensor_zero_point = nullptr; - bool B_constant = info.TryGetConstantInput(1, &tensor_B); - bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); - bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); - is_asym_ = info.GetInputCount() >= 4; - all_constant_ = B_constant && scale_constant; - all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; -#endif } Status Compute(OpKernelContext* context) const override; @@ -92,13 +65,6 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; - -#if defined(ORT_NEURAL_SPEED) - - bool is_asym_{false}; - bool all_constant_{false}; - -#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, @@ -106,54 +72,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#if defined(ORT_NEURAL_SPEED) - - if (!all_constant_) { - return Status::OK(); - } - MLAS_THREADPOOL* pool = NULL; - if (nbits_ != 4) { - return Status::OK(); - } - auto comp_type = static_cast(accuracy_level_); - auto nbits = static_cast(nbits_); - if (input_idx == 1) { - packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); - if (packed_b_size_ == 0) return Status::OK(); - auto qptr = tensor.Data(); - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - std::memset(packed_b_.get(), 0, packed_b_size_); - NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - if (input_idx == 2 && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - if (input_idx == 3 && packed_b_ != nullptr) { - auto zptr = tensor.Data(); - NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - -#else // defined(ORT_NEURAL_SPEED) - if (input_idx == 1) { const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { @@ -173,8 +91,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#endif // defined(ORT_NEURAL_SPEED) - return Status::OK(); } @@ -182,31 +98,11 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#if defined(ORT_NEURAL_SPEED) - - // Pack three tensors into one buffer - if (input_idx == 1) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - if (input_idx == 2) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - if (input_idx == 3) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - -#else // defined(ORT_NEURAL_SPEED) - if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } -#endif // defined(ORT_NEURAL_SPEED) - return Status::OK(); } @@ -216,46 +112,6 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#if defined(ORT_NEURAL_SPEED) - - if (packed_b_) { - TensorShape b_shape({static_cast(N_), static_cast(K_)}); - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - - Tensor* y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) return Status::OK(); - - auto* y_data = y->MutableData(); - - const size_t max_len = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); - const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); - AllocatorPtr allocator; - auto status = ctx->GetTempSpaceAllocator(&allocator); - ORT_RETURN_IF_ERROR(status); - for (size_t i = 0; i < max_len; i++) { - gemm_params[i].A = a_data + helper.LeftOffsets()[i]; - gemm_params[i].lda = lda; - gemm_params[i].B = packed_b_.get(); - gemm_params[i].C = y_data + helper.OutputOffsets()[i]; - gemm_params[i].ldc = N; - } - auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); - // workspace for activation process(dynamic quantization and others) - auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); - return Status::OK(); - } - -#endif // defined(ORT_NEURAL_SPEED) - const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const auto* scales_data = scales->Data(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h deleted file mode 100644 index 864abffd131fe..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h +++ /dev/null @@ -1,45 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" - -namespace bestla { - -using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; -using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; -using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; -using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; -using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; -using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; - -template -using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; -template -using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; - -class ORTThreading : public parallel::IThreading { - public: - explicit ORTThreading(void* tp); - void parallel_for(const parallel::thread_func& func) const override; - void set_threads(int nthreads) override { - (void)(nthreads); - assert(0); - } - void sync() const override { assert(0); } - void* mTp; -}; - -} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc deleted file mode 100644 index 73aaa4ae61a6e..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc +++ /dev/null @@ -1,438 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - neural_speed_gemm.cpp - -Abstract: - - GEMM template combinations of neural_speed. ---*/ - -#include "contrib_ops/cpu/quantization/neural_speed_defs.h" -#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" -#include "core/platform/threadpool.h" - -using ThreadPool = onnxruntime::concurrency::ThreadPool; - -namespace bestla { - -ORTThreading::ORTThreading(void* tp) - : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} - -void ORTThreading::parallel_for(const parallel::thread_func& func) const { - ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, - [&](ptrdiff_t tid) { func(static_cast(tid)); }); -} - -template -static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, - parallel::IThreading* th) { - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); - if (M <= 16) { - using Parallel = parallel::gemm::SchedulerKBlock; - using Launcher = - wrapper::gemm::LauncherKBlock; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->IsAsym()) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); - } - typename Launcher::Param args{gp, - {A, lda_, &reduceA}, - {B}, - {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), - reduceA.template RPtr(), reduceA.lda}, - {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); - } else { - using Parallel = parallel::gemm::SchedulerBase; - using Launcher = - wrapper::gemm::LauncherBase; - static Launcher kernel; - typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); - } -} - -template -static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, - parallel::IThreading* th) { - using Parallel = parallel::gemm::SchedulerKBlockS; - using Launcher = - wrapper::gemm::LauncherIntKBlock; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); - typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); -} - -template -static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(A); - (void)(N); - (void)(C); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using ProA = prologue_a::gemm::ActivationKBlockBaseF32; - static ProA proA; - if (B->IsAsym()) { - auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - // using ProA = prologue_a::gemm::ActivationBase; - return 0; - } -} - -template -static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { - (void)(N); - (void)(lda); - (void)(ldc); - (void)(A); - (void)(C); - using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; - static ProA proA; - auto quanA = - proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); - return quanA.mSize; -} - -} // namespace bestla - -using namespace bestla; - -static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, - void* ThreadPool) { - GetCPUDevice(); - bestla::ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - auto BlkSize = kptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, - DataParams[i].ldc, WorkSpace, &orth); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && - BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, - &orth); - } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, - &orth); - } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && - BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - auto NTile = - gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - auto BlkSize = kptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc), - size); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } - } - } - } - } - return size; -} - -template -static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { - static T proB; - auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), - BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto wptr = reinterpret_cast(ptr); - auto BlkSize = wptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} - -template -static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, - const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, - void* ThreadPool) { - static T proB; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, - BTLA_DTYPE::BF16, IsAsym); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - proB.reduceWeight(&stor, &orth); - } -} - -static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case NSCompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - } - [[fallthrough]]; - case NSCompBf16: - case NSCompFp16: - case NSCompFp32: - case NSCompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - [[fallthrough]]; - default: - return 0; - } -} - -static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, - size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, - NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case NSCompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, - K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - } - [[fallthrough]]; - case NSCompBf16: - case NSCompFp16: - case NSCompFp32: - case NSCompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, - lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, - ldb, ThreadPool); - return true; - } - [[fallthrough]]; - default: - return false; - } -} - -size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, - NS_SQNBIT_COMPUTE_TYPE CompType) { - if (nbits == 4) { - auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } - return 0; -} - -void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, - size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, - NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { - if (nbits == 4) { - if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -} - -void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -} - -size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -} - -void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, - void* ThreadPool) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by bestla - return; - } -} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h deleted file mode 100644 index ebcb3027a209f..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h +++ /dev/null @@ -1,129 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - neural_speed_gemm.h - -Abstract: - - Prepack-weight GEMM APIs of neural_speed. ---*/ - -#pragma once - -#include -#include - -/** - * @brief Define compute types of block quantization - */ -enum NS_SQNBIT_COMPUTE_TYPE { - NSCompUndef = 0, /*!< undef */ - NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ - NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ - NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ - NSCompInt8 = 4 /*!< input int8, accumulator int32 */ -}; - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, - NS_SQNBIT_COMPUTE_TYPE comp_type); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, - size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, - NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, - void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h deleted file mode 100644 index d3902f9bd68c7..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h +++ /dev/null @@ -1,39 +0,0 @@ -//----------------------------------------------------------------------------- -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// -//----------------------------------------------------------------------------- -#pragma once -#if defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wsign-compare" -#pragma GCC diagnostic ignored "-Wmissing-field-initializers" -#pragma GCC diagnostic ignored "-Wunused-variable" -#pragma GCC diagnostic ignored "-Wunused-value" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wunused-function" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wclass-memaccess" -#pragma GCC diagnostic ignored "-Wunused-but-set-variable" -#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" - -#elif defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4457) -#pragma warning(disable : 4189) -#pragma warning(disable : 4100) -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#pragma warning(disable : 4702) -#endif - -#include "bestla/bestla_prologue_a.h" -#include "bestla/bestla_wrapper.h" - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d22da2a3da87f 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -149,17 +149,10 @@ TEST(MatMulNBits, Float32) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { -#ifdef ORT_NEURAL_SPEED - for (auto accuracy_level : {0, 1, 4}) { - RunTest(M, N, K, block_size, accuracy_level, false, false); - RunTest(M, N, K, block_size, accuracy_level, true, false); - } -#else for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); } -#endif } } } @@ -192,174 +185,6 @@ TEST(MatMulNBits, Float16Large) { #endif -void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym, - int64_t acc_lvl) { - // (M x K) X (K x N) - - OpTester test("MatMulNBits", 1, kMSDomain); - test.AddAttribute("accuracy_level", acc_lvl); - test.AddAttribute("block_size", int64_t(block_size)); - test.AddAttribute("bits", QBits); - test.AddAttribute("N", N); - test.AddAttribute("K", K); - - std::vector input0_vals(M * K); - float fv = -135.f; - for (auto& f : input0_vals) { - f = fv / 127; - fv++; - if (fv > 135.f) { - fv = -135.f; - } - } - - size_t kblks = K / block_size; - std::vector input1_vals(N * K / 2); - for (size_t i = 0; i < input1_vals.size(); i++) { - input1_vals[i] = uint8_t(i); - } - std::vector input2_vals(N * kblks, 0.002f); - for (size_t i = 0; i < N * kblks; i++) { - input2_vals[i] += (i % 100) * 0.00003f; - } - std::vector input3_vals(N * kblks / 2, static_cast(0x88)); - - std::vector input1_f_vals(N * K); - if (is_asym) { - for (size_t i = 0; i < N * kblks; i += 2) { - input3_vals[i / 2] = static_cast(i + 1); - } - for (int64_t i = 0; i < K; i += 2) { - for (int64_t j = 0; j < N; j++) { - auto srcv = input1_vals[j * K / 2 + i / 2]; - auto koff = i % (block_size * 2); - auto zpv = input3_vals[j * kblks / 2 + i / block_size / 2]; - auto zp0 = koff < block_size ? (zpv & 0xf) - 8 : ((zpv & 0xf0) >> 4) - 8; - auto src0 = (srcv & 0xf) - 8; - auto src1 = ((srcv & 0xf0) >> 4) - 8; - auto scale0 = input2_vals[j * kblks + i / block_size]; - auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; - input1_f_vals[i * N + j] = (static_cast(src0) - zp0) * scale0; - input1_f_vals[(i + 1) * N + j] = (static_cast(src1) - zp0) * scale1; - } - } - } else { - for (int64_t i = 0; i < K; i += 2) { - for (int64_t j = 0; j < N; j++) { - auto srcv = input1_vals[j * K / 2 + i / 2]; - auto src0 = (srcv & 0xf) - 8; - auto src1 = ((srcv & 0xf0) >> 4) - 8; - auto scale0 = input2_vals[j * kblks + i / block_size]; - auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; - input1_f_vals[i * N + j] = static_cast(src0) * scale0; - input1_f_vals[(i + 1) * N + j] = static_cast(src1) * scale1; - } - } - } - - std::vector expected_vals(M * N); - for (int64_t m = 0; m < M; m++) { - for (int64_t n = 0; n < N; n++) { - float sum = 0.0f; - for (int64_t k = 0; k < K; k++) { - sum += input0_vals[m * K + k] * input1_f_vals[k * N + n]; - } - expected_vals[m * N + n] = sum; - } - } - - test.AddInput("A", {M, K}, input0_vals, false); - - test.AddInput("B", {N, static_cast(kblks), static_cast(block_size / 2)}, input1_vals, - true); - test.AddInput("scales", {N, static_cast(kblks)}, input2_vals, true); - if (is_asym) { - test.AddInput("zero_points", {N, static_cast(kblks / 2)}, input3_vals, true); - } - test.AddOutput("Y", {M, N}, expected_vals, false); - if (acc_lvl == 4) { - test.SetOutputAbsErr("Y", 0.1f); - } - - OrtValue b, scale, zp; - Tensor::InitOrtValue(DataTypeImpl::GetType(), - TensorShape({N, static_cast(kblks), static_cast(block_size / 2)}), - input1_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); - - Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks)}), - input2_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), scale); - if (is_asym) { - Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks / 2)}), - input3_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), zp); - } - SessionOptions so; - // Set up B as a shared initializer to be shared between sessions - ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); - ASSERT_EQ(so.AddInitializer("scales", &scale), Status::OK()); - if (is_asym) { - ASSERT_EQ(so.AddInitializer("zero_points", &zp), Status::OK()); - } - - // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable - test.EnableSharingOfPrePackedWeightsAcrossSessions(); - - // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP - // and we want to ensure that it is available in this build - auto cpu_ep = []() -> std::vector> { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - return execution_providers; - }; - - size_t number_of_pre_packed_weights_counter_session_1 = 0; - size_t number_of_shared_pre_packed_weights_counter = 0; - - // Session 1 - { - auto ep_vec = cpu_ep(); - test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, - &number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); - // Assert that no pre-packed weights have been shared thus far - ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); - } - - auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); - // Assert that the number of elements in the shared container - // is the same as the number of weights that have been pre-packed - ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); - - // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements - // that have been pre-packed will be zero in which case we do not continue with the testing - // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. - if (number_of_pre_packed_weights_counter_session_1 == 0) return; - - // Session 2 - { - size_t number_of_pre_packed_weights_counter_session_2 = 0; - auto ep_vec = cpu_ep(); - test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, - &number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); - - // Assert that the same number of weights were pre-packed in both sessions - ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); - - // Assert that the number of pre-packed weights that were shared equals - // the number of pre-packed weights in the second session - ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, - static_cast(number_of_shared_pre_packed_weights_counter)); - } -} - -#ifdef ORT_NEURAL_SPEED -TEST(MatMulNBits, SharedPrepackedWeights) { - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4); -} -#endif } // namespace test } // namespace onnxruntime