Skip to content

Commit

Permalink
[ARM] MatMulNBits fp16 support - connect kernels (microsoft#22856)
Browse files Browse the repository at this point in the history
### Description
A breakdown PR of microsoft#22651



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
fajin-corp authored and ankitm3k committed Dec 11, 2024
1 parent 1668346 commit 5388c4e
Show file tree
Hide file tree
Showing 11 changed files with 94 additions and 57 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/hqnbitgemm_kernel_neon_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Module Name:

#include "fp16_common.h"
#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
{
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
this->RopeDispatch = &MlasRopeDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down Expand Up @@ -573,9 +572,6 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot;

// MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
}

#if defined(__linux__)
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ SQ4BitGemm_CompFp32(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
Expand Down Expand Up @@ -808,7 +808,7 @@ GetQNBitGemm(QNBitGemmVariant variant)
{
switch (variant) {
case HQNBitGemmVariant_BitWidth4_CompFp16:
return nullptr;
return HQ4BitGemm_CompFp16;
default:
return nullptr;
}
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
//

/** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;
Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;

/** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */
typedef void(SQ4BitGemmPackQuantBData_Fn)(
typedef void(Q4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
Expand Down Expand Up @@ -151,7 +151,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;
Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
Expand All @@ -164,7 +164,7 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;
Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// SQNBIT_CompFp32 kernel function prototypes.
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm_kernel_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,47 @@ SQ4BitBlkDequantBForSgemm_CompFp32(
size_t BlockCountK
);

// HQNBIT_CompFp16 declarations
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
void
HQ4BitGemmPackQuantBData_CompFp16(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
const std::byte* QuantBDataBegin,
std::byte* PackedQuantBDataBegin,
MLAS_THREADPOOL* ThreadPool
);

void
HQ4BitBlkDequantBForHgemm_CompFp16(
size_t BlkLen,
MLAS_FP16* FpData,
const std::byte* QuantBData,
const MLAS_FP16* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t K,
size_t BlockCountK
);

void
HQ4BitGemmKernel_CompFp16(
const MLAS_FP16* A,
const MLAS_FP16* B,
const MLAS_FP16* Bias,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t K,
size_t lda,
size_t ldb,
size_t ldc
);

#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64))

// SQNBIT_CompInt8 declarations

void
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"

namespace sqnbitgemm_neon
{
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Module Name:
#include <cassert>

#include "qnbitgemm.h"
#include "sqnbitgemm_kernel_neon.h"
#include "qnbitgemm_kernel_neon.h"
#include "sqnbitgemm_q8_block.h"

namespace sqnbitgemm_neon
Expand Down
13 changes: 4 additions & 9 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,11 @@ void TestMatMulNBitsTyped() {
base_opts.block_size = block_size;
base_opts.accuracy_level = accuracy_level;

if (base_opts.accuracy_level == 4) {
if constexpr (std::is_same<AType, MLFloat16>::value) {
base_opts.output_abs_error = 0.055f;
base_opts.output_rel_error = 0.02f;
} else if (base_opts.accuracy_level == 4) {
base_opts.output_abs_error = 0.1f;
} else {
if constexpr (std::is_same<AType, MLFloat16>::value) {
#ifdef USE_WEBGPU
base_opts.output_abs_error = 0.03f;
#else
base_opts.output_abs_error = 0.01f;
#endif
}
}

{
Expand Down
60 changes: 31 additions & 29 deletions onnxruntime/test/mlas/bench/bench_qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
#include "core/util/thread_utils.h"
#include "core/platform/env_var_utils.h"

template <size_t BlkBitWidth>
void RunSQNBitGemmBenchmark(size_t BlkLen,
size_t M, size_t N, size_t K,
size_t Threads,
bool Symmetric,
bool HasBias,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
benchmark::State& state) {
template <typename AType, size_t BlkBitWidth>
void RunQNBitGemmBenchmark(size_t BlkLen,
size_t M, size_t N, size_t K,
size_t Threads,
bool Symmetric,
bool HasBias,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType,
benchmark::State& state) {
if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) {
state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine.");
state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine.");
return;
}

Expand Down Expand Up @@ -77,7 +77,7 @@ void RunSQNBitGemmBenchmark(size_t BlkLen,
tp.get());
}

MLAS_QNBIT_GEMM_DATA_PARAMS<float> params{};
MLAS_QNBIT_GEMM_DATA_PARAMS<AType> params{};
params.A = A.data();
params.lda = K;
if (PackedQuantBData != nullptr)
Expand Down Expand Up @@ -121,14 +121,16 @@ static void QNBitGemmArgs(benchmark::internal::Benchmark* b) {
b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"});

b->ArgsProduct({
{128}, // BlkLen
{1}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{1, 8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{false}, int64_t{true}}, // HasBias
{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType
{128}, // BlkLen
{1, 4096}, // M
{4096, 11008}, // N
{4096, 11008}, // K
{1, 8}, // Threads
{int64_t{false}, int64_t{true}}, // Symmetric
{int64_t{false}, int64_t{true}}, // HasBias
std::is_same_v<AType, MLAS_FP16>
? std::vector<int64_t>{int64_t{HQNBIT_CompFp16}}
: std::vector<int64_t>{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType
});
}

Expand All @@ -140,19 +142,19 @@ template <typename AType, size_t BlkBitWidth>
void QNBITGEMM_ENV(benchmark::State& state) {
using onnxruntime::ParseEnvironmentVariableWithDefault;

const auto BlkLen = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_BLKLEN", 32);
const auto M = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_M", 1);
const auto N = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_N", 4096);
const auto K = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_K", 4096);
const auto Threads = ParseEnvironmentVariableWithDefault<size_t>("ORT_SQNBITGEMM_THREADS", 1);
const auto Symmetric = ParseEnvironmentVariableWithDefault<bool>("ORT_SQNBITGEMM_SYMMETRIC", true);
const auto HasBias = ParseEnvironmentVariableWithDefault<bool>("ORT_SQNBITGEMM_HAS_BIAS", false);
const auto ComputeType = ParseEnvironmentVariableWithDefault<int32_t>("ORT_SQNBITGEMM_COMPUTE_TYPE",
const auto BlkLen = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_BLKLEN", 32);
const auto M = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_M", 1);
const auto N = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_N", 4096);
const auto K = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_K", 4096);
const auto Threads = ParseEnvironmentVariableWithDefault<size_t>("ORT_QNBITGEMM_THREADS", 1);
const auto Symmetric = ParseEnvironmentVariableWithDefault<bool>("ORT_QNBITGEMM_SYMMETRIC", true);
const auto HasBias = ParseEnvironmentVariableWithDefault<bool>("ORT_QNBITGEMM_HAS_BIAS", false);
const auto ComputeType = ParseEnvironmentVariableWithDefault<int32_t>("ORT_QNBITGEMM_COMPUTE_TYPE",
static_cast<int32_t>(SQNBIT_CompFp32));

RunSQNBitGemmBenchmark<BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, HasBias,
static_cast<MLAS_QNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);
RunQNBitGemmBenchmark<AType, BlkBitWidth>(BlkLen, M, N, K, Threads, Symmetric, HasBias,
static_cast<MLAS_QNBIT_GEMM_COMPUTE_TYPE>(ComputeType),
state);

std::ostringstream s;
s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/test/mlas/unittest/test_hqnbitgemm_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class MlasNeonFp16CastTest : public MlasTestBase {

class MlasNeonFp16PrepackTest : public MlasTestBase {
private:
std::random_device rd_; // a seed source for the random number engine
unsigned int seed_;
std::mt19937 gen_; // mersenne_twister_engine seeded with rd()
std::uniform_int_distribution<> distrib_;
Expand Down Expand Up @@ -172,7 +173,7 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {

public:
MlasNeonFp16PrepackTest()
: seed_(19287), gen_(seed_), distrib_(0, 255) {
: seed_(rd_()), gen_(seed_), distrib_(0, 255) {
}

static const char* GetTestSuiteName() {
Expand All @@ -196,6 +197,7 @@ class MlasNeonFp16PrepackTest : public MlasTestBase {

class MlasNeonFp16DequantBTest : public MlasTestBase {
private:
std::random_device rd_; // a seed source for the random number engine
unsigned int seed_;
std::mt19937 gen_; // mersenne_twister_engine seeded with rd()
std::uniform_int_distribution<> distrib_;
Expand Down Expand Up @@ -316,7 +318,7 @@ class MlasNeonFp16DequantBTest : public MlasTestBase {

public:
MlasNeonFp16DequantBTest()
: seed_(19287), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) {
: seed_(rd_()), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) {
}

static const char* GetTestSuiteName() {
Expand Down Expand Up @@ -351,6 +353,7 @@ class MlasNeonFp16DequantBTest : public MlasTestBase {

class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase {
private:
std::random_device rd_; // a seed source for the random number engine
unsigned int seed_;
std::mt19937 gen_; // mersenne_twister_engine seeded with rd()
MatrixGuardBuffer<MLAS_FP16> A_, B_, C_, ref_, bias_;
Expand Down Expand Up @@ -401,7 +404,7 @@ class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase {
for (size_t m = 0; m < M; ++m) {
for (size_t n = 0; n < N; ++n) {
size_t i = m * Ldc + n;
ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.02f, 0.055f))
ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.015f, 0.03f))
<< " seed " << seed_
<< " v0 " << target[i] << " v1 " << ref[i]
<< " m " << m << " n " << n;
Expand Down Expand Up @@ -436,7 +439,7 @@ class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase {

public:
MlasNeonFp16HQ4BitGemmKernelTest()
: seed_(19287), gen_(seed_) {
: seed_(rd_()), gen_(seed_) {
}

static const char* GetTestSuiteName() {
Expand Down

0 comments on commit 5388c4e

Please sign in to comment.