Skip to content

Commit

Permalink
Enable AVX NE CONVERT for FP16 to FP32 cast
Browse files Browse the repository at this point in the history
* Developed x86 and amd64 assembly kernel using AVX NE CONVERT.
* Developed x86 assembly kernel using SSE instructions.
* Added fallback implementation for FP16 to FP32 cast.
* Runtime check to determine if CPU supports the ISA requiered for the kernel.
* Added kernel dispatching logic on platform.cpp
  • Loading branch information
eralmual committed Sep 4, 2024
1 parent 09d786f commit 289d92f
Show file tree
Hide file tree
Showing 10 changed files with 537 additions and 11 deletions.
19 changes: 19 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm.cpp
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -212,6 +213,12 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm
${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm
)
if(MSVC_VERSION GREATER_EQUAL 1933)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm
)
endif()

if (NOT onnxruntime_ORT_MINIMAL_BUILD)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
Expand Down Expand Up @@ -522,6 +529,12 @@ else()
${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S
${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S
)
if(NOT APPLE)
set(mlas_platform_srcs_sse2
${mlas_platform_srcs_sse2}
${MLAS_SRC_DIR}/x86_64/cvtfp16a.S
)
endif()
set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2")

set(mlas_platform_srcs_avx
Expand Down Expand Up @@ -555,6 +568,12 @@ else()
${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp
)
if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE))
set(mlas_platform_srcs_avx2
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S
)
endif()

message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}")
message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1029,14 +1029,13 @@ MlasComputeTanh(
// Half-precision floating-point routines.
//

extern "C"
void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
);
);

//
// Transpose routines.
Expand Down
151 changes: 151 additions & 0 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
;++
;
; Copyright (c) Intel Corporation. All rights reserved.
;
; Licensed under the MIT License.
;
; Module Name:
;
; cvtfp16Avx2.asm
;
; Abstract:
;
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA.
;
;--

.xlist
INCLUDE mlasi.inc
.list

.const

SINGLE_SIZE equ 4
HALF_SIZE equ 2
LOW_SELECTOR equ 00100000b
HIGH_SELECTOR equ 00110001b

SUBTTL "Convert buffer of half-precision floats to single-precision floats"
;++
;
; Routine Description:
;
; This routine converts the source buffer of half-precision floats to the
; destination buffer of single-precision floats.
;
; This implementation uses AVX2 instructions.
;
; Arguments:
;
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
;
; Return Value:
;
; None.
;
;--


LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT

test r8, r8 ; Check if we have any elements to convert
jz ExitRoutine
cmp r8, 8
jb ConvertMaskedVectors
cmp r8, 16
jb Convert128Vectors



Convert256Vectors:
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order
vmovups ymmword PTR [rdx], ymm0 ; Store the low part
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part

add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements
sub r8, 16 ; Reduce the counter by 16 elements

jz ExitRoutine ; If we are done, exit
cmp r8, 16 ; If the vector is big enough, we go again
jae Convert256Vectors



Convert128Vectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part

add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements
sub r8, 8 ; Reduce the counter by 8 elements

jz ExitRoutine ; If we are done, exit



ConvertMaskedVectors:
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order

cmp r8, 4 ; Check if we can store the complete lower vector
jae ConvertLowerVector

vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertLower1
ja ConvertLower3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertLowerMaskedVector
ConvertLower1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertLowerMaskedVector
ConvertLower3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertLowerMaskedVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing
ConvertLowerVector:
vmovups xmmword PTR [rdx], xmm0 ; Store the low part
sub r8, 4 ; Check if we still need to convert
jz ExitRoutine


add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones
cmp r8, 2 ; Check how many converts we need
jb ConvertUpper1
ja ConvertUpper3
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values
jmp ConvertMaskedUpperVector
ConvertUpper1:
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value
jmp ConvertMaskedUpperVector
ConvertUpper3:
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values
ConvertMaskedUpperVector:
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples

ExitRoutine:
ret

LEAF_END MlasCastF16ToF32KernelAvx, _TEXT

END
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
; Source (rcx) - Supplies the address of the source buffer of half-precision
; floats.
;
; Destination (edx) - Supplies the address of the destination buffer of
; Destination (rdx) - Supplies the address of the destination buffer of
; single-precision floats.
;
; Count (r8) - Supplies the number of elements to convert.
Expand All @@ -53,7 +53,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h)
;
;--

LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT
LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT

test r8,r8
jz ExitRoutine
Expand Down Expand Up @@ -119,6 +119,6 @@ StoreLastElement:
ExitRoutine:
ret

LEAF_END MlasConvertHalfToFloatBuffer, _TEXT
LEAF_END MlasCastF16ToF32KernelSse, _TEXT

END
59 changes: 59 additions & 0 deletions onnxruntime/core/mlas/lib/cast.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*++
Copyright (c) Intel Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
cast.cpp
Abstract:
This module implements Half (F16) to Single (F32) precision casting.
--*/
#include "mlasi.h"

union fp32_bits {
uint32_t u;
float f;
};

void
MLASCALL
MlasConvertHalfToFloatBuffer(
const unsigned short* Source,
float* Destination,
size_t Count
)
{

if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) {
// If there is no kernel use the reference implementation, adapted from mlas_float16.h.
constexpr fp32_bits magic = {113 << 23};
constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift

for (size_t i = 0; i < Count; ++i) {
fp32_bits o;
o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits
uint32_t exp = shifted_exp & o.u; // just the exponent
o.u += (127 - 15) << 23; // exponent adjust

// handle exponent special cases
if (exp == shifted_exp) { // Inf/NaN?
o.u += (128 - 16) << 23; // extra exp adjust
} else if (exp == 0) { // Zero/Denormal?
o.u += 1 << 23; // extra exp adjust
o.f -= magic.f; // renormalize
}

o.u |= (Source[i] & 0x8000) << 16; // sign bit
Destination[i] = o.f;
}

} else {
// If the kernel is available, use it to perform the conversion.
GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count);
}
}
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ void
size_t N
);

typedef
void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)(
const unsigned short* Source,
float* Destination,
size_t Count
);

typedef
void
(MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)(
Expand Down Expand Up @@ -870,6 +877,11 @@ extern "C" {
MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx;
#endif

#if defined(MLAS_TARGET_AMD64)
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse;
MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx;
#endif

}

//
Expand Down Expand Up @@ -1151,6 +1163,8 @@ struct MLAS_PLATFORM {
const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr};

const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr};

MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
};

inline
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ Return Value:
this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel<uint8_t, uint8_t>;
this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel<int8_t, int8_t>;
this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel<int8_t, uint8_t>;
this->CastF16ToF32Kernel = nullptr;

#if defined(MLAS_TARGET_AMD64_IX86)

Expand Down Expand Up @@ -283,6 +284,9 @@ Return Value:
this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel;
this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel;
this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel;
#ifndef __APPLE__
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse;
#endif // __APPLE__

this->NchwcBlockSize = 8;
this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT;
Expand Down Expand Up @@ -469,6 +473,16 @@ Return Value:
}

#ifndef __APPLE__
#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))
//
// Check if the processor supports AVX NE CONVERT.
//
if ((Cpuid7_1[3] & (0b1 << 5)) != 0) {
this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx;
}
#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13))


//
// Check if the processor supports AMX-TILE and AMX-INT8
// features.
Expand Down
Loading

0 comments on commit 289d92f

Please sign in to comment.