Skip to content

Commit

Permalink
Matmul_nbits kernel for mlas sqnbits to support Fp16 inputs (microsof…
Browse files Browse the repository at this point in the history
  • Loading branch information
liqunfu authored Sep 13, 2024
1 parent 7e2c722 commit a89bddd
Show file tree
Hide file tree
Showing 11 changed files with 341 additions and 116 deletions.
4 changes: 2 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")

if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11")
message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni")
else()
message(STATUS "Using -mavx2 -mfma flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c")
endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down
246 changes: 181 additions & 65 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder(
T scale = *(scale_data + n_idx * scales_shape_x + rid);
float zp_f = 8;
if (zero_points) {
if constexpr (std::is_same_v<zeroT, T>) {
zp_f = *(zero_points + n_idx * scales_shape_x + rid);
} else {
if constexpr (std::is_same_v<zeroT, uint8_t>) {
uint8_t zp = 8;
zp = zero_points[n_idx * zero_point_shape_x + rid / 2];
zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f);
} else {
zp_f = *(zero_points + static_cast<uint64_t>(n_idx) * static_cast<uint64_t>(scales_shape_x) + static_cast<uint64_t>(rid));
}
}

Expand Down Expand Up @@ -112,5 +112,10 @@ template void DequantizeBlockwise<float, float>(
const float* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

template void DequantizeBlockwise<float, MLFloat16>(
float* output, const uint8_t* quant_data, const float* scales_data,
const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size,
bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool);

} // namespace contrib
} // namespace onnxruntime
36 changes: 23 additions & 13 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Module Name:
#include <cstddef>
#include <cstdlib>
#include <cstdint>
#include <stdexcept>

//
// Define the calling convention for Windows targets.
Expand Down Expand Up @@ -1025,18 +1026,6 @@ MlasComputeTanh(
size_t N
);

//
// Half-precision floating-point routines.
//

void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);

//
// Transpose routines.
//
Expand Down Expand Up @@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16;

constexpr size_t FP16_SIZE = sizeof(uint16_t);

/**
//
// Half-precision floating-point routines.
//

void
MLASCALL
MlasConvertHalfToFloatBuffer(
const MLAS_FP16* Source,
float* Destination,
size_t Count
);

void
MLASCALL
MlasConvertFloatToHalfBuffer(
const float* Source,
MLAS_FP16* Destination,
size_t Count
);

/**
* @brief Whether current CPU supports FP16 acceleration.
*/
bool MLASCALL
Expand Down Expand Up @@ -1787,6 +1796,7 @@ MlasTranspose(
M, N);
}


#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED
/**
* @brief Max Pooling for fp16 NHWC
Expand Down
42 changes: 20 additions & 22 deletions onnxruntime/core/mlas/lib/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,35 @@ union fp32_bits {
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
const MLAS_FP16* Source,
float* Destination,
size_t Count
)
{

if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
constexpr fp32_bits magic = {113 << 23};
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
for (size_t i = 0; i < Count; ++i) {
Destination[i] = Source[i].ToFloat();
}
} else {
// If the kernel is available, use it to perform the conversion.
GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast<const unsigned short*>(Source), Destination, Count);
}
}

void
MLASCALL
MlasConvertFloatToHalfBuffer(
const float* Source,
MLAS_FP16* Destination,
size_t Count
)
{
if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) {
for (size_t i = 0; i < Count; ++i) {
fp32_bits o;
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
uint32_t exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust

// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}

o.u |= (Source[i] & 0x8000) << 16; // sign bit
Destination[i] = o.f;
Destination[i] = MLAS_FP16(Source[i]);
}

} else {
// If the kernel is available, use it to perform the conversion.
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast<unsigned short*>(Destination), Count);
}
}
11 changes: 10 additions & 1 deletion onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,13 +610,19 @@ void
size_t N
);

typedef
typedef
void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)(
const unsigned short* Source,
float* Destination,
size_t Count
);

typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)(
const float* Source,
unsigned short* Destination,
size_t Count
);

typedef
void
(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)(
Expand Down Expand Up @@ -880,6 +886,8 @@ extern "C" {
#if defined(MLAS_TARGET_AMD64)
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse;
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx;
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2;
MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2;
#endif

}
Expand Down Expand Up @@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM {
const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
};

inline
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ Return Value:
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;
this->CastF32ToF16Kernel = nullptr;

#if defined(MLAS_TARGET_AMD64_IX86)

Expand Down Expand Up @@ -387,6 +388,9 @@ Return Value:
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2<int8_t, uint8_t>;
this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3;
this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2;
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;


//
// Check if the processor supports Hybrid core architecture.
Expand Down
45 changes: 45 additions & 0 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@ Module Name:
#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h"
#include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h"

void
MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size)
{
size_t i = 0;

// Process 16 elements at a time using AVX2
for (; i + 15 < size; i += 16) {
// Load 16 FP16 values into an AVX2 register
__m256i fp16_values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src_fp16 + i));

// Convert FP16 values to FP32
__m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values));
__m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1));

// Store the converted FP32 values into the output vector
_mm256_storeu_ps(dst_fp32 + i, fp32_values1);
_mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2);
}

// Process any remaining elements
const MLAS_FP16* fp16 = reinterpret_cast<const MLAS_FP16*>(src_fp16);
for (; i < size; ++i) {
dst_fp32[i] = fp16[i].ToFloat();
}
}

void
MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size)
{
size_t i = 0;

// Process 8 elements at a time using AVX2
for (; i + 8 <= size; i += 8) {
__m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]);
__m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk);
}

// Process any remaining elements
for (; i < size; ++i) {
MLAS_FP16 fp16(src_fp32[i]);
dst_fp16[i] = fp16.val;
}
}

MLAS_FORCEINLINE
__m256
load_float_n_avx2(const float* data, int n)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/tensor/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ struct TensorCaster<MLFloat16, float> {
auto out_data = out.MutableData<float>();
auto in_data = in.Data<MLFloat16>();
const size_t shape_size = narrow<size_t>(shape.Size());
MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size);
MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size);
}
};

Expand Down
54 changes: 46 additions & 8 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ void RunTest(const TestOptions& opts,

} // namespace

TEST(MatMulNBits, Float32) {
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");
template <typename AType>
void TestMatMulNBitsTyped() {
for (auto M : {1, 2, 100}) {
for (auto N : {/*2560, */ 1, 2, 32, 288}) {
for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) {
Expand All @@ -276,30 +276,53 @@ TEST(MatMulNBits, Float32) {

if (base_opts.accuracy_level == 4) {
base_opts.output_abs_error = 0.1f;
} else {
if constexpr (std::is_same<AType, MLFloat16>::value) {
base_opts.output_abs_error = 0.01f;
}
}

{
TestOptions opts = base_opts;
RunTest<float>(opts);
RunTest<AType>(opts);
}

{
TestOptions opts = base_opts;
opts.has_zero_point = true;
RunTest<float>(opts);
RunTest<AType>(opts);
}

#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML)
{
TestOptions opts = base_opts;
opts.has_g_idx = true;
RunTest<float>(opts);
RunTest<AType>(opts);
}

{
TestOptions opts = base_opts;
opts.has_g_idx = true;
opts.has_bias = true;
if constexpr (std::is_same<AType, float>::value) {
if (opts.accuracy_level == 0 || opts.accuracy_level == 1) {
// CI failure (not able to repro on either local machines):
// M:100, N:288, K:1234, block_size:16, accuracy_level:0, has_zero_point:0, zp_is_4bit:1, has_g_idx:1, has_bias:1
// The difference between cur_expected[i] and cur_actual[i] is 1.0401010513305664e-05, which exceeds tolerance,
// tolerance evaluates to 1.006456386676291e-05.
opts.output_abs_error = 0.0001f;
}
}
// only enabled for CPU EP for now
std::vector<std::unique_ptr<IExecutionProvider>> explicit_eps;
explicit_eps.emplace_back(DefaultCpuExecutionProvider());
RunTest<AType>(opts, std::move(explicit_eps));
}

{
TestOptions opts = base_opts;
opts.has_zero_point = true, opts.zp_is_4bit = false;
RunTest<float>(opts);
RunTest<AType>(opts);
}
#endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML)

Expand All @@ -311,7 +334,7 @@ TEST(MatMulNBits, Float32) {
std::vector<std::unique_ptr<IExecutionProvider>> explicit_eps;
explicit_eps.emplace_back(DefaultCpuExecutionProvider());

RunTest<float>(opts, std::move(explicit_eps));
RunTest<AType>(opts, std::move(explicit_eps));
}
}
}
Expand All @@ -320,6 +343,21 @@ TEST(MatMulNBits, Float32) {
}
}

TEST(MatMulNBits, Float32) {
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");
TestMatMulNBitsTyped<float>();
}

#ifdef MLAS_TARGET_AMD64_IX86
#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML)
// Actual and expected difference is over 0.01 with DmlExecutionProvider.
// Skip the tests instead of raising the tolerance to make is pass.
TEST(MatMulNBits, Float16) {
TestMatMulNBitsTyped<MLFloat16>();
}
#endif
#endif

#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML)

namespace {
Expand Down Expand Up @@ -367,7 +405,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura
}
} // namespace

TEST(MatMulNBits, Float16) {
TEST(MatMulNBits, Float16Cuda) {
#if defined(USE_CUDA) || defined(USE_ROCM)
auto has_gidx_options = {true, false};
#else
Expand Down

0 comments on commit a89bddd

Please sign in to comment.