-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enable AVX NE CONVERT for FP16 to FP32 cast
* Enable AVX_NE_CONVERT detection via CPUID. * Developed x86 and amd64 assembly kernel using the new ISA. * Integrated kernel.
- Loading branch information
Showing
7 changed files
with
343 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
;++ | ||
; | ||
; Copyright (c) Intel Corporation. All rights reserved. | ||
; | ||
; Licensed under the MIT License. | ||
; | ||
; Module Name: | ||
; | ||
; cvtfp16Avx2.asm | ||
; | ||
; Abstract: | ||
; | ||
; This module implements routines to convert between FP16 and FP32 formats using the AVX_NE_CONVERT ISA. | ||
; | ||
;-- | ||
|
||
.xlist | ||
INCLUDE mlasi.inc | ||
.list | ||
|
||
.const | ||
SINGLE_SIZE equ 4 | ||
HALF_SIZE equ 2 | ||
LOW_SELECTOR equ 00100000b | ||
HIGH_SELECTOR equ 00110001b | ||
SUBTTL "Convert buffer of half-precision floats to single-precision floats" | ||
;++ | ||
; | ||
; Routine Description: | ||
; | ||
; This routine converts the source buffer of half-precision floats to the | ||
; destination buffer of single-precision floats. | ||
; | ||
; This implementation uses AVX2 instructions. | ||
; | ||
; Arguments: | ||
; | ||
; Source (rcx) - Supplies the address of the source buffer of half-precision | ||
; floats. | ||
; | ||
; Destination (rdx) - Supplies the address of the destination buffer of | ||
; single-precision floats. | ||
; | ||
; Count (r8) - Supplies the number of elements to convert. | ||
; | ||
; Return Value: | ||
; | ||
; None. | ||
; | ||
;-- | ||
|
||
LEAF_ENTRY MlasConvertHalfToFloatBufferAVX, _TEXT | ||
|
||
test r8, r8 | ||
jz ExitRoutine | ||
cmp r8, 8 | ||
jb ConvertMaskedVectors | ||
cmp r8, 16 | ||
jb Convert128Vectors | ||
|
||
Convert256Vectors: | ||
vcvtneeph2ps ymm0, ymmword PTR [rcx] ; Load even indexes | ||
vcvtneoph2ps ymm1, ymmword PTR [rcx] ; Load odd indexes | ||
vunpcklps ymm2, ymm0, ymm1 ; Interleave low part | ||
vunpckhps ymm1, ymm0, ymm1 ; Interleave high part | ||
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR ; Fix the order | ||
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR ; Fix the order | ||
vmovups ymmword PTR [rdx], ymm0 ; Store the low part | ||
vmovups ymmword PTR [rdx + 8*SINGLE_SIZE], ymm1 ; Store the high part | ||
|
||
add rcx, 16*HALF_SIZE ; Advance src ptr by 16 elements | ||
add rdx, 16*SINGLE_SIZE ; Advance dest ptr by 16 elements | ||
sub r8, 16 ; Reduce the counter by 16 elements | ||
|
||
jz ExitRoutine ; If we are done, exit | ||
cmp r8, 16 ; If the vector is big enough, we go again | ||
jae Convert256Vectors | ||
|
||
|
||
|
||
Convert128Vectors: | ||
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes | ||
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes | ||
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order | ||
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order | ||
vmovups xmmword PTR [rdx], xmm0 ; Store the low part | ||
vmovups xmmword PTR [rdx + 4*SINGLE_SIZE], xmm1 ; Store the high part | ||
|
||
add rcx, 8*HALF_SIZE ; Advance src ptr by 8 elements | ||
add rdx, 8*SINGLE_SIZE ; Advance dest ptr by 8 elements | ||
sub r8, 8 ; Reduce the counter by 8 elements | ||
|
||
jz ExitRoutine ; If we are done, exit | ||
|
||
|
||
|
||
ConvertMaskedVectors: | ||
vcvtneeph2ps xmm2, xmmword PTR [rcx] ; Load even indexes | ||
vcvtneoph2ps xmm1, xmmword PTR [rcx] ; Load odd indexes | ||
vunpcklps xmm0, xmm2, xmm1 ; Interleave low part to fix order | ||
vunpckhps xmm1, xmm2, xmm1 ; Interleave high part to fix order | ||
|
||
cmp r8, 4 ; Chek if we can store the complete lower vector | ||
jae ConvertLowerVector | ||
|
||
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones | ||
cmp r8, 2 ; Check how many converts we need | ||
jb ConvertLower1 | ||
ja ConvertLower3 | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values | ||
jmp ConvertLowerMaskedVector | ||
ConvertLower1: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value | ||
jmp ConvertLowerMaskedVector | ||
ConvertLower3: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values | ||
ConvertLowerMaskedVector: | ||
vmaskmovps xmmword PTR [rdx], xmm2, xmm0 ; Store the masked data, the shift is done in 8bit multiples | ||
jmp ExitRoutine ; If we ran into any of the cases above, means we are done after storing | ||
ConvertLowerVector: | ||
vmovups xmmword PTR [rdx], xmm0 ; Store the low part | ||
sub r8, 4 ; Check if we still need to convert | ||
jz ExitRoutine | ||
|
||
|
||
add rdx, 4*SINGLE_SIZE ; Advance dest ptr by 4 elements | ||
vpcmpeqw xmm2, xmm2, xmm2 ; Initialize the mask full of ones | ||
cmp r8, 2 ; Check how many converts we need | ||
jb ConvertUpper1 | ||
ja ConvertUpper3 | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 ; Shift the memory store two values | ||
jmp ConvertMaskedUpperVector | ||
ConvertUpper1: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 ; Shift the memory store only one value | ||
jmp ConvertMaskedUpperVector | ||
ConvertUpper3: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE ; Shift the memory store three values | ||
ConvertMaskedUpperVector: | ||
vmaskmovps xmmword PTR [rdx], xmm2, xmm1 ; Store the masked data, the shift is done in 8bit multiples | ||
|
||
jmp ExitRoutine | ||
|
||
ExitRoutine: | ||
ret | ||
|
||
LEAF_END MlasConvertHalfToFloatBufferAVX, _TEXT | ||
|
||
END |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/*++ Routine Description: | ||
|
||
This routine converts the source buffer of half-precision floats to the | ||
destination buffer of single-precision floats. | ||
|
||
This implementation uses AVX2 instructions. | ||
|
||
Arguments: | ||
|
||
Source (rdi) - Supplies the address of the source buffer of half-precision | ||
floats. | ||
|
||
Destination (rsi) - Supplies the address of the destination buffer of | ||
single-precision floats. | ||
|
||
Count (rdx) - Supplies the number of elements to convert. | ||
|
||
Return Value: | ||
|
||
None. | ||
|
||
--*/ | ||
.data | ||
.equ SINGLE_SIZE, 4 | ||
.equ HALF_SIZE, 2 | ||
.equ LOW_SELECTOR, 0b00100000 | ||
.equ HIGH_SELECTOR, 0b00110001 | ||
|
||
.text | ||
.globl MlasConvertHalfToFloatBufferAVX | ||
.intel_syntax noprefix | ||
|
||
MlasConvertHalfToFloatBufferAVX: | ||
test rdx, rdx // Check if we have any elements to convert | ||
jz ExitRoutine | ||
|
||
AVX_NE_CONVERT: | ||
cmp rdx, 8 | ||
jb ConvertMaskedVectors | ||
cmp rdx, 16 | ||
jb Convert128Vectors | ||
|
||
Convert256Vectors: | ||
vcvtneeph2ps ymm0, ymmword PTR [rdi] // Load even indexes | ||
vcvtneoph2ps ymm1, ymmword PTR [rdi] // Load odd indexes | ||
vunpcklps ymm2, ymm0, ymm1 // Interleave low part | ||
vunpckhps ymm1, ymm0, ymm1 // Interleave high part | ||
vperm2f128 ymm0, ymm2, ymm1, LOW_SELECTOR // Fix the order | ||
vperm2f128 ymm1, ymm2, ymm1, HIGH_SELECTOR // Fix the order | ||
vmovups ymmword PTR [rsi], ymm0 // Store the low part | ||
vmovups ymmword PTR [rsi + 8*SINGLE_SIZE], ymm1 // Store the high part | ||
|
||
add rdi, 16*HALF_SIZE // Advance src ptr by 16 elements | ||
add rsi, 16*SINGLE_SIZE // Advance dest ptr by 16 elements | ||
sub rdx, 16 // Reduce the counter by 16 elements | ||
|
||
jz ExitRoutine // If we are done, exit | ||
cmp rdx, 16 // If the vector is big enough, we go again | ||
jae Convert256Vectors | ||
|
||
|
||
|
||
Convert128Vectors: | ||
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes | ||
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes | ||
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order | ||
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order | ||
vmovups xmmword PTR [rsi], xmm0 // Store the low part | ||
vmovups xmmword PTR [rsi + 4*SINGLE_SIZE], xmm1 // Store the high part | ||
|
||
add rdi, 8*HALF_SIZE // Advance src ptr by 8 elements | ||
add rsi, 8*SINGLE_SIZE // Advance dest ptr by 8 elements | ||
sub rdx, 8 // Reduce the counter by 8 elements | ||
|
||
jz ExitRoutine // If we are done, exit | ||
|
||
|
||
|
||
ConvertMaskedVectors: | ||
vcvtneeph2ps xmm2, xmmword PTR [rdi] // Load even indexes | ||
vcvtneoph2ps xmm1, xmmword PTR [rdi] // Load odd indexes | ||
vunpcklps xmm0, xmm2, xmm1 // Interleave low part to fix order | ||
vunpckhps xmm1, xmm2, xmm1 // Interleave high part to fix order | ||
|
||
cmp rdx, 4 // Chek if we can store the complete lower vector | ||
jae ConvertLowerVector | ||
|
||
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones | ||
cmp rdx, 2 // Check how many converts we need | ||
jb ConvertLower1 | ||
ja ConvertLower3 | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values | ||
jmp ConvertLowerMaskedVector | ||
ConvertLower1: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value | ||
jmp ConvertLowerMaskedVector | ||
ConvertLower3: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values | ||
ConvertLowerMaskedVector: | ||
vmaskmovps xmmword PTR [rsi], xmm2, xmm0 // Store the masked data, the shift is done in 8bit multiples | ||
jmp ExitRoutine // If we ran into any of the cases above, means we are done after storing | ||
ConvertLowerVector: | ||
vmovups xmmword PTR [rsi], xmm0 // Store the low part | ||
sub rdx, 4 // Check if we still need to convert | ||
jz ExitRoutine | ||
|
||
|
||
add rsi, 4*SINGLE_SIZE // Advance dest ptr by 4 elements | ||
vpcmpeqw xmm2, xmm2, xmm2 // Initialize the mask full of ones | ||
cmp rdx, 2 // Check how many converts we need | ||
jb ConvertUpper1 | ||
ja ConvertUpper3 | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*2 // Shift the memory store two values | ||
jmp ConvertMaskedUpperVector | ||
ConvertUpper1: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE*3 // Shift the memory store only one value | ||
jmp ConvertMaskedUpperVector | ||
ConvertUpper3: | ||
vpsrldq xmm2, xmm2, SINGLE_SIZE // Shift the memory store three values | ||
ConvertMaskedUpperVector: | ||
vmaskmovps xmmword PTR [rsi], xmm2, xmm1 // Store the masked data, the shift is done in 8bit multiples | ||
|
||
jmp ExitRoutine | ||
ExitRoutine: | ||
ret |
Oops, something went wrong.