Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU EP] optimize qlinearsoftmax #22686

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
${MLAS_SRC_DIR}/qsoftmax.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_naive.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -169,6 +171,10 @@ function(setup_mlas_source_for_windows)
file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS
"${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp"
)
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
"${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp"
)
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2")

target_sources(onnxruntime_mlas PRIVATE
Expand All @@ -177,6 +183,7 @@ function(setup_mlas_source_for_windows)
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
Expand Down Expand Up @@ -581,6 +588,7 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
Expand Down Expand Up @@ -621,6 +629,7 @@ endif()

set(mlas_platform_srcs_avx512vnni
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
)
set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")

Expand Down
150 changes: 6 additions & 144 deletions onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ void QlinearBuildLookupTableUint32(gsl::span<QLinearSoftmax::EXP_OUT_DTYPE> tabl
for (int32_t i = 0; i < 256; i++) {
double scaled_exp_xi = exp((static_cast<double>(i) - 255 + bit_shift) * static_cast<double>(x_scale));
// we can't get the real max value of input tensor here, so we just assume 255-bit_shift.
// in the function of `QlinearSoftmaxCPU`,
// in the function of `MlasComputeQSoftmax`,
// all numbers will have a shift (255-bit_shift-max_value) if its max value is not 255
//
// if is_signed index = [1 2 3 ......126 127 -128 -127 ..... -3 -2 -1]
Expand Down Expand Up @@ -123,136 +123,6 @@ Status QLinearSoftmax::Compute(OpKernelContext* ctx) const {
}
}

template <typename T>
common::Status QlinearSoftmaxCPU(size_t N,
size_t D,
const T* x_data,
T* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
T yzp,
onnxruntime::concurrency::ThreadPool* thread_pool);

template <>
common::Status QlinearSoftmaxCPU<uint8_t>(size_t N,
size_t D,
const uint8_t* x_data,
uint8_t* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
uint8_t yzp,
onnxruntime::concurrency::ThreadPool* thread_pool) {
using onnxruntime::TensorOpCost;
using onnxruntime::concurrency::ThreadPool;
ThreadPool::TryParallelFor(
thread_pool, N,
// Read 3*N (max,sum,div) write N (div), computation=Read
TensorOpCost{static_cast<double>(D) * 3.0,
static_cast<double>(D),
static_cast<double>(D) * 3.0},
[x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) {
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;
const uint8_t* x_t = x_data + first * D;
uint8_t* y_t = y_data + first * D;
for (; first < last; first++) {
// reduceMaxUint8
uint8_t xmax = *std::max_element(x_t, x_t + D);
// we want the xmas to align with 255 for higher precision.
// as we build a lookup table with X-255. So we could use the adjustment here
// to let all numbers have a shift in the lookup table.
// 1 2 3 4 5 ...........................254 255
// 1 3 5 ... 10
// after the shift --->
// 235 237 239 .. 255
const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table + 255 - xmax;
size_t elements_n = D;
// reduceSumUin8ToUint32: need speedup
// vsum = \sum_i{e^x_i}
QLinearSoftmax::EXP_OUT_DTYPE vsum = 0;
const uint8_t* x_t_cur = x_t;
do {
const size_t vx = *x_t_cur++;
vsum += shifted_lookuptable[vx];
} while (--elements_n != 0);
if (vsum == 0) {
return;
}
elements_n = D;
x_t_cur = x_t;
// elementwise div, y_i=\frac{x_i}{vsum}
do {
const size_t vx = *x_t_cur++;
const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx];
// simulate round function, and re-quant to uint8
const uint32_t vq = static_cast<uint32_t>(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp;
const uint8_t vy = vq > 255 ? static_cast<uint8_t>(255) : static_cast<uint8_t>(vq);
*y_t++ = vy;
} while (--elements_n != 0);
x_t = x_t_cur;
}
});

return Status::OK();
}

template <>
common::Status QlinearSoftmaxCPU<int8_t>(size_t N,
size_t D,
const int8_t* x_data,
int8_t* y_data,
const QLinearSoftmax::EXP_OUT_DTYPE* lookup_table,
QLinearSoftmax::EXP_OUT_DTYPE y_scale,
int8_t yzp,
onnxruntime::concurrency::ThreadPool* thread_pool) {
using onnxruntime::TensorOpCost;
using onnxruntime::concurrency::ThreadPool;
ThreadPool::TryParallelFor(
thread_pool, N,
// Read 3*N (max,sum,div) write N (div), computation=Read
TensorOpCost{static_cast<double>(D) * 3.0,
static_cast<double>(D),
static_cast<double>(D) * 3.0},
[x_data, y_data, D, y_scale, yzp, &lookup_table](std::ptrdiff_t first, std::ptrdiff_t last) {
const auto c_y_scale = y_scale;
const auto c_y_zp = yzp;

const int8_t* x_t = x_data + first * D;
int8_t* y_t = y_data + first * D;
for (; first < last; first++) {
// reduceMaxInt8
int8_t xmax = *std::max_element(x_t, x_t + D);
const int32_t adjustment = int32_t(127) - xmax;
const QLinearSoftmax::EXP_OUT_DTYPE* shifted_lookuptable = lookup_table;
size_t elements_n = D;
// reduceSumUin8ToUint32: need speedup
QLinearSoftmax::EXP_OUT_DTYPE vsum = 0;
const int8_t* x_t_cur = x_t;
do {
const uint8_t vx = uint8_t(adjustment + (*x_t_cur++));
vsum += shifted_lookuptable[vx];
} while (--elements_n != 0);
if (vsum == 0) {
return;
}
elements_n = D;
x_t_cur = x_t;
// elementwise div
do {
const uint8_t vx = uint8_t(adjustment + (*x_t_cur++));
const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx];
// simulate round function, and re-quant to Int8
const int32_t vq = static_cast<int32_t>(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp;
const int8_t vy = static_cast<int32_t>(vq) > 255 ? static_cast<int8_t>(255) : static_cast<int8_t>(vq);
*y_t++ = vy;
} while (--elements_n != 0);
x_t = x_t_cur;
}
});

return Status::OK();
}

gsl::span<const QLinearSoftmax::EXP_OUT_DTYPE> QLinearSoftmax::GetLookupTable(
OpKernelContext* context,
gsl::span<EXP_OUT_DTYPE> lookup_table_span,
Expand All @@ -270,25 +140,17 @@ gsl::span<const QLinearSoftmax::EXP_OUT_DTYPE> QLinearSoftmax::GetLookupTable(
Status QLinearSoftmax::ComputeInternal(OpKernelContext* context, const Tensor& input, Tensor& output,
gsl::span<const EXP_OUT_DTYPE> lookup_table, int axis,
concurrency::ThreadPool* thread_pool) const {
const auto* X_scale_tensor = context->Input<Tensor>(1);
const auto* Y_scale_tensor = context->Input<Tensor>(3);
const auto* Y_zp_tensor = context->Input<Tensor>(4);
const QLinearSoftmax::EXP_OUT_DTYPE Y_scale = std::floor(1.0F / (*(Y_scale_tensor->Data<float>())));
const auto& X_shape = input.Shape();
const size_t N = onnxruntime::narrow<size_t>(X_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis)));
const size_t D = onnxruntime::narrow<size_t>(X_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis)));
common::Status status;
if (is_signed_) {
using T = int8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
} else {
using T = uint8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
}
return status;
const int Y_zp = Y_zp_tensor ? (is_signed_ ? *(Y_zp_tensor->Data<int8_t>()) : *(Y_zp_tensor->Data<uint8_t>())) : 0;
MlasComputeQSoftmax(input.DataRaw(), output.MutableDataRaw(), N, D, lookup_table.data(),
*X_scale_tensor->Data<float>(), Y_scale, Y_zp, is_signed_, thread_pool);
return Status::OK();
}

// opset-13 and above
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,21 @@ MlasComputeSoftmax(
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeQSoftmax(
const void* Input,
void* Output,
size_t N,
size_t D,
const float* LoopupTable,
float X_Scale,
float Scale,
int ZeroPoint,
bool is_signed,
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeTanh(
Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Module Name:
#if defined(__GNUC__) && __GNUC__ >= 12
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
#pragma GCC diagnostic ignored "-Wuninitialized" // GCC 12 warns about uninitialized variables in immintrin.h.
#include <immintrin.h>
#pragma GCC diagnostic pop
#else
Expand Down Expand Up @@ -727,6 +728,28 @@ void
float Scale,
int8_t ZeroPoint);

typedef
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_I8_KERNEL)(
size_t D,
const int8_t* Xdata,
int8_t* Ydata,
const float* LookupTable,
float Yscale,
int8_t YZeroPoint,
float* Buff);

typedef
void
(MLASCALL MLAS_QUANTIZE_SOFTMAX_U8_KERNEL)(
size_t D,
const uint8_t* Xdata,
uint8_t* Ydata,
const float* LookupTable,
float Yscale,
uint8_t YZeroPoint,
float* Buff);

template<typename InputType, typename FilterType>
struct MLAS_QUANT_KERNEL
{
Expand Down Expand Up @@ -896,7 +919,13 @@ extern "C" {
MLAS_QLINEAR_BINARY_OP_U8_KERNEL MlasQLinearAddU8KernelAvx2;
MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelAvx512F;
MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelAvx512F;
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx2;
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelAvx512;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx2;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelAvx512;
#endif
MLAS_QUANTIZE_SOFTMAX_I8_KERNEL MlasQuantizeSoftmaxI8KernelNaive;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL MlasQuantizeSoftmaxU8KernelNaive;

MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL MlasReduceMaximumF32Kernel;
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32Kernel;
Expand Down Expand Up @@ -1217,6 +1246,9 @@ struct MLAS_PLATFORM {
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;

const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};

MLAS_QUANTIZE_SOFTMAX_I8_KERNEL* QuantizeSoftmaxI8Kernel;
MLAS_QUANTIZE_SOFTMAX_U8_KERNEL* QuantizeSoftmaxU8Kernel;
};

inline
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ Return Value:
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;
this->CastF32ToF16Kernel = nullptr;
this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelNaive;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelNaive;

#if defined(MLAS_TARGET_AMD64_IX86)

Expand All @@ -258,7 +260,6 @@ Return Value:
this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchSse;

#if defined(MLAS_TARGET_AMD64)

this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Sse;
this->GemmDoubleKernel = MlasGemmDoubleKernelSse;
this->ConvNchwFloatKernel = MlasConvNchwFloatKernelSse;
Expand Down Expand Up @@ -402,6 +403,8 @@ Return Value:
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2;
this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2;

this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx2;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx2;

//
// Check if the processor supports Hybrid core architecture.
Expand Down Expand Up @@ -471,6 +474,8 @@ Return Value:
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

this->QuantizeSoftmaxI8Kernel = MlasQuantizeSoftmaxI8KernelAvx512;
this->QuantizeSoftmaxU8Kernel = MlasQuantizeSoftmaxU8KernelAvx512;
//
// Check if the processor supports AVX512VNNI.
//
Expand Down
Loading
Loading