Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlas] Speed up tanhf activation function #20612

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 39 additions & 46 deletions onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm
Original file line number Diff line number Diff line change
Expand Up @@ -64,39 +64,33 @@ INCLUDE TransKernelCommon.inc
END_PROLOGUE

lea rax,MlasTanhConstants
vbroadcastss ymm4,TanhConstants.LowerRange[rax]
vbroadcastss ymm5,TanhConstants.UpperRange[rax]
vbroadcastss ymm6,TanhConstants.alpha_13[rax]
vbroadcastss ymm7,TanhConstants.alpha_11[rax]
vbroadcastss ymm8,TanhConstants.alpha_9[rax]
vbroadcastss ymm9,TanhConstants.alpha_7[rax]
vbroadcastss ymm10,TanhConstants.alpha_5[rax]
vbroadcastss ymm11,TanhConstants.alpha_3[rax]
vbroadcastss ymm12,TanhConstants.alpha_1[rax]
vbroadcastss ymm13,TanhConstants.beta_6[rax]
vbroadcastss ymm14,TanhConstants.beta_2[rax]
vbroadcastss ymm15,TanhConstants.beta_0[rax]
vbroadcastss ymm5, DWORD PTR [rax + 0] ; nc2
vbroadcastss ymm6, DWORD PTR [rax + 4] ; nc1
vbroadcastss ymm4, DWORD PTR [rax + 8] ; nc0
vbroadcastss ymm7, DWORD PTR [rax + 12] ; dc2
vbroadcastss ymm8, DWORD PTR [rax + 16] ; dc1
vbroadcastss ymm9, DWORD PTR [rax + 20] ; dc0
vbroadcastss ymm10, DWORD PTR [rax + 24] ; absmask
vbroadcastss ymm11, DWORD PTR [rax + 28] ; bound

sub r8,8
jb ProcessRemainingCount

ComputeTanhBy8Loop:
vmaxps ymm0,ymm4,YMMWORD PTR [rcx] ; clamp lower bound
vmovaps ymm2,ymm7
vminps ymm0,ymm5,ymm0 ; clamp upper bound
vmulps ymm1,ymm0,ymm0 ; x2
vbroadcastss ymm3,TanhConstants.beta_4[rax]
vfmadd231ps ymm2,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11
vfmadd213ps ymm2,ymm1,ymm8 ; p = x2 * p + alpha_9
vfmadd213ps ymm2,ymm1,ymm9 ; p = x2 * p + alpha_7
vfmadd213ps ymm2,ymm1,ymm10 ; p = x2 * p + alpha_5
vfmadd213ps ymm2,ymm1,ymm11 ; p = x2 * p + alpha_3
vfmadd213ps ymm2,ymm1,ymm12 ; p = x2 * p + alpha_1
vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4
vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2
vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0
vmulps ymm2,ymm0,ymm2 ; p = x * p
vdivps ymm0,ymm2,ymm3 ; tanh = p / q
vandps ymm0,ymm10,YMMWORD PTR [rcx]
vmovaps ymm3, ymm5
r-devulap marked this conversation as resolved.
Show resolved Hide resolved
vmovaps ymm13, ymm7
vxorps ymm1, ymm0, YMMWORD PTR [rcx]
vmulps ymm2, ymm0, ymm0
vcmpps ymm12, ymm0, ymm11, 29
vfmadd132ps ymm3, ymm6, ymm2
vfmadd132ps ymm13, ymm8, ymm2
vfmadd132ps ymm3, ymm4, ymm2
vfmadd132ps ymm2, ymm9, ymm13
vfmadd132ps ymm0, ymm0, ymm2
vdivps ymm0, ymm0, ymm3
vblendvps ymm0, ymm0, ymm4, ymm12
vxorps ymm0, ymm0, ymm1
add rcx,8*4 ; advance input by 8 elements
vmovups YMMWORD PTR [rdx],ymm0
add rdx,8*4 ; advance output by 8 elements
Expand All @@ -108,24 +102,23 @@ ProcessRemainingCount:
jz ExitKernel
neg r8
lea r10,MlasMaskMoveTableAvx+8*4
vmovups ymm2,YMMWORD PTR [r10+r8*4]
vmaskmovps ymm0,ymm2,YMMWORD PTR [rcx]
vmaxps ymm0,ymm4,ymm0 ; clamp lower bound
vminps ymm0,ymm5,ymm0 ; clamp upper bound
vmulps ymm1,ymm0,ymm0 ; x2
vbroadcastss ymm3,TanhConstants.beta_4[rax]
vfmadd231ps ymm7,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11
vfmadd213ps ymm7,ymm1,ymm8 ; p = x2 * p + alpha_9
vfmadd213ps ymm7,ymm1,ymm9 ; p = x2 * p + alpha_7
vfmadd213ps ymm7,ymm1,ymm10 ; p = x2 * p + alpha_5
vfmadd213ps ymm7,ymm1,ymm11 ; p = x2 * p + alpha_3
vfmadd213ps ymm7,ymm1,ymm12 ; p = x2 * p + alpha_1
vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4
vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2
vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0
vmulps ymm7,ymm0,ymm7 ; p = x * p
vdivps ymm0,ymm7,ymm3 ; tanh = p / q
vmaskmovps YMMWORD PTR [rdx],ymm2,ymm0
vmovups ymm15,YMMWORD PTR [r10+r8*4]
vmaskmovps ymm0,ymm15,YMMWORD PTR [rcx]
vandps ymm0,ymm10,ymm0
vmovaps ymm3, ymm5
vmovaps ymm13, ymm7
vxorps ymm1, ymm0, YMMWORD PTR [rcx]
vmulps ymm2, ymm0, ymm0
vcmpps ymm12, ymm0, ymm11, 29
vfmadd132ps ymm3, ymm6, ymm2
vfmadd132ps ymm13, ymm8, ymm2
vfmadd132ps ymm3, ymm4, ymm2
vfmadd132ps ymm2, ymm9, ymm13
vfmadd132ps ymm0, ymm0, ymm2
vdivps ymm0, ymm0, ymm3
vblendvps ymm0, ymm0, ymm4, ymm12
vxorps ymm0, ymm0, ymm1
vmaskmovps YMMWORD PTR [rdx],ymm15,ymm0

ExitKernel:
vzeroupper
Expand Down
41 changes: 41 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,17 @@
#endif
}

MLAS_FORCEINLINE
MLAS_FLOAT32X4
MlasPartialLoadFloat32x4(const float* Buffer, const int N)
{

Check warning on line 1740 in onnxruntime/core/mlas/lib/mlasi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/lib/mlasi.h:1740: { should almost always be at the end of the previous line [whitespace/braces] [4]
float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f};
for (int ii = 0; ii < N; ++ii) {
temp[ii] = Buffer[ii];
}
return MlasLoadFloat32x4(temp);
}

MLAS_FORCEINLINE
void
MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
Expand All @@ -1753,6 +1764,17 @@
#endif
}

MLAS_FORCEINLINE
void
MlasPartialStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector, const int N)
{

Check warning on line 1770 in onnxruntime/core/mlas/lib/mlasi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/lib/mlasi.h:1770: { should almost always be at the end of the previous line [whitespace/braces] [4]
float temp[4];
MlasStoreFloat32x4(temp, Vector);
for (int ii = 0; ii < N; ++ii) {
Buffer[ii] = temp[ii];
}
}

MLAS_FORCEINLINE
void
MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector)
Expand Down Expand Up @@ -2047,6 +2069,25 @@
#endif
}

MLAS_FORCEINLINE
MLAS_FLOAT32X4
MlasGreaterThanEqualFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
{

Check warning on line 2075 in onnxruntime/core/mlas/lib/mlasi.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/core/mlas/lib/mlasi.h:2075: { should almost always be at the end of the previous line [whitespace/braces] [4]
#if defined(MLAS_NEON_INTRINSICS)
return vreinterpretq_f32_u32(vcgeq_f32(Vector1, Vector2));
#elif defined(MLAS_SSE2_INTRINSICS)
return _mm_cmpge_ps(Vector1, Vector2);
#elif defined(MLAS_WASM_SIMD_INTRINSICS)
return wasm_f32x4_ge(Vector1, Vector2);
#elif defined(MLAS_VSX_INTRINSICS)
return MLAS_FLOAT32X4(vec_cmpge(Vector1, Vector2));
#elif defined(MLAS_LSX_INTRINSICS)
return (MLAS_FLOAT32X4)__lsx_vfcmp_cle_s(Vector2, Vector1);
#else
return Vector1 >= Vector2;
#endif
}

MLAS_FORCEINLINE
MLAS_FLOAT32X4
MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2)
Expand Down
144 changes: 58 additions & 86 deletions onnxruntime/core/mlas/lib/tanh.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,23 @@
//

MLAS_INTERNAL_DATA const struct {
float LowerRange;
float UpperRange;
float alpha_13;
float alpha_11;
float alpha_9;
float alpha_7;
float alpha_5;
float alpha_3;
float alpha_1;
float beta_6;
float beta_4;
float beta_2;
float beta_0;
uint32_t _nc2;
uint32_t _nc1;
uint32_t _nc0;
uint32_t _dc2;
uint32_t _dc1;
uint32_t _dc0;
uint32_t _absmask;
uint32_t _ubound;
} MlasTanhConstants = {
-9.0f,
9.0f,
-2.76076847742355e-16f,
2.00018790482477e-13f,
-8.60467152213735e-11f,
5.12229709037114e-08f,
1.48572235717979e-05f,
6.37261928875436e-04f,
4.89352455891786e-03f,
1.19825839466702e-06f,
1.18534705686654e-04f,
2.26843463243900e-03f,
4.89352518554385e-03f,
0x3c520a84, /* _nc2 */
0x3edef102, /* _nc1 */
0x3f800000, /* _nc0 */
0x3a2fc8e6, /* _dc2 */
0x3dd1c060, /* _dc1 */
0xb859e195, /* _dc0 */
0x7fffffff, /* _absmask */
0x40a00000, /* _ubound = +5.0f */
};

void
Expand Down Expand Up @@ -83,69 +73,51 @@

--*/
{
while (N >= 4) {

MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input);

Value = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.LowerRange), Value);
Value = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.UpperRange), Value);

MLAS_FLOAT32X4 ValueSquared = MlasMultiplyFloat32x4(Value, Value);

MLAS_FLOAT32X4 p;
p = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_13),
MlasBroadcastFloat32x4(MlasTanhConstants.alpha_11));
p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_9));
p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_7));
p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_5));
p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_3));
p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_1));
p = MlasMultiplyFloat32x4(p, Value);

MLAS_FLOAT32X4 q;
q = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_6),
MlasBroadcastFloat32x4(MlasTanhConstants.beta_4));
q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_2));
q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_0));

MlasStoreFloat32x4(Output, MlasDivideFloat32x4(p, q));
const MLAS_FLOAT32X4 nc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc0);

Check warning on line 76 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:76: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 nc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc1);

Check warning on line 77 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:77: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 nc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc2);

Check warning on line 78 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:78: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 dc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc0);

Check warning on line 79 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:79: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 dc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc1);

Check warning on line 80 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:80: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 dc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc2);

Check warning on line 81 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:81: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 ub = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._ubound);

Check warning on line 82 in onnxruntime/core/mlas/lib/tanh.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/core/mlas/lib/tanh.cpp:82: Using C-style cast. Use reinterpret_cast<float*>(...) instead [readability/casting] [4]
const MLAS_FLOAT32X4 absmask = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._absmask);
MLAS_FLOAT32X4 Val;

size_t count = 0;
while (count < N) {

if (N - count >= 4) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there needs a check for each iteration with this change. If N is large, the previous version can save a significant amount of instructions

Copy link
Author

@r-devulap r-devulap May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For large values of N, the CPU branch predictor should be able to predict this branch pretty easily. It will only miss at the very last iteration for the tail but when N is large, one single branch miss should hardly matter in terms of performance. It does bring the benefit of processing the entire array contained in a single loop.

Val = MlasLoadFloat32x4(Input);
}
else {
Val = MlasPartialLoadFloat32x4(Input, N - count);
}
MLAS_FLOAT32X4 ValAbs = MlasAndFloat32x4(Val, absmask);
MLAS_FLOAT32X4 boundmask = MlasGreaterThanEqualFloat32x4(ValAbs, ub);
MLAS_FLOAT32X4 signVal = MlasXorFloat32x4(ValAbs, Val);
MLAS_FLOAT32X4 ValSq = MlasMultiplyFloat32x4(ValAbs, ValAbs);

MLAS_FLOAT32X4 npoly = MlasMultiplyAddFloat32x4(nc2, ValSq, nc1);
npoly = MlasMultiplyAddFloat32x4(npoly, ValSq, nc0);

MLAS_FLOAT32X4 dpoly = MlasMultiplyAddFloat32x4(dc2, ValSq, dc1);
dpoly = MlasMultiplyAddFloat32x4(dpoly, ValSq, dc0);
dpoly = MlasMultiplyAddFloat32x4(dpoly, ValAbs, ValAbs);

MLAS_FLOAT32X4 out = MlasDivideFloat32x4(dpoly, npoly);
out = MlasBlendFloat32x4(out, nc0, boundmask);
out = MlasXorFloat32x4(out, signVal);

if (N - count >= 4) {
MlasStoreFloat32x4(Output, out);
}
else {
MlasPartialStoreFloat32x4(Output, out, N - count);
}

Input += 4;
Output += 4;
N -= 4;
}

while (N > 0) {

float Value = *Input++;

// This odd two-step process exists to ensure an input value of NaN carries through
// without modification because "std::min" and "std::max" return unreliable results
// when NaNs are involved, and it's clear from the test's reference outputs that
// they want a NaN on output whenever the input is a NaN.
float v_tmp;
v_tmp = (Value < MlasTanhConstants.LowerRange) ? MlasTanhConstants.LowerRange : Value;
Value = (v_tmp > MlasTanhConstants.UpperRange) ? MlasTanhConstants.UpperRange : v_tmp;

float ValueSquared = Value * Value;

float p;
p = ValueSquared * MlasTanhConstants.alpha_13 + MlasTanhConstants.alpha_11;
p = p * ValueSquared + MlasTanhConstants.alpha_9;
p = p * ValueSquared + MlasTanhConstants.alpha_7;
p = p * ValueSquared + MlasTanhConstants.alpha_5;
p = p * ValueSquared + MlasTanhConstants.alpha_3;
p = p * ValueSquared + MlasTanhConstants.alpha_1;
p = p * Value;

float q;
q = ValueSquared * MlasTanhConstants.beta_6 + MlasTanhConstants.beta_4;
q = q * ValueSquared + MlasTanhConstants.beta_2;
q = q * ValueSquared + MlasTanhConstants.beta_0;

*Output++ = (p / q);

N -= 1;
count += 4;
}
}

Expand Down
Loading
Loading