From 340ba2196d440c8d6f7e3dfd39c6a83612cd4819 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Oct 2023 17:01:01 -0700 Subject: [PATCH 01/10] change AddBiasAvx to MlasAddBiasForGemm --- onnxruntime/core/mlas/lib/q4gemm.h | 31 ++++++++++++++++--- onnxruntime/core/mlas/lib/q4gemm_avx512.cpp | 34 --------------------- 2 files changed, 27 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index 1562f9c0b4236..b4a3941af1edb 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -43,9 +43,32 @@ void MlasBlkQ4DequantB(float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); -template -MLAS_FORCEINLINE void -AddBiasAvx(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc); +inline +MLAS_FORCEINLINE +void +MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) { + for (size_t m = 0; m < CountM; m++) { + const float* bias = Bias; + float* sum = C; + for (size_t n = 0; n < CountN; n += 4) { + if (CountN - n < 4) { + for (size_t nn = n; nn < CountN; nn++) { + *sum += *bias; + sum++; + bias++; + } + break; + } + + MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum); + acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias)); + MlasStoreFloat32x4(sum, acc_x); + bias += 4; + sum += 4; + } + C += ldc; + } +} @@ -135,7 +158,7 @@ MlasQ4GemmOperation( #endif if (bias) { - AddBiasAvx(bias, c_blk, RowsHandled, CountN, ldc); + MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc); } if (DataParams->OutputProcessor != nullptr) { DataParams->OutputProcessor->Process( diff --git a/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp b/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp index f7af82ed12e0f..790e67eb407a7 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_avx512.cpp @@ -1030,40 +1030,6 @@ MlasBlkQ4DequantSgemmPackB( } } -template<> -MLAS_FORCEINLINE -void -AddBiasAvx( - const float* Bias, - float* C, - size_t CountM, - size_t CountN, - size_t ldc - ) -{ - for (size_t m = 0; m < CountM; m++) { - const float* bias = Bias; - float* sum = C; - for (size_t n = 0; n < CountN; n += 4) { - if (CountN - n < 4) { - for (size_t nn = n; nn < CountN; nn++) { - *sum += *bias; - sum++; - bias++; - } - break; - } - - __m128 acc_x = _mm_loadu_ps(sum); - acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias)); - _mm_storeu_ps(sum, acc_x); - bias += 4; - sum += 4; - } - C += ldc; - } -} - static MLAS_Q4GEMM_OPERATION* Q4Operations_avx512vnni[] = { MlasQ4GemmOperation, From 05c04ef2f2c3fffc464b912c0e2e2ecb44c6933b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:47:28 -0700 Subject: [PATCH 02/10] use std::cerr instead of logging in cpuid_uarch.cc --- onnxruntime/core/common/cpuid_uarch.cc | 28 ++++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/common/cpuid_uarch.cc b/onnxruntime/core/common/cpuid_uarch.cc index 52baad739441b..16634b2bc8744 100644 --- a/onnxruntime/core/common/cpuid_uarch.cc +++ b/onnxruntime/core/common/cpuid_uarch.cc @@ -3,7 +3,8 @@ #include "core/common/cpuid_uarch.h" -#include "core/common/logging/logging.h" +#include // For std::cerr. + // Writing to stderr instead of logging because logger may not be initialized yet. namespace onnxruntime { @@ -137,7 +138,7 @@ void decodeMIDR( break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,7 +157,7 @@ void decodeMIDR( break; // #endif default: - LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,7 +173,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; default: - LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,7 +188,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_cortex_a76; break; default: - LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,7 +200,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xscale; break; default: - LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,7 +216,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_carmel; break; default: - LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,7 +226,7 @@ void decodeMIDR( *uarch = cpuinfo_uarch_xgene; break; default: - LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,7 +298,7 @@ void decodeMIDR( break; // #endif /* ARM64 && !defined(__ANDROID__) */ default: - LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,8 +344,9 @@ void decodeMIDR( *uarch = cpuinfo_uarch_exynos_m5; break; default: - LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Samsung CPU variant 0x" + << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -355,12 +357,12 @@ void decodeMIDR( *uarch = cpuinfo_uarch_pj4; break; default: - LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored"; + std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ default: - LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr; + std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } From 6fe9b81a2ff615d61c718d5d8631d688c59c3eaa Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 10 Oct 2023 09:54:48 -0700 Subject: [PATCH 03/10] add infrastructure for q4gemm neon impl --- cmake/onnxruntime_mlas.cmake | 2 + onnxruntime/core/mlas/lib/mlasi.h | 1 + onnxruntime/core/mlas/lib/platform.cpp | 1 + onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 114 ++++++++++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 onnxruntime/core/mlas/lib/q4gemm_neon.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 992908392c946..af51df4838505 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -65,6 +65,7 @@ function(setup_mlas_source_for_windows) if(onnxruntime_target_platform STREQUAL "ARM64") target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/q4gemm_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp @@ -331,6 +332,7 @@ else() ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/q4gemm_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index b6ac4a1ca1d6c..d037360cf1028 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -889,6 +889,7 @@ extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni; struct MLAS_FPQ4GEMM_DISPATCH; extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; +extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon; // // Quantized depthwise convolution kernels. diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 96bc1d8010bed..1493e536ae15b 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -449,6 +449,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp new file mode 100644 index 0000000000000..0d3b89085269a --- /dev/null +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -0,0 +1,114 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + q4gemm_neon.cpp + +Abstract: + + This module implements the fp32 matrix multiplication with compressed + weight tensor (right hand side). The assumption is the right hand side + tensor can be pre-packed and compressed using int-4 quantization to save + memory. + + This implementation is for ARM NEON. + +--*/ + +#include + +#include "q4gemm.h" + +struct MLAS_FP_Q4_GEMM_KERNEL_NEON { + // static constexpr size_t StrideM = 256; +}; + +// +// MlasQ4GemmKernel and related helper functions +// + +template +MLAS_FORCEINLINE size_t +MlasQ4GemmKernelNeon(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias); + +template <> +MLAS_FORCEINLINE size_t +MlasQ4GemmKernelNeon(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) +{ + static_cast((A, PackedB, C, CountM, CountN, CountK, lda, ldb, ldc, Bias)); + return 1; // TODO ... +} + +template <> +MLAS_FORCEINLINE size_t +MlasQ4GemmKernel(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) +{ + return MlasQ4GemmKernelNeon(A, PackedB, C, CountM, CountN, CountK, lda, ldb, + ldc, Bias); +} + +// +// MlasBlkQ4DequantB and related helper functions +// + +template +MLAS_FORCEINLINE void +MlasBlkQ4DequantBNeon( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + static_cast((FpData, PackedB, CountN, CountK, ldb)); + // TODO ... +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + MlasBlkQ4DequantBNeon(FpData, PackedB, CountN, CountK, ldb); +} + +// +// MlasFpQ4GemmDispatchNeon structure population +// + +static MLAS_Q4GEMM_OPERATION* Q4Operations_neon[] = { + MlasQ4GemmOperation, + nullptr, + nullptr, + nullptr, + nullptr, +}; + +const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon = {Q4Operations_neon}; From f7bd0299b4204c759c7b7a5ccd5c4223498575e8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:20:09 -0700 Subject: [PATCH 04/10] make test_main.cc threadpool a unique_ptr to avoid memory leak output in debug build --- onnxruntime/test/mlas/unittest/test_main.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_main.cpp b/onnxruntime/test/mlas/unittest/test_main.cpp index 66b5a6a15db2b..505c0c01dfa90 100644 --- a/onnxruntime/test/mlas/unittest/test_main.cpp +++ b/onnxruntime/test/mlas/unittest/test_main.cpp @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test_util.h" - -#include #include +#include +#include + +#include "test_util.h" #if !defined(BUILD_MLAS_NO_ONNXRUNTIME) MLAS_THREADPOOL* GetMlasThreadPool(void) { - static MLAS_THREADPOOL* threadpool = new onnxruntime::concurrency::ThreadPool( + static auto threadpool = std::make_unique( &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 2, true); - return threadpool; + return threadpool.get(); } #else From 85e76cea4ffc7da5e8c97c59c3ea5095f94f9a8d Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 12 Oct 2023 18:16:09 -0700 Subject: [PATCH 05/10] initial neon impl of q4gemm --- onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 176 +++++++++++++++++++++- 1 file changed, 173 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp index 0d3b89085269a..84d0cb4075062 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -57,8 +57,172 @@ MlasQ4GemmKernelNeon(const float* A, size_t ldc, const float* Bias) { - static_cast((A, PackedB, C, CountM, CountN, CountK, lda, ldb, ldc, Bias)); - return 1; // TODO ... + using Q4Type = MLAS_Q4TYPE_BLK0; + + auto impl0_reference = [&]() { + for (size_t m = 0; m < CountM; ++m) { + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* PackedBBlock = PackedB + n * ldb; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + float b_blk_unpacked[Q4Type::BlkLen]{}; + + const size_t kblocklen = std::min(CountK - k, Q4Type::BlkLen); + + const float s = MlasQ4BlkScale(PackedBBlock); + const uint8_t z = 8; // MlasQ4BlkZeroPoint(PackedBBlock); + const uint8_t* PackedBData = MlasQ4BlkData(PackedBBlock); + + for (size_t kk = 0; kk < kblocklen; kk += 32) { + const size_t ksubblocklen = std::min(size_t{32}, kblocklen - kk); + + for (size_t l0 = 0; l0 < 16; ++l0) { + const uint8_t PackedByte = PackedBData[l0]; + + if (l0 < ksubblocklen) { + const int8_t PackedByteLo = PackedByte & 0x0F; + const float UnpackedValue0 = (PackedByteLo - z) * s; + b_blk_unpacked[kk + l0] = UnpackedValue0; + } + + const size_t l1 = l0 + 16; + if (l1 < ksubblocklen) { + const int8_t PackedByteHi = PackedByte >> 4; + const float UnpackedValue1 = (PackedByteHi - z) * s; + b_blk_unpacked[kk + l1] = UnpackedValue1; + } + } + + PackedBData += 16; + } + + for (size_t kk = 0; kk < kblocklen; ++kk) { + C[m * ldc + n] += A[m * lda + k + kk] * b_blk_unpacked[kk]; + } + + PackedBBlock += Q4Type::BlobSize; + } + + if (Bias) { + C[m * ldc + n] += Bias[n]; + } + } + } + + return CountM; + }; + + auto impl1 = [&]() { + const float* ARowPtr = A; + float* CRowPtr = C; + const float* BiasPtr = Bias; + + const int8x16_t LowMask = vdupq_n_s8(0x0F); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* PackedBColPtr = PackedB; + + for (size_t n = 0; n < CountN; ++n) { + float32x4_t acc = vdupq_n_f32(0.0f); + const uint8_t* PackedBBlobPtr = PackedBColPtr; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); + + const float scale = MlasQ4BlkScale(PackedBBlobPtr); + const uint8_t zp = 8; + const uint8_t* b_data = MlasQ4BlkData(PackedBBlobPtr); + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { + // load A row vector elements + + // load 32 elements from A padding with 0's if there aren't enough + float a_segment[32]; + { + const size_t k_subblk_len = + std::min(k_blk_len - k_idx_in_blk, size_t{32}); + const float* a_begin = ARowPtr + k + k_idx_in_blk; + std::copy(a_begin, a_begin + k_subblk_len, a_segment); + std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); + } + + // 32 floats of A + float32x4_t av[8]; + for (int i = 0; i < 8; ++i) { + av[i] = vld1q_f32(a_segment + i * 4); + } + + // load B column vector + int8x16_t bv_packed = vld1q_s8(b_data); + + int8x16_t bv_bytes_0 = vandq_s8(bv_packed, LowMask); + int8x16_t bv_bytes_1 = vandq_s8(vshrq_n_s8(bv_packed, 4), LowMask); + + // dequantize B + + // subtract zero point + const int8x16_t zpv = vdupq_n_s8(zp); + + bv_bytes_0 = vsubq_s8(bv_bytes_0, zpv); + bv_bytes_1 = vsubq_s8(bv_bytes_1, zpv); + + // widen to int16 + int16x8_t bv_int16_0 = vmovl_s8(vget_low_s8(bv_bytes_0)); + int16x8_t bv_int16_1 = vmovl_s8(vget_high_s8(bv_bytes_0)); + int16x8_t bv_int16_2 = vmovl_s8(vget_low_s8(bv_bytes_1)); + int16x8_t bv_int16_3 = vmovl_s8(vget_high_s8(bv_bytes_1)); + + // 32 floats of B + float32x4_t bv[8]; + + // widen to int32, cast to float32 + + int32x4_t bv_int32_0 = vmovl_s16(vget_low_s16(bv_int16_0)); + + bv[0] = vcvtq_f32_s32(bv_int32_0); + bv[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_0))); + + bv[2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_1))); + bv[3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_1))); + + bv[4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_2))); + bv[5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_2))); + + bv[6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_3))); + bv[7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_3))); + + // multiply by scale + for (int i = 0; i < 8; ++i) { + bv[i] = vmulq_n_f32(bv[i], scale); + } + + // c += a * b + for (int i = 0; i < 8; ++i) { + acc = vfmaq_f32(acc, av[i], bv[i]); + } + } + + PackedBBlobPtr += Q4Type::BlobSize; + } + + float sum = vpadds_f32(vpadd_f32(vget_low_f32(acc), vget_high_f32(acc))); + + sum += BiasPtr ? BiasPtr[n] : 0.0f; + + CRowPtr[n] = sum; + + PackedBColPtr += ldb; + } + + ARowPtr += lda; + CRowPtr += ldc; + } + + return CountM; + }; + + // return impl0_reference(); + return impl1(); } template <> @@ -85,9 +249,15 @@ MlasQ4GemmKernel(const float* A, template MLAS_FORCEINLINE void MlasBlkQ4DequantBNeon( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantBNeon( float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) { - static_cast((FpData, PackedB, CountN, CountK, ldb)); + using Q4Type = MLAS_Q4TYPE_BLK0; + static_cast(FpData, PackedB, CountN, CountK, ldb); // TODO ... } From 922434f2e094b966f2e2d5313a8224e54fe4e4d7 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 16 Oct 2023 16:35:02 -0700 Subject: [PATCH 06/10] reference implementation of MlasBlkQ4DequantBNeon --- onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 70 ++++++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp index 84d0cb4075062..33f5d78d84367 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -257,8 +257,74 @@ MlasBlkQ4DequantBNeon( float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) { using Q4Type = MLAS_Q4TYPE_BLK0; - static_cast(FpData, PackedB, CountN, CountK, ldb); - // TODO ... + + // unpack B in format suitable for MlasSgemmKernelZero + + auto impl0_reference = [&]() { + float* Dst = FpData; + const uint8_t* PackedBCol = PackedB; + + for (size_t n = 0; n < CountN; n += 16) { + const size_t nnlen = std::min(CountN - n, size_t{16}); + + for (size_t nn = 0; nn < nnlen; ++nn) { + const uint8_t* PackedBBlock = PackedBCol; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + float b_blk_unpacked[32]{}; + + const size_t kblocklen = std::min(CountK - k, Q4Type::BlkLen); + + const float s = MlasQ4BlkScale(PackedBBlock); + const uint8_t z = 8; // MlasQ4BlkZeroPoint(PackedBBlock); + const uint8_t* PackedBData = MlasQ4BlkData(PackedBBlock); + + for (size_t kk = 0; kk < kblocklen; kk += 32) { + const size_t ksubblocklen = std::min(size_t{32}, kblocklen - kk); + + for (size_t l0 = 0; l0 < 16; ++l0) { + const uint8_t PackedByte = PackedBData[l0]; + + if (l0 < ksubblocklen) { + const int8_t PackedByteLo = PackedByte & 0x0F; + const float UnpackedValue0 = (PackedByteLo - z) * s; + b_blk_unpacked[kk + l0] = UnpackedValue0; + } + + const size_t l1 = l0 + 16; + if (l1 < ksubblocklen) { + const int8_t PackedByteHi = PackedByte >> 4; + const float UnpackedValue1 = (PackedByteHi - z) * s; + b_blk_unpacked[kk + l1] = UnpackedValue1; + } + } + + PackedBData += 16; + } + + for (size_t kk = 0; kk < kblocklen; ++kk) { + Dst[(k + kk) * 16 + nn] = b_blk_unpacked[kk]; + } + + PackedBBlock += Q4Type::BlobSize; + } + + PackedBCol += ldb; + } + + // zero out any remaining columns + + if (nnlen < 16) { + for (size_t k = 0; k < CountK; ++k) { + std::fill_n(Dst + (k * 16) + nnlen, 16 - nnlen, 0.0f); + } + } + + Dst += CountK * 16; + } + }; + + impl0_reference(); } template <> From 6048f8c9fe33483092555b4bfa8454a82787098a Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 18 Oct 2023 19:22:39 -0700 Subject: [PATCH 07/10] WIP optimize MlasQ4GemmKernelNeon --- onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 607 +++++++++++++++++++++- 1 file changed, 605 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp index 33f5d78d84367..611509f2ff8f2 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -44,6 +44,184 @@ MlasQ4GemmKernelNeon(const float* A, size_t ldc, const float* Bias); +namespace q4gemm_neon +{ + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +MLAS_FORCEINLINE float32x4_t +FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) +{ + // aN: aN_0 aN_1 aN_2 aN_3 + + float32x4_t b0 = vzip1q_f32(a0, a1); // a0_0 a1_0 a0_1 a1_1 + float32x4_t b1 = vzip2q_f32(a0, a1); // a0_2 a1_2 a0_3 a1_3 + float32x4_t b2 = vzip1q_f32(a2, a3); // a2_0 a3_0 a2_1 a3_1 + float32x4_t b3 = vzip2q_f32(a2, a3); // a2_2 a3_2 a2_3 a3_3 + + // a0_0 a1_0 a2_0 a3_0 + a0 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_1 a1_1 a2_1 a3_1 + a1 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b0), vreinterpretq_f64_f32(b2))); + // a0_2 a1_2 a3_2 a3_2 + a2 = vreinterpretq_f32_f64(vzip1q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + // a0_3 a1_3 a2_3 a3_3 + a3 = vreinterpretq_f32_f64(vzip2q_f64(vreinterpretq_f64_f32(b1), vreinterpretq_f64_f32(b3))); + + return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); +} + +template +MLAS_FORCEINLINE void +ComputeDotProducts(const float* ARowPtr, + const uint8_t* PackedBColPtr, + float* SumPtr, + size_t CountK, + size_t ldb, + const float* BiasPtr) +{ + using Q4Type = MLAS_Q4TYPE_BLK0; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + + const int8x16_t LowMask = vdupq_n_s8(0x0F); + + float32x4_t acc[NCols]; + for (int i = 0; i < NCols; ++i) { + acc[i] = vdupq_n_f32(0.0f); + } + + const uint8_t* PackedBBlobPtr = PackedBColPtr; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); + + float scale[NCols]; + for (int i = 0; i < NCols; ++i) { + scale[i] = MlasQ4BlkScale(PackedBBlobPtr + i * ldb); + } + + const uint8_t zp = 8; + + const uint8_t* b_data[NCols]; + for (int i = 0; i < NCols; ++i) { + b_data[i] = MlasQ4BlkData(PackedBBlobPtr + i * ldb); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { + // load A row vector elements + + // load 32 elements from A padding with 0's if there aren't enough + float a_segment[32]; + { + const size_t k_subblk_len = std::min(k_blk_len - k_idx_in_blk, size_t{32}); + const float* a_begin = ARowPtr + k + k_idx_in_blk; + std::copy(a_begin, a_begin + k_subblk_len, a_segment); + std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); + } + + // 32 floats of A + float32x4_t av[8]; + UnrolledLoop<8>([&](size_t i) { av[i] = vld1q_f32(a_segment + i * 4); }); + + // load B column vectors + int8x16_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { bv_packed[i] = vld1q_s8(b_data[i]); }); + + int8x16_t bv_bytes[NCols][2]; + + UnrolledLoop([&](size_t i) { + bv_bytes[i][0] = vandq_s8(bv_packed[i], LowMask); + bv_bytes[i][1] = vandq_s8(vshrq_n_s8(bv_packed[i], 4), LowMask); + }); + + // dequantize B + + // subtract zero point + const int8x16_t zpv = vdupq_n_s8(zp); + + UnrolledLoop([&](size_t i) { + bv_bytes[i][0] = vsubq_s8(bv_bytes[i][0], zpv); + bv_bytes[i][1] = vsubq_s8(bv_bytes[i][1], zpv); + }); + + // widen to int16 + int16x8_t bv_int16[NCols][4]; + + UnrolledLoop([&](size_t i) { + bv_int16[i][0] = vmovl_s8(vget_low_s8(bv_bytes[i][0])); + bv_int16[i][1] = vmovl_s8(vget_high_s8(bv_bytes[i][0])); + bv_int16[i][2] = vmovl_s8(vget_low_s8(bv_bytes[i][1])); + bv_int16[i][3] = vmovl_s8(vget_high_s8(bv_bytes[i][1])); + }); + + // 32 floats of B + float32x4_t bv[NCols][8]; + + // widen to int32, cast to float32 + + UnrolledLoop([&](size_t i) { + bv[i][0] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][0]))); + bv[i][1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][0]))); + + bv[i][2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][1]))); + bv[i][3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][1]))); + }); + + UnrolledLoop([&](size_t i) { + bv[i][4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][2]))); + bv[i][5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][2]))); + + bv[i][6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][3]))); + bv[i][7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][3]))); + }); + + // multiply by scale + UnrolledLoop([&](size_t i) { + UnrolledLoop<8>([&](size_t j) { bv[i][j] = vmulq_n_f32(bv[i][j], scale[i]); }); + }); + + // c += a * b + UnrolledLoop<8>([&](size_t j) { + UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + PackedBBlobPtr += Q4Type::BlobSize; + } + + if constexpr (NCols == 4) { + float32x4_t sum = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (BiasPtr != nullptr) { + sum = vaddq_f32(sum, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sum); + } else { + for (int i = 0; i < NCols; ++i) { + SumPtr[i] = vaddvq_f32(acc[i]); + if (BiasPtr != nullptr) { + SumPtr[i] += BiasPtr[i]; + } + } + } +} + +} // namespace q4gemm_neon + template <> MLAS_FORCEINLINE size_t MlasQ4GemmKernelNeon(const float* A, @@ -59,6 +237,8 @@ MlasQ4GemmKernelNeon(const float* A, { using Q4Type = MLAS_Q4TYPE_BLK0; + static constexpr size_t NCols = 4; // columns to handle at once + auto impl0_reference = [&]() { for (size_t m = 0; m < CountM; ++m) { for (size_t n = 0; n < CountN; ++n) { @@ -205,7 +385,121 @@ MlasQ4GemmKernelNeon(const float* A, PackedBBlobPtr += Q4Type::BlobSize; } - float sum = vpadds_f32(vpadd_f32(vget_low_f32(acc), vget_high_f32(acc))); + float sum = vaddvq_f32(acc); + + sum += BiasPtr ? BiasPtr[n] : 0.0f; + + CRowPtr[n] = sum; + + PackedBColPtr += ldb; + } + + ARowPtr += lda; + CRowPtr += ldc; + } + + return CountM; + }; + + auto impl2_two_accumulators = [&]() { + const float* ARowPtr = A; + float* CRowPtr = C; + const float* BiasPtr = Bias; + + const int8x16_t LowMask = vdupq_n_s8(0x0F); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* PackedBColPtr = PackedB; + + for (size_t n = 0; n < CountN; ++n) { + float32x4_t acc0 = vdupq_n_f32(0.0f); + float32x4_t acc1 = vdupq_n_f32(0.0f); + + const uint8_t* PackedBBlobPtr = PackedBColPtr; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); + + const float scale = MlasQ4BlkScale(PackedBBlobPtr); + const uint8_t zp = 8; + const uint8_t* b_data = MlasQ4BlkData(PackedBBlobPtr); + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { + // load A row vector elements + + // load 32 elements from A padding with 0's if there aren't enough + float a_segment[32]; + { + const size_t k_subblk_len = + std::min(k_blk_len - k_idx_in_blk, size_t{32}); + const float* a_begin = ARowPtr + k + k_idx_in_blk; + std::copy(a_begin, a_begin + k_subblk_len, a_segment); + std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); + } + + // 32 floats of A + float32x4_t av[8]; + for (int i = 0; i < 8; ++i) { + av[i] = vld1q_f32(a_segment + i * 4); + } + + // load B column vector + int8x16_t bv_packed = vld1q_s8(b_data); + + int8x16_t bv_bytes_0 = vandq_s8(bv_packed, LowMask); + int8x16_t bv_bytes_1 = vandq_s8(vshrq_n_s8(bv_packed, 4), LowMask); + + // dequantize B + + // subtract zero point + const int8x16_t zpv = vdupq_n_s8(zp); + + bv_bytes_0 = vsubq_s8(bv_bytes_0, zpv); + bv_bytes_1 = vsubq_s8(bv_bytes_1, zpv); + + // widen to int16 + int16x8_t bv_int16_0 = vmovl_s8(vget_low_s8(bv_bytes_0)); + int16x8_t bv_int16_1 = vmovl_s8(vget_high_s8(bv_bytes_0)); + int16x8_t bv_int16_2 = vmovl_s8(vget_low_s8(bv_bytes_1)); + int16x8_t bv_int16_3 = vmovl_s8(vget_high_s8(bv_bytes_1)); + + // 32 floats of B + float32x4_t bv[8]; + + // widen to int32, cast to float32 + + int32x4_t bv_int32_0 = vmovl_s16(vget_low_s16(bv_int16_0)); + + bv[0] = vcvtq_f32_s32(bv_int32_0); + bv[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_0))); + + bv[2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_1))); + bv[3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_1))); + + bv[4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_2))); + bv[5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_2))); + + bv[6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_3))); + bv[7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_3))); + + // multiply by scale + for (int i = 0; i < 8; ++i) { + bv[i] = vmulq_n_f32(bv[i], scale); + } + + // c += a * b + for (int i = 0; i < 4; ++i) { + acc0 = vfmaq_f32(acc0, av[i], bv[i]); + acc1 = vfmaq_f32(acc1, av[i + 4], bv[i + 4]); + } + } + + PackedBBlobPtr += Q4Type::BlobSize; + } + + float sum = + vpadds_f32(vpadd_f32(vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)), + vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)))); sum += BiasPtr ? BiasPtr[n] : 0.0f; @@ -221,8 +515,317 @@ MlasQ4GemmKernelNeon(const float* A, return CountM; }; + auto impl3_four_cols = [&]() { + const float* ARowPtr = A; + float* CRowPtr = C; + + for (size_t m = 0; m < CountM; ++m) { + const float* BiasPtr = Bias; + const uint8_t* PackedBColPtr = PackedB; + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { + q4gemm_neon::ComputeDotProducts(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, + BiasPtr); + + // move to next `NCols` columns + + PackedBColPtr += NCols * ldb; + BiasPtr += BiasPtr != nullptr ? NCols : 0; + SumPtr += NCols; + + nblk -= NCols; + } + + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { + q4gemm_neon::ComputeDotProducts<1>(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, + BiasPtr); + + PackedBColPtr += ldb; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + + ARowPtr += lda; + CRowPtr += ldc; + } + + return CountM; + }; + + auto impl4_one_col_with_helper = [&]() { + const float* ARowPtr = A; + float* CRowPtr = C; + + for (size_t m = 0; m < CountM; ++m) { + const float* BiasPtr = Bias; + const uint8_t* PackedBColPtr = PackedB; + float* SumPtr = CRowPtr; + + for (size_t n = 0; n < CountN; ++n) { + q4gemm_neon::ComputeDotProducts<1>(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, + BiasPtr); + + PackedBColPtr += ldb; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + + ARowPtr += lda; + CRowPtr += ldc; + } + + return CountM; + }; + + auto impl5_four_cols_inline = [&]() { + + const float* ARowPtr = A; + float* CRowPtr = C; + + const int8x16_t LowMask = vdupq_n_s8(0x0F); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* PackedBColPtr = PackedB; + const float* BiasPtr = Bias; + float* SumPtr = CRowPtr; + + int64_t nblk = CountN - NCols; + + while (nblk >= 0) { + float32x4_t acc[NCols]{}; + + const uint8_t* PackedBBlobPtr = PackedBColPtr; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); + + float scale[NCols]; + q4gemm_neon::UnrolledLoop([&](size_t i) { + scale[i] = MlasQ4BlkScale(PackedBBlobPtr + i * ldb); + }); + + const uint8_t zp = 8; + + const uint8_t* b_data[NCols]; + q4gemm_neon::UnrolledLoop([&](size_t i) { + b_data[i] = MlasQ4BlkData(PackedBBlobPtr + i * ldb); + }); + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { + // load A row vector elements + + // load 32 elements from A padding with 0's if there aren't enough + float a_segment[32]; + { + const size_t k_subblk_len = + std::min(k_blk_len - k_idx_in_blk, size_t{32}); + const float* a_begin = ARowPtr + k + k_idx_in_blk; + std::copy(a_begin, a_begin + k_subblk_len, a_segment); + std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); + } + + // 32 floats of A + float32x4_t av[8]; + q4gemm_neon::UnrolledLoop<8>( + [&](size_t i) { av[i] = vld1q_f32(a_segment + i * 4); }); + + // load B column vector + int8x16_t bv_packed[NCols]; + q4gemm_neon::UnrolledLoop( + [&](size_t i) { bv_packed[i] = vld1q_s8(b_data[i]); }); + + int8x16_t bv_bytes[NCols][2]; + q4gemm_neon::UnrolledLoop([&](size_t i) { + bv_bytes[i][0] = vandq_s8(bv_packed[i], LowMask); + bv_bytes[i][1] = vandq_s8(vshrq_n_s8(bv_packed[i], 4), LowMask); + }); + + // dequantize B + + // subtract zero point + const int8x16_t zpv = vdupq_n_s8(zp); + + q4gemm_neon::UnrolledLoop([&](size_t i) { + bv_bytes[i][0] = vsubq_s8(bv_bytes[i][0], zpv); + bv_bytes[i][1] = vsubq_s8(bv_bytes[i][1], zpv); + }); + + // widen to int16 + int16x8_t bv_int16[NCols][4]; + q4gemm_neon::UnrolledLoop([&](size_t i) { + bv_int16[i][0] = vmovl_s8(vget_low_s8(bv_bytes[i][0])); + bv_int16[i][1] = vmovl_s8(vget_high_s8(bv_bytes[i][0])); + bv_int16[i][2] = vmovl_s8(vget_low_s8(bv_bytes[i][1])); + bv_int16[i][3] = vmovl_s8(vget_high_s8(bv_bytes[i][1])); + }); + + // 32 floats of B + float32x4_t bv[NCols][8]; + + // widen to int32, cast to float32 + + q4gemm_neon::UnrolledLoop([&](size_t i) { + bv[i][0] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][0]))); + bv[i][1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][0]))); + + bv[i][2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][1]))); + bv[i][3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][1]))); + }); + + q4gemm_neon::UnrolledLoop([&](size_t i) { + bv[i][4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][2]))); + bv[i][5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][2]))); + + bv[i][6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[i][3]))); + bv[i][7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[i][3]))); + }); + + // multiply by scale + q4gemm_neon::UnrolledLoop([&](size_t i) { + q4gemm_neon::UnrolledLoop<8>( + [&](size_t j) { bv[i][j] = vmulq_n_f32(bv[i][j], scale[i]); }); + }); + + // c += a * b + q4gemm_neon::UnrolledLoop<8>([&](size_t j) { + q4gemm_neon::UnrolledLoop( + [&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); + }); + } + + PackedBBlobPtr += Q4Type::BlobSize; + } + + float32x4_t sums = q4gemm_neon::FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + if (Bias) { + sums = vaddq_f32(sums, vld1q_f32(BiasPtr)); + } + + vst1q_f32(SumPtr, sums); + + PackedBColPtr += NCols * ldb; + BiasPtr += NCols; + SumPtr += NCols; + + nblk -= NCols; + } + + nblk += NCols; + + if (nblk > 0) { + + float32x4_t acc[NCols]{}; + + const uint8_t* PackedBBlobPtr = PackedBColPtr; + + for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { + const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); + + float scale[NCols]; + const uint8_t* b_data[NCols]; + const uint8_t zp = 8; + + for (int64_t nn = 0; nn < nblk; ++nn) { + scale[nn] = MlasQ4BlkScale(PackedBBlobPtr + nn * ldb); + b_data[nn] = MlasQ4BlkData(PackedBBlobPtr + nn * ldb); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { + // load A row vector elements + + // load 32 elements from A padding with 0's if there aren't enough + float a_segment[32]; + { + const size_t k_subblk_len = + std::min(k_blk_len - k_idx_in_blk, size_t{32}); + const float* a_begin = ARowPtr + k + k_idx_in_blk; + std::copy(a_begin, a_begin + k_subblk_len, a_segment); + std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); + } + + // 32 floats of A + float32x4_t av[8]; + q4gemm_neon::UnrolledLoop<8>( + [&](size_t i) { av[i] = vld1q_f32(a_segment + i * 4); }); + + for (int64_t nn = 0; nn < nblk; ++nn) { + // load B column vector + int8x16_t bv_packed = vld1q_s8(b_data[nn]); + + int8x16_t bv_bytes[2]; + bv_bytes[0] = vandq_s8(bv_packed, LowMask); + bv_bytes[1] = vandq_s8(vshrq_n_s8(bv_packed, 4), LowMask); + + // dequantize B + + // subtract zero point + const int8x16_t zpv = vdupq_n_s8(zp); + + bv_bytes[0] = vsubq_s8(bv_bytes[0], zpv); + bv_bytes[1] = vsubq_s8(bv_bytes[1], zpv); + + // widen to int16 + int16x8_t bv_int16[4]; + bv_int16[0] = vmovl_s8(vget_low_s8(bv_bytes[0])); + bv_int16[1] = vmovl_s8(vget_high_s8(bv_bytes[0])); + bv_int16[2] = vmovl_s8(vget_low_s8(bv_bytes[1])); + bv_int16[3] = vmovl_s8(vget_high_s8(bv_bytes[1])); + + // 32 floats of B + float32x4_t bv[8]; + + // widen to int32, cast to float32 + bv[0] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[0]))); + bv[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[0]))); + + bv[2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[1]))); + bv[3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[1]))); + + bv[4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[2]))); + bv[5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[2]))); + + bv[6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16[3]))); + bv[7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16[3]))); + + // multiply by scale + q4gemm_neon::UnrolledLoop<8>( + [&](size_t j) { bv[j] = vmulq_n_f32(bv[j], scale[nn]); }); + + // c += a * b + q4gemm_neon::UnrolledLoop<8>([&](size_t j) { + acc[nn] = vfmaq_f32(acc[nn], av[j], bv[j]); + }); + } + } + + PackedBBlobPtr += Q4Type::BlobSize; + } + + for (int64_t nn = 0; nn < nblk; ++nn) { + SumPtr[nn] = vaddvq_f32(acc[nn]); + SumPtr[nn] += Bias != nullptr ? BiasPtr[nn] : 0.0f; + } + } + + ARowPtr += lda; + CRowPtr += ldc; + } + + return CountM; + }; + // return impl0_reference(); - return impl1(); + // return impl1(); + // return impl2_two_accumulators(); + // return impl3_four_cols(); + // return impl4_one_col_with_helper(); + return impl5_four_cols_inline(); } template <> From a88eedc7cf947da40d071bd18cec02e6eee645fc Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:51:34 -0700 Subject: [PATCH 08/10] fix some bugs, enable implementation for other Q4Types, remove old impls which don't seem useful --- onnxruntime/core/mlas/lib/q4common.h | 14 + onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 462 ++++++++-------------- 2 files changed, 170 insertions(+), 306 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4common.h b/onnxruntime/core/mlas/lib/q4common.h index 532437797a084..e2d2172fe9148 100644 --- a/onnxruntime/core/mlas/lib/q4common.h +++ b/onnxruntime/core/mlas/lib/q4common.h @@ -43,6 +43,13 @@ MlasQ4BlkScale(const uint8_t* BlkPtr) return *reinterpret_cast(BlkPtr); } +template +constexpr bool +MlasQ4BlkHasZeroPoint() +{ + return false; +} + template uint8_t& MlasQ4BlkZeroPoint(uint8_t* BlkPtr); @@ -100,6 +107,13 @@ struct MLAS_Q4TYPE_BLK1 { static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float) + sizeof(uint8_t); }; +template <> +constexpr bool +MlasQ4BlkHasZeroPoint() +{ + return true; +} + template<> inline uint8_t& MlasQ4BlkZeroPoint(uint8_t* BlkPtr) diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp index 611509f2ff8f2..d6aba1357ecbd 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -24,7 +24,6 @@ Module Name: #include "q4gemm.h" struct MLAS_FP_Q4_GEMM_KERNEL_NEON { - // static constexpr size_t StrideM = 256; }; // @@ -83,7 +82,7 @@ FoldAccumulators(float32x4_t a0, float32x4_t a1, float32x4_t a2, float32x4_t a3) return vaddq_f32(vaddq_f32(a0, a1), vaddq_f32(a2, a3)); } -template +template MLAS_FORCEINLINE void ComputeDotProducts(const float* ARowPtr, const uint8_t* PackedBColPtr, @@ -92,8 +91,6 @@ ComputeDotProducts(const float* ARowPtr, size_t ldb, const float* BiasPtr) { - using Q4Type = MLAS_Q4TYPE_BLK0; - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); const int8x16_t LowMask = vdupq_n_s8(0x0F); @@ -109,16 +106,21 @@ ComputeDotProducts(const float* ARowPtr, const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); float scale[NCols]; - for (int i = 0; i < NCols; ++i) { - scale[i] = MlasQ4BlkScale(PackedBBlobPtr + i * ldb); - } - - const uint8_t zp = 8; + UnrolledLoop( + [&](size_t i) { scale[i] = MlasQ4BlkScale(PackedBBlobPtr + i * ldb); }); + + uint8_t zp[NCols]; + UnrolledLoop([&](size_t i) { + if constexpr (MlasQ4BlkHasZeroPoint()) { + zp[i] = MlasQ4BlkZeroPoint(PackedBBlobPtr + i * ldb); + } else { + zp[i] = 8; + } + }); const uint8_t* b_data[NCols]; - for (int i = 0; i < NCols; ++i) { - b_data[i] = MlasQ4BlkData(PackedBBlobPtr + i * ldb); - } + UnrolledLoop( + [&](size_t i) { b_data[i] = MlasQ4BlkData(PackedBBlobPtr + i * ldb); }); for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { // load A row vector elements @@ -150,9 +152,8 @@ ComputeDotProducts(const float* ARowPtr, // dequantize B // subtract zero point - const int8x16_t zpv = vdupq_n_s8(zp); - UnrolledLoop([&](size_t i) { + const int8x16_t zpv = vdupq_n_s8(zp[i]); bv_bytes[i][0] = vsubq_s8(bv_bytes[i][0], zpv); bv_bytes[i][1] = vsubq_s8(bv_bytes[i][1], zpv); }); @@ -197,6 +198,9 @@ ComputeDotProducts(const float* ARowPtr, UnrolledLoop<8>([&](size_t j) { UnrolledLoop([&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); }); + + // increment b data pointers to next 32 elements + UnrolledLoop([&](size_t i) { b_data[i] += 16; }); } PackedBBlobPtr += Q4Type::BlobSize; @@ -222,21 +226,19 @@ ComputeDotProducts(const float* ARowPtr, } // namespace q4gemm_neon -template <> +template MLAS_FORCEINLINE size_t -MlasQ4GemmKernelNeon(const float* A, - const uint8_t* PackedB, - float* C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t lda, - size_t ldb, - size_t ldc, - const float* Bias) +MlasQ4GemmKernelNeon(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) { - using Q4Type = MLAS_Q4TYPE_BLK0; - static constexpr size_t NCols = 4; // columns to handle at once auto impl0_reference = [&]() { @@ -250,7 +252,13 @@ MlasQ4GemmKernelNeon(const float* A, const size_t kblocklen = std::min(CountK - k, Q4Type::BlkLen); const float s = MlasQ4BlkScale(PackedBBlock); - const uint8_t z = 8; // MlasQ4BlkZeroPoint(PackedBBlock); + const uint8_t z = [PackedBBlock]() -> uint8_t { + if constexpr (MlasQ4BlkHasZeroPoint()) { + return MlasQ4BlkZeroPoint(PackedBBlock); + } else { + return 8; + } + }(); const uint8_t* PackedBData = MlasQ4BlkData(PackedBBlock); for (size_t kk = 0; kk < kblocklen; kk += 32) { @@ -292,229 +300,6 @@ MlasQ4GemmKernelNeon(const float* A, return CountM; }; - auto impl1 = [&]() { - const float* ARowPtr = A; - float* CRowPtr = C; - const float* BiasPtr = Bias; - - const int8x16_t LowMask = vdupq_n_s8(0x0F); - - for (size_t m = 0; m < CountM; ++m) { - const uint8_t* PackedBColPtr = PackedB; - - for (size_t n = 0; n < CountN; ++n) { - float32x4_t acc = vdupq_n_f32(0.0f); - const uint8_t* PackedBBlobPtr = PackedBColPtr; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); - - const float scale = MlasQ4BlkScale(PackedBBlobPtr); - const uint8_t zp = 8; - const uint8_t* b_data = MlasQ4BlkData(PackedBBlobPtr); - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { - // load A row vector elements - - // load 32 elements from A padding with 0's if there aren't enough - float a_segment[32]; - { - const size_t k_subblk_len = - std::min(k_blk_len - k_idx_in_blk, size_t{32}); - const float* a_begin = ARowPtr + k + k_idx_in_blk; - std::copy(a_begin, a_begin + k_subblk_len, a_segment); - std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); - } - - // 32 floats of A - float32x4_t av[8]; - for (int i = 0; i < 8; ++i) { - av[i] = vld1q_f32(a_segment + i * 4); - } - - // load B column vector - int8x16_t bv_packed = vld1q_s8(b_data); - - int8x16_t bv_bytes_0 = vandq_s8(bv_packed, LowMask); - int8x16_t bv_bytes_1 = vandq_s8(vshrq_n_s8(bv_packed, 4), LowMask); - - // dequantize B - - // subtract zero point - const int8x16_t zpv = vdupq_n_s8(zp); - - bv_bytes_0 = vsubq_s8(bv_bytes_0, zpv); - bv_bytes_1 = vsubq_s8(bv_bytes_1, zpv); - - // widen to int16 - int16x8_t bv_int16_0 = vmovl_s8(vget_low_s8(bv_bytes_0)); - int16x8_t bv_int16_1 = vmovl_s8(vget_high_s8(bv_bytes_0)); - int16x8_t bv_int16_2 = vmovl_s8(vget_low_s8(bv_bytes_1)); - int16x8_t bv_int16_3 = vmovl_s8(vget_high_s8(bv_bytes_1)); - - // 32 floats of B - float32x4_t bv[8]; - - // widen to int32, cast to float32 - - int32x4_t bv_int32_0 = vmovl_s16(vget_low_s16(bv_int16_0)); - - bv[0] = vcvtq_f32_s32(bv_int32_0); - bv[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_0))); - - bv[2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_1))); - bv[3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_1))); - - bv[4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_2))); - bv[5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_2))); - - bv[6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_3))); - bv[7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_3))); - - // multiply by scale - for (int i = 0; i < 8; ++i) { - bv[i] = vmulq_n_f32(bv[i], scale); - } - - // c += a * b - for (int i = 0; i < 8; ++i) { - acc = vfmaq_f32(acc, av[i], bv[i]); - } - } - - PackedBBlobPtr += Q4Type::BlobSize; - } - - float sum = vaddvq_f32(acc); - - sum += BiasPtr ? BiasPtr[n] : 0.0f; - - CRowPtr[n] = sum; - - PackedBColPtr += ldb; - } - - ARowPtr += lda; - CRowPtr += ldc; - } - - return CountM; - }; - - auto impl2_two_accumulators = [&]() { - const float* ARowPtr = A; - float* CRowPtr = C; - const float* BiasPtr = Bias; - - const int8x16_t LowMask = vdupq_n_s8(0x0F); - - for (size_t m = 0; m < CountM; ++m) { - const uint8_t* PackedBColPtr = PackedB; - - for (size_t n = 0; n < CountN; ++n) { - float32x4_t acc0 = vdupq_n_f32(0.0f); - float32x4_t acc1 = vdupq_n_f32(0.0f); - - const uint8_t* PackedBBlobPtr = PackedBColPtr; - - for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - const size_t k_blk_len = std::min(CountK - k, Q4Type::BlkLen); - - const float scale = MlasQ4BlkScale(PackedBBlobPtr); - const uint8_t zp = 8; - const uint8_t* b_data = MlasQ4BlkData(PackedBBlobPtr); - - for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { - // load A row vector elements - - // load 32 elements from A padding with 0's if there aren't enough - float a_segment[32]; - { - const size_t k_subblk_len = - std::min(k_blk_len - k_idx_in_blk, size_t{32}); - const float* a_begin = ARowPtr + k + k_idx_in_blk; - std::copy(a_begin, a_begin + k_subblk_len, a_segment); - std::fill(a_segment + k_subblk_len, a_segment + 32, 0.0f); - } - - // 32 floats of A - float32x4_t av[8]; - for (int i = 0; i < 8; ++i) { - av[i] = vld1q_f32(a_segment + i * 4); - } - - // load B column vector - int8x16_t bv_packed = vld1q_s8(b_data); - - int8x16_t bv_bytes_0 = vandq_s8(bv_packed, LowMask); - int8x16_t bv_bytes_1 = vandq_s8(vshrq_n_s8(bv_packed, 4), LowMask); - - // dequantize B - - // subtract zero point - const int8x16_t zpv = vdupq_n_s8(zp); - - bv_bytes_0 = vsubq_s8(bv_bytes_0, zpv); - bv_bytes_1 = vsubq_s8(bv_bytes_1, zpv); - - // widen to int16 - int16x8_t bv_int16_0 = vmovl_s8(vget_low_s8(bv_bytes_0)); - int16x8_t bv_int16_1 = vmovl_s8(vget_high_s8(bv_bytes_0)); - int16x8_t bv_int16_2 = vmovl_s8(vget_low_s8(bv_bytes_1)); - int16x8_t bv_int16_3 = vmovl_s8(vget_high_s8(bv_bytes_1)); - - // 32 floats of B - float32x4_t bv[8]; - - // widen to int32, cast to float32 - - int32x4_t bv_int32_0 = vmovl_s16(vget_low_s16(bv_int16_0)); - - bv[0] = vcvtq_f32_s32(bv_int32_0); - bv[1] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_0))); - - bv[2] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_1))); - bv[3] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_1))); - - bv[4] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_2))); - bv[5] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_2))); - - bv[6] = vcvtq_f32_s32(vmovl_s16(vget_low_s16(bv_int16_3))); - bv[7] = vcvtq_f32_s32(vmovl_s16(vget_high_s16(bv_int16_3))); - - // multiply by scale - for (int i = 0; i < 8; ++i) { - bv[i] = vmulq_n_f32(bv[i], scale); - } - - // c += a * b - for (int i = 0; i < 4; ++i) { - acc0 = vfmaq_f32(acc0, av[i], bv[i]); - acc1 = vfmaq_f32(acc1, av[i + 4], bv[i + 4]); - } - } - - PackedBBlobPtr += Q4Type::BlobSize; - } - - float sum = - vpadds_f32(vpadd_f32(vpadd_f32(vget_low_f32(acc0), vget_high_f32(acc0)), - vpadd_f32(vget_low_f32(acc1), vget_high_f32(acc1)))); - - sum += BiasPtr ? BiasPtr[n] : 0.0f; - - CRowPtr[n] = sum; - - PackedBColPtr += ldb; - } - - ARowPtr += lda; - CRowPtr += ldc; - } - - return CountM; - }; - auto impl3_four_cols = [&]() { const float* ARowPtr = A; float* CRowPtr = C; @@ -527,8 +312,8 @@ MlasQ4GemmKernelNeon(const float* A, int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - q4gemm_neon::ComputeDotProducts(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, - BiasPtr); + q4gemm_neon::ComputeDotProducts(ARowPtr, PackedBColPtr, SumPtr, + CountK, ldb, BiasPtr); // move to next `NCols` columns @@ -542,33 +327,8 @@ MlasQ4GemmKernelNeon(const float* A, // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - q4gemm_neon::ComputeDotProducts<1>(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, - BiasPtr); - - PackedBColPtr += ldb; - BiasPtr += BiasPtr != nullptr ? 1 : 0; - SumPtr += 1; - } - - ARowPtr += lda; - CRowPtr += ldc; - } - - return CountM; - }; - - auto impl4_one_col_with_helper = [&]() { - const float* ARowPtr = A; - float* CRowPtr = C; - - for (size_t m = 0; m < CountM; ++m) { - const float* BiasPtr = Bias; - const uint8_t* PackedBColPtr = PackedB; - float* SumPtr = CRowPtr; - - for (size_t n = 0; n < CountN; ++n) { - q4gemm_neon::ComputeDotProducts<1>(ARowPtr, PackedBColPtr, SumPtr, CountK, ldb, - BiasPtr); + q4gemm_neon::ComputeDotProducts(ARowPtr, PackedBColPtr, SumPtr, CountK, + ldb, BiasPtr); PackedBColPtr += ldb; BiasPtr += BiasPtr != nullptr ? 1 : 0; @@ -583,7 +343,6 @@ MlasQ4GemmKernelNeon(const float* A, }; auto impl5_four_cols_inline = [&]() { - const float* ARowPtr = A; float* CRowPtr = C; @@ -609,7 +368,14 @@ MlasQ4GemmKernelNeon(const float* A, scale[i] = MlasQ4BlkScale(PackedBBlobPtr + i * ldb); }); - const uint8_t zp = 8; + uint8_t zp[NCols]; + q4gemm_neon::UnrolledLoop([&](size_t i) { + if constexpr (MlasQ4BlkHasZeroPoint()) { + zp[i] = MlasQ4BlkZeroPoint(PackedBBlobPtr + i * ldb); + } else { + zp[i] = 8; + } + }); const uint8_t* b_data[NCols]; q4gemm_neon::UnrolledLoop([&](size_t i) { @@ -648,9 +414,9 @@ MlasQ4GemmKernelNeon(const float* A, // dequantize B // subtract zero point - const int8x16_t zpv = vdupq_n_s8(zp); - q4gemm_neon::UnrolledLoop([&](size_t i) { + const int8x16_t zpv = vdupq_n_s8(zp[i]); + bv_bytes[i][0] = vsubq_s8(bv_bytes[i][0], zpv); bv_bytes[i][1] = vsubq_s8(bv_bytes[i][1], zpv); }); @@ -696,6 +462,11 @@ MlasQ4GemmKernelNeon(const float* A, q4gemm_neon::UnrolledLoop( [&](size_t i) { acc[i] = vfmaq_f32(acc[i], av[j], bv[i][j]); }); }); + + // increment b data pointers to next 32 elements + q4gemm_neon::UnrolledLoop([&](size_t i) { + b_data[i] += 16; + }); } PackedBBlobPtr += Q4Type::BlobSize; @@ -719,7 +490,6 @@ MlasQ4GemmKernelNeon(const float* A, nblk += NCols; if (nblk > 0) { - float32x4_t acc[NCols]{}; const uint8_t* PackedBBlobPtr = PackedBColPtr; @@ -729,11 +499,16 @@ MlasQ4GemmKernelNeon(const float* A, float scale[NCols]; const uint8_t* b_data[NCols]; - const uint8_t zp = 8; + uint8_t zp[NCols]; for (int64_t nn = 0; nn < nblk; ++nn) { scale[nn] = MlasQ4BlkScale(PackedBBlobPtr + nn * ldb); b_data[nn] = MlasQ4BlkData(PackedBBlobPtr + nn * ldb); + if constexpr (MlasQ4BlkHasZeroPoint()) { + zp[nn] = MlasQ4BlkZeroPoint(PackedBBlobPtr + nn * ldb); + } else { + zp[nn] = 8; + } } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += 32) { @@ -765,7 +540,7 @@ MlasQ4GemmKernelNeon(const float* A, // dequantize B // subtract zero point - const int8x16_t zpv = vdupq_n_s8(zp); + const int8x16_t zpv = vdupq_n_s8(zp[nn]); bv_bytes[0] = vsubq_s8(bv_bytes[0], zpv); bv_bytes[1] = vsubq_s8(bv_bytes[1], zpv); @@ -798,9 +573,11 @@ MlasQ4GemmKernelNeon(const float* A, [&](size_t j) { bv[j] = vmulq_n_f32(bv[j], scale[nn]); }); // c += a * b - q4gemm_neon::UnrolledLoop<8>([&](size_t j) { - acc[nn] = vfmaq_f32(acc[nn], av[j], bv[j]); - }); + q4gemm_neon::UnrolledLoop<8>( + [&](size_t j) { acc[nn] = vfmaq_f32(acc[nn], av[j], bv[j]); }); + + // increment b data pointers to next 32 elements + b_data[nn] += 16; } } @@ -821,11 +598,8 @@ MlasQ4GemmKernelNeon(const float* A, }; // return impl0_reference(); - // return impl1(); - // return impl2_two_accumulators(); // return impl3_four_cols(); - // return impl4_one_col_with_helper(); - return impl5_four_cols_inline(); + // return impl5_four_cols_inline(); } template <> @@ -845,6 +619,57 @@ MlasQ4GemmKernel(const float* A, ldc, Bias); } +template <> +MLAS_FORCEINLINE size_t +MlasQ4GemmKernel(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) +{ + return MlasQ4GemmKernelNeon(A, PackedB, C, CountM, CountN, CountK, lda, ldb, + ldc, Bias); +} + +template <> +MLAS_FORCEINLINE size_t +MlasQ4GemmKernel(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) +{ + return MlasQ4GemmKernelNeon(A, PackedB, C, CountM, CountN, CountK, lda, ldb, + ldc, Bias); +} + +template <> +MLAS_FORCEINLINE size_t +MlasQ4GemmKernel(const float* A, + const uint8_t* PackedB, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + const float* Bias) +{ + return MlasQ4GemmKernelNeon(A, PackedB, C, CountM, CountN, CountK, lda, ldb, + ldc, Bias); +} + // // MlasBlkQ4DequantB and related helper functions // @@ -852,15 +677,8 @@ MlasQ4GemmKernel(const float* A, template MLAS_FORCEINLINE void MlasBlkQ4DequantBNeon( - float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); - -template <> -MLAS_FORCEINLINE void -MlasBlkQ4DequantBNeon( float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) { - using Q4Type = MLAS_Q4TYPE_BLK0; - // unpack B in format suitable for MlasSgemmKernelZero auto impl0_reference = [&]() { @@ -874,12 +692,18 @@ MlasBlkQ4DequantBNeon( const uint8_t* PackedBBlock = PackedBCol; for (size_t k = 0; k < CountK; k += Q4Type::BlkLen) { - float b_blk_unpacked[32]{}; + float b_blk_unpacked[Q4Type::BlkLen]{}; const size_t kblocklen = std::min(CountK - k, Q4Type::BlkLen); const float s = MlasQ4BlkScale(PackedBBlock); - const uint8_t z = 8; // MlasQ4BlkZeroPoint(PackedBBlock); + const uint8_t z = [PackedBBlock]() -> uint8_t { + if constexpr (MlasQ4BlkHasZeroPoint()) { + return MlasQ4BlkZeroPoint(PackedBBlock); + } else { + return 8; + } + }(); const uint8_t* PackedBData = MlasQ4BlkData(PackedBBlock); for (size_t kk = 0; kk < kblocklen; kk += 32) { @@ -928,6 +752,8 @@ MlasBlkQ4DequantBNeon( }; impl0_reference(); + + // TODO optimized implementation } template <> @@ -938,16 +764,40 @@ MlasBlkQ4DequantB( MlasBlkQ4DequantBNeon(FpData, PackedB, CountN, CountK, ldb); } +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + MlasBlkQ4DequantBNeon(FpData, PackedB, CountN, CountK, ldb); +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + MlasBlkQ4DequantBNeon(FpData, PackedB, CountN, CountK, ldb); +} + +template <> +MLAS_FORCEINLINE void +MlasBlkQ4DequantB( + float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb) +{ + MlasBlkQ4DequantBNeon(FpData, PackedB, CountN, CountK, ldb); +} + // // MlasFpQ4GemmDispatchNeon structure population // static MLAS_Q4GEMM_OPERATION* Q4Operations_neon[] = { MlasQ4GemmOperation, + MlasQ4GemmOperation, + MlasQ4GemmOperation, nullptr, - nullptr, - nullptr, - nullptr, + MlasQ4GemmOperation, }; const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon = {Q4Operations_neon}; From 34d40a8fcdd22d2765c0f98abbdde56eabf88942 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 19 Oct 2023 11:52:57 -0700 Subject: [PATCH 09/10] uncomment an impl --- onnxruntime/core/mlas/lib/q4gemm_neon.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp index d6aba1357ecbd..ec143bdc1b070 100644 --- a/onnxruntime/core/mlas/lib/q4gemm_neon.cpp +++ b/onnxruntime/core/mlas/lib/q4gemm_neon.cpp @@ -598,7 +598,7 @@ MlasQ4GemmKernelNeon(const float* A, }; // return impl0_reference(); - // return impl3_four_cols(); + return impl3_four_cols(); // return impl5_four_cols_inline(); } From 0a4c43439ac284ee257f38ee4f6558109f3011f8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 20 Oct 2023 10:57:15 -0700 Subject: [PATCH 10/10] remove redundant inline --- onnxruntime/core/mlas/lib/q4gemm.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/q4gemm.h b/onnxruntime/core/mlas/lib/q4gemm.h index b4a3941af1edb..2528513cf24a8 100644 --- a/onnxruntime/core/mlas/lib/q4gemm.h +++ b/onnxruntime/core/mlas/lib/q4gemm.h @@ -42,8 +42,6 @@ MLAS_FORCEINLINE void MlasBlkQ4DequantB(float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb); - -inline MLAS_FORCEINLINE void MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) {