Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/edgchen1/arm64_q4gemm' into edgc…
Browse files Browse the repository at this point in the history
…hen1/mlas_qnbitgemm_driver
  • Loading branch information
edgchen1 committed Oct 25, 2023
2 parents cc7e8cc + 0a4c434 commit 6de6339
Show file tree
Hide file tree
Showing 9 changed files with 867 additions and 56 deletions.
2 changes: 2 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ function(setup_mlas_source_for_windows)
if(onnxruntime_target_platform STREQUAL "ARM64")
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/q4gemm_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
Expand Down Expand Up @@ -331,6 +332,7 @@ else()
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S
${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S
${MLAS_SRC_DIR}/q4gemm_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp
Expand Down
28 changes: 15 additions & 13 deletions onnxruntime/core/common/cpuid_uarch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

#include "core/common/cpuid_uarch.h"

#include "core/common/logging/logging.h"
#include <iostream> // For std::cerr.
// Writing to stderr instead of logging because logger may not be initialized yet.

namespace onnxruntime {

Expand Down Expand Up @@ -137,7 +138,7 @@ void decodeMIDR(
break;
// #endif /* ARM */
default:
LOGS_DEFAULT(WARNING) << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
}
break;
Expand All @@ -156,7 +157,7 @@ void decodeMIDR(
break;
// #endif
default:
LOGS_DEFAULT(WARNING) << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
// #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__)
Expand All @@ -172,7 +173,7 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_thunderx2;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
// #endif
Expand All @@ -187,7 +188,7 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_cortex_a76;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
// #if defined(_M_ARM) || defined(__arm__)
Expand All @@ -199,7 +200,7 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_xscale;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
// #endif /* ARM */
Expand All @@ -215,7 +216,7 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_carmel;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
#if !defined(__ANDROID__)
Expand All @@ -225,7 +226,7 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_xgene;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
#endif
Expand Down Expand Up @@ -297,7 +298,7 @@ void decodeMIDR(
break;
// #endif /* ARM64 && !defined(__ANDROID__) */
default:
LOGS_DEFAULT(WARNING) << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
case 'S':
Expand Down Expand Up @@ -343,8 +344,9 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_exynos_m5;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Samsung CPU variant 0x"
<< std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Samsung CPU variant 0x"
<< std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr)
<< " ignored\n";
}
break;
// #if defined(_M_ARM) || defined(__arm__)
Expand All @@ -355,12 +357,12 @@ void decodeMIDR(
*uarch = cpuinfo_uarch_pj4;
break;
default:
LOGS_DEFAULT(WARNING) << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored";
std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n";
}
break;
// #endif /* ARM */
default:
LOGS_DEFAULT(WARNING) << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr;
std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n";
}
}

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,7 @@ extern const MLAS_Q8Q4GEMM_DISPATCH MlasQ8Q4GemmDispatchAvx512vnni;
struct MLAS_FPQ4GEMM_DISPATCH;

extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512;
extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchNeon;

//
// Quantized depthwise convolution kernels.
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ Return Value:
this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon;
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchNeon;

//
// Check if the processor supports ASIMD dot product instructions.
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/lib/q4common.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ MlasQ4BlkScale(const uint8_t* BlkPtr)
return *reinterpret_cast<const float*>(BlkPtr);
}

template <typename T>
constexpr bool
MlasQ4BlkHasZeroPoint()
{
return false;
}

template <typename T>
uint8_t&
MlasQ4BlkZeroPoint(uint8_t* BlkPtr);
Expand Down Expand Up @@ -100,6 +107,13 @@ struct MLAS_Q4TYPE_BLK1 {
static constexpr size_t BlobSize = BlkLen / 2 + sizeof(float) + sizeof(uint8_t);
};

template <>
constexpr bool
MlasQ4BlkHasZeroPoint<MLAS_Q4TYPE_BLK1>()
{
return true;
}

template<>
inline uint8_t&
MlasQ4BlkZeroPoint<MLAS_Q4TYPE_BLK1>(uint8_t* BlkPtr)
Expand Down
29 changes: 25 additions & 4 deletions onnxruntime/core/mlas/lib/q4gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,31 @@ MLAS_FORCEINLINE
void
MlasBlkQ4DequantB(float* FpData, const uint8_t* PackedB, size_t CountN, size_t CountK, size_t ldb);

MLAS_FORCEINLINE
void
MlasAddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc) {
for (size_t m = 0; m < CountM; m++) {
const float* bias = Bias;
float* sum = C;
for (size_t n = 0; n < CountN; n += 4) {
if (CountN - n < 4) {
for (size_t nn = n; nn < CountN; nn++) {
*sum += *bias;
sum++;
bias++;
}
break;
}

template <typename KERNEL>
MLAS_FORCEINLINE void
AddBiasAvx(const float* Bias, float* C, size_t CountM, size_t CountN, size_t ldc);
MLAS_FLOAT32X4 acc_x = MlasLoadFloat32x4(sum);
acc_x = MlasAddFloat32x4(acc_x, MlasLoadFloat32x4(bias));
MlasStoreFloat32x4(sum, acc_x);
bias += 4;
sum += 4;
}
C += ldc;
}
}



Expand Down Expand Up @@ -135,7 +156,7 @@ MlasQ4GemmOperation(
#endif

if (bias) {
AddBiasAvx<KERNEL>(bias, c_blk, RowsHandled, CountN, ldc);
MlasAddBiasForGemm(bias, c_blk, RowsHandled, CountN, ldc);
}
if (DataParams->OutputProcessor != nullptr) {
DataParams->OutputProcessor->Process(
Expand Down
34 changes: 0 additions & 34 deletions onnxruntime/core/mlas/lib/q4gemm_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,40 +1030,6 @@ MlasBlkQ4DequantSgemmPackB(
}
}

template<>
MLAS_FORCEINLINE
void
AddBiasAvx<MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI>(
const float* Bias,
float* C,
size_t CountM,
size_t CountN,
size_t ldc
)
{
for (size_t m = 0; m < CountM; m++) {
const float* bias = Bias;
float* sum = C;
for (size_t n = 0; n < CountN; n += 4) {
if (CountN - n < 4) {
for (size_t nn = n; nn < CountN; nn++) {
*sum += *bias;
sum++;
bias++;
}
break;
}

__m128 acc_x = _mm_loadu_ps(sum);
acc_x = _mm_add_ps(acc_x, _mm_loadu_ps(bias));
_mm_storeu_ps(sum, acc_x);
bias += 4;
sum += 4;
}
C += ldc;
}
}


static MLAS_Q4GEMM_OPERATION* Q4Operations_avx512vnni[] = {
MlasQ4GemmOperation<MLAS_Q4TYPE_BLK0, MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI>,
Expand Down
Loading

0 comments on commit 6de6339

Please sign in to comment.