diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c02ac2096db2e..cf23416943c1f 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -40,6 +40,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/sqnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp ) target_sources(onnxruntime_mlas PRIVATE @@ -212,6 +213,12 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm ) + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + if (NOT onnxruntime_ORT_MINIMAL_BUILD) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/q4gemm_avx512.cpp @@ -522,6 +529,12 @@ else() ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") set(mlas_platform_srcs_avx @@ -555,6 +568,12 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bea4b91ebaa79..8b3156d77e57c 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1029,14 +1029,13 @@ MlasComputeTanh( // Half-precision floating-point routines. // -extern "C" void MLASCALL MlasConvertHalfToFloatBuffer( const unsigned short* Source, float* Destination, size_t Count - ); +); // // Transpose routines. diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm new file mode 100644 index 0000000000000..c7f6342c527bf --- /dev/null +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -0,0 +1,151 @@ +;++ +; +; Copyright (c) Intel Corporation. All rights reserved. +; +; Licensed under the MIT License. +; +; Module Name: +; +; cvtfp16Avx2.asm +; +; Abstract: +; +; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. +; +;-- + + .xlist +INCLUDE mlasi.inc + .list + + .const + +SINGLE_SIZE equ 4 +HALF_SIZE equ 2 +LOW_SELECTOR equ 00100000b +HIGH_SELECTOR equ 00110001b + + SUBTTL "Convert buffer of half-precision floats to single-precision floats" +;++ +; +; Routine Description: +; +; This routine converts the source buffer of half-precision floats to the +; destination buffer of single-precision floats. +; +; This implementation uses AVX2 instructions. +; +; Arguments: +; +; Source (rcx) - Supplies the address of the source buffer of half-precision +; floats. +; +; Destination (rdx) - Supplies the address of the destination buffer of +; single-precision floats. +; +; Count (r8) - Supplies the number of elements to convert. +; +; Return Value: +; +; None. +; +;-- + + +LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT + + test r8, r8 ; Check if we have any elements to convert + jz ExitRoutine + cmp r8, 8 + jb ConvertMaskedVectors + cmp r8, 16 + jb Convert128Vectors + + + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes + vunpcklps ymm2, ymm0, ymm1 ; Interleave low part + vunpckhps ymm1, ymm0, ymm1 ; Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order + vmovups ymmword PTR [rdx], ymm0 ; Store the low part + vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part + + add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements + add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements + sub r8, 16 ; Reduce the counter by 16 elements + + jz ExitRoutine ; If we are done, exit + cmp r8, 16 ; If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part + + add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements + add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements + sub r8, 8 ; Reduce the counter by 8 elements + + jz ExitRoutine ; If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes + vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order + + cmp r8, 4 ; Check if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rdx], xmm0 ; Store the low part + sub r8, 4 ; Check if we still need to convert + jz ExitRoutine + + + add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones + cmp r8, 2 ; Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples + +ExitRoutine: + ret + + LEAF_END MlasCastF16ToF32KernelAvx, _TEXT + + END diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm index 50315146ca79b..0ad98d3115208 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16a.asm @@ -42,7 +42,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h) ; Source (rcx) - Supplies the address of the source buffer of half-precision ; floats. ; -; Destination (edx) - Supplies the address of the destination buffer of +; Destination (rdx) - Supplies the address of the destination buffer of ; single-precision floats. ; ; Count (r8) - Supplies the number of elements to convert. @@ -53,7 +53,7 @@ MlasFp16MagicDenormal DD 4 DUP (38800000h) ; ;-- - LEAF_ENTRY MlasConvertHalfToFloatBuffer, _TEXT + LEAF_ENTRY MlasCastF16ToF32KernelSse, _TEXT test r8,r8 jz ExitRoutine @@ -119,6 +119,6 @@ StoreLastElement: ExitRoutine: ret - LEAF_END MlasConvertHalfToFloatBuffer, _TEXT + LEAF_END MlasCastF16ToF32KernelSse, _TEXT END diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp new file mode 100644 index 0000000000000..24af4064bbd9b --- /dev/null +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -0,0 +1,59 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cast.cpp + +Abstract: + + This module implements Half (F16) to Single (F32) precision casting. + +--*/ +#include "mlasi.h" + +union fp32_bits { + uint32_t u; + float f; +}; + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const unsigned short* Source, + float* Destination, + size_t Count +) +{ + + if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { + // If there is no kernel use the reference implementation, adapted from mlas_float16.h. + constexpr fp32_bits magic = {113 << 23}; + constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + + for (size_t i = 0; i < Count; ++i) { + fp32_bits o; + o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits + uint32_t exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // renormalize + } + + o.u |= (Source[i] & 0x8000) << 16; // sign bit + Destination[i] = o.f; + } + + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + } +} diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 4239e2ecaeb6e..6f5db766b7def 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,6 +610,13 @@ void size_t N ); +typedef +void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( + const unsigned short* Source, + float* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -870,6 +877,11 @@ extern "C" { MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL MlasReduceMinimumMaximumF32KernelAvx; #endif +#if defined(MLAS_TARGET_AMD64) + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; +#endif + } // @@ -1151,6 +1163,8 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + + MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index ed437f20f7c2a..4cd7faaa9e6ff 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -244,6 +244,7 @@ Return Value: this->ConvDepthwiseU8U8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; + this->CastF16ToF32Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -283,6 +284,9 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; +#ifndef __APPLE__ + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#endif // __APPLE__ this->NchwcBlockSize = 8; this->PreferredBufferAlignment = MLAS_DEFAULT_PREFERRED_BUFFER_ALIGNMENT; @@ -469,6 +473,16 @@ Return Value: } #ifndef __APPLE__ +#if (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) + // + // Check if the processor supports AVX NE CONVERT. + // + if ((Cpuid7_1[3] & (0b1 << 5)) != 0) { + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx; + } +#endif // (defined(_MSC_VER) && (_MSC_VER >= 1933)) || (defined(__GNUC__) && (__GNUC__ >= 13)) + + // // Check if the processor supports AMX-TILE and AMX-INT8 // features. diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S new file mode 100644 index 0000000000000..1a70061460e50 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -0,0 +1,143 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cvtfp16Avx2.asm + +Abstract: + + This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. + +--*/ + +#include "asmmacro.h" + +.data +.equ SINGLE_SIZE, 4 +.equ HALF_SIZE, 2 +.equ LOW_SELECTOR, 0b00100000 +.equ HIGH_SELECTOR, 0b00110001 + +.text +.intel_syntax noprefix + +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses AVX2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ +FUNCTION_ENTRY MlasCastF16ToF32KernelAvx + + test rdx, rdx // Check if we have any elements to convert + jz ExitRoutine + +AVX_NE_CONVERT: + cmp rdx, 8 + jb ConvertMaskedVectors + cmp rdx, 16 + jb Convert128Vectors + +Convert256Vectors: + vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes + vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes + vunpcklps ymm2, ymm0, ymm1 // Interleave low part + vunpckhps ymm1, ymm0, ymm1 // Interleave high part + vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order + vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order + vmovups ymmword PTR [rsi], ymm0 // Store the low part + vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part + + add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements + add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements + sub rdx, 16 // Reduce the counter by 16 elements + + jz ExitRoutine // If we are done, exit + cmp rdx, 16 // If the vector is big enough, we go again + jae Convert256Vectors + + + +Convert128Vectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + vmovups xmmword PTR [rsi], xmm0 // Store the low part + vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part + + add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements + add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements + sub rdx, 8 // Reduce the counter by 8 elements + + jz ExitRoutine // If we are done, exit + + + +ConvertMaskedVectors: + vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes + vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes + vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order + vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order + + cmp rdx, 4 // Check if we can store the complete lower vector + jae ConvertLowerVector + + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertLower1 + ja ConvertLower3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertLowerMaskedVector +ConvertLower1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertLowerMaskedVector +ConvertLower3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertLowerMaskedVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples + jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing +ConvertLowerVector: + vmovups xmmword PTR [rsi], xmm0 // Store the low part + sub rdx, 4 // Check if we still need to convert + jz ExitRoutine + + + add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements + vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones + cmp rdx, 2 // Check how many converts we need + jb ConvertUpper1 + ja ConvertUpper3 + vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values + jmp ConvertMaskedUpperVector +ConvertUpper1: + vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value + jmp ConvertMaskedUpperVector +ConvertUpper3: + vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values +ConvertMaskedUpperVector: + vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples + + jmp ExitRoutine +ExitRoutine: + ret diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S new file mode 100644 index 0000000000000..f27114c183f44 --- /dev/null +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16a.S @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + cvtfp16a.S + +Abstract: + + This module implements routines to convert between FP16 and FP32 formats using SSE2 isntructions. + +--*/ + +#include "asmmacro.h" + +// We use RIP relative addressing to avoid reallication related errors +.section .rodata +MlasFp16MaskSign: .long 0x00007FFF +MlasFp16CompareInfinity: .long 0x00007C00 +MlasFp16CompareSmallest: .long 0x00000400 +MlasFp16AdjustExponent: .long 0x38000000 +MlasFp16MagicDenormal: .long 0x38800000 + +.text +.intel_syntax noprefix + +/*++ Routine Description: + + This routine converts the source buffer of half-precision floats to the + destination buffer of single-precision floats. + + This implementation uses SSE2 instructions. + + Arguments: + + Source (rdi) - Supplies the address of the source buffer of half-precision + floats. + + Destination (rsi) - Supplies the address of the destination buffer of + single-precision floats. + + Count (rdx) - Supplies the number of elements to convert. + + Return Value: + + None. + +--*/ + +FUNCTION_ENTRY MlasCastF16ToF32KernelSse + + test rdx,rdx + jz ExitRoutine + + // Load xmm constants + movd xmm5, DWORD PTR [rip + MlasFp16MaskSign] + pshufd xmm5, xmm5, 0x00 + movd xmm6, DWORD PTR [rip + MlasFp16AdjustExponent] + pshufd xmm6, xmm6, 0x00 + movd xmm7, DWORD PTR [rip + MlasFp16MagicDenormal] + pshufd xmm7, xmm7, 0x00 + + + cmp rdx,4 + jb LoadPartialVector + +LoadFullVector: + movq xmm0,QWORD PTR [rdi] + add rdi,4*2 // advance S by 4 elements + +ConvertHalfToFloat: + punpcklwd xmm0,xmm0 // duplicate 4 WORDs to 4 DWORDs + movaps xmm1,xmm0 // isolate exponent/mantissa + pand xmm1,xmm5 + pxor xmm0,xmm1 // isolate sign bit + movd xmm2, DWORD PTR [rip + MlasFp16CompareInfinity] + pshufd xmm2, xmm2, 0x00 + pcmpgtd xmm2,xmm1 // test for infinity/NaNs + movd xmm3, DWORD PTR [rip + MlasFp16CompareSmallest] + pshufd xmm3, xmm3, 0x00 + pcmpgtd xmm3,xmm1 // test for denormals + pandn xmm2,xmm6 + pslld xmm1,13 // shift exponent/mask into place + movaps xmm4,xmm1 + paddd xmm1,xmm6 + paddd xmm1,xmm2 // adjust exponent again for infinity/NaNs + paddd xmm4,xmm7 + pslld xmm0,16 // shift sign into place + subps xmm4,xmm7 + pand xmm4,xmm3 // select elements that are denormals + pandn xmm3,xmm1 // select elements that are not denormals + por xmm3,xmm4 // blend the selected values together + por xmm0,xmm3 // merge sign into exponent/mantissa + + cmp rdx,4 // storing full vector? + jb StorePartialVector + movups XMMWORD PTR [rsi],xmm0 + add rsi,4*4 // advance D by 4 elements + sub rdx,4 + jz ExitRoutine + cmp rdx,4 + jae LoadFullVector + +LoadPartialVector: + pxor xmm0,xmm0 + pinsrw xmm0,WORD PTR [rdi],0 + cmp rdx,2 + jb ConvertHalfToFloat + pinsrw xmm0,WORD PTR [rdi+2],1 + je ConvertHalfToFloat + pinsrw xmm0,WORD PTR [rdi+4],2 + jmp ConvertHalfToFloat + +StorePartialVector: + cmp rdx,2 + jb StoreLastElement + movsd QWORD PTR [rsi],xmm0 + je ExitRoutine + movhlps xmm0,xmm0 // shift third element down + add rsi,4*2 // advance D by 2 elements + +StoreLastElement: + movss DWORD PTR [rsi],xmm0 + +ExitRoutine: + ret diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 6742bab4fa4a2..f2aaa75cadd8d 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -22,9 +22,8 @@ #include "Eigen/src/Core/arch/Default/BFloat16.h" #include "Eigen/src/Core/arch/Default/Half.h" -#if defined(_M_AMD64) && !defined(_M_ARM64EC) #include "core/mlas/inc/mlas.h" -#endif +#include "core/common/cpuid_info.h" namespace onnxruntime { @@ -252,10 +251,6 @@ struct TensorCasterNoSat { #endif -#if defined(_M_AMD64) && !defined(_M_ARM64EC) -// specializations to use optimized and Windows x64-specific -// MlasConvertHalfToFloatBuffer() routine for MLFloat16 -> float conversion - // tensor MLFloat16 -> float template <> struct TensorCaster { @@ -267,6 +262,9 @@ struct TensorCaster { } }; +#if defined(_M_AMD64) && !defined(_M_ARM64EC) +// specializations to use optimized and Windows x64-specific + Tensor GetIntermediateMLFloat16ToFloatTensor( const OpKernelContext& context, const TensorShape& shape, const Tensor& in) { AllocatorPtr allocator;