From c6c93092f333650f126d0a83bce3340a4a179eb4 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 24 Apr 2024 10:35:12 -0700 Subject: [PATCH 01/13] [mlas]: Speed up tanhf activation function Use Intel SVML tanhf function which speeds up tanhf computation by up to ~38%. The algorithm has a max ULP error of 1536. Benchmark numbers comparison v/s main branch is provided below (generated on TigerLake Dell XPS laptop using: https://github.com/google/benchmark/blob/main/tools/compare.py) |-----------------+---------+---------+----------+----------+---------+---------| | Benchmark | Time | CPU | Time Old | Time New | CPU Old | CPU New | |-----------------+---------+---------+----------+----------+---------+---------| | BM_Tanh/40000 | -0.3822 | -0.3825 | 15059 | 9304 | 15035 | 9283 | | BM_Tanh/80000 | -0.3845 | -0.3844 | 30055 | 18499 | 29998 | 18467 | | BM_Tanh/160000 | -0.3146 | -0.3144 | 17803 | 12203 | 17762 | 12178 | | BM_Tanh/320000 | -0.3495 | -0.3491 | 32840 | 21362 | 32724 | 21300 | | BM_Tanh/640000 | -0.3563 | -0.3568 | 62902 | 40487 | 62754 | 40361 | | BM_Tanh/1280000 | -0.3326 | -0.3333 | 128536 | 85780 | 128102 | 85408 | |-----------------+---------+---------+----------+----------+---------+---------| | OVERALL_GEOMEAN | -0.3538 | -0.3539 | 0 | 0 | 0 | 0 | |-----------------+---------+---------+----------+----------+---------+---------| --- .../core/mlas/lib/amd64/TanhKernelFma3.asm | 85 +++++------ onnxruntime/core/mlas/lib/mlasi.h | 41 +++++ onnxruntime/core/mlas/lib/tanh.cpp | 144 +++++++----------- .../core/mlas/lib/x86_64/TanhKernelFma3.S | 105 ++++++------- 4 files changed, 186 insertions(+), 189 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm index 6d94d533d72ad..8e003e8f34df0 100644 --- a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm +++ b/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm @@ -64,39 +64,33 @@ INCLUDE TransKernelCommon.inc END_PROLOGUE lea rax,MlasTanhConstants - vbroadcastss ymm4,TanhConstants.LowerRange[rax] - vbroadcastss ymm5,TanhConstants.UpperRange[rax] - vbroadcastss ymm6,TanhConstants.alpha_13[rax] - vbroadcastss ymm7,TanhConstants.alpha_11[rax] - vbroadcastss ymm8,TanhConstants.alpha_9[rax] - vbroadcastss ymm9,TanhConstants.alpha_7[rax] - vbroadcastss ymm10,TanhConstants.alpha_5[rax] - vbroadcastss ymm11,TanhConstants.alpha_3[rax] - vbroadcastss ymm12,TanhConstants.alpha_1[rax] - vbroadcastss ymm13,TanhConstants.beta_6[rax] - vbroadcastss ymm14,TanhConstants.beta_2[rax] - vbroadcastss ymm15,TanhConstants.beta_0[rax] + vbroadcastss ymm5, DWORD PTR [rax + 0] ; nc2 + vbroadcastss ymm6, DWORD PTR [rax + 4] ; nc1 + vbroadcastss ymm4, DWORD PTR [rax + 8] ; nc0 + vbroadcastss ymm7, DWORD PTR [rax + 12] ; dc2 + vbroadcastss ymm8, DWORD PTR [rax + 16] ; dc1 + vbroadcastss ymm9, DWORD PTR [rax + 20] ; dc0 + vbroadcastss ymm10, DWORD PTR [rax + 24] ; absmask + vbroadcastss ymm11, DWORD PTR [rax + 28] ; bound sub r8,8 jb ProcessRemainingCount ComputeTanhBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rcx] ; clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,TanhConstants.beta_4[rax] - vfmadd231ps ymm2,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm2,ymm1,ymm8 ; p = x2 * p + alpha_9 - vfmadd213ps ymm2,ymm1,ymm9 ; p = x2 * p + alpha_7 - vfmadd213ps ymm2,ymm1,ymm10 ; p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm11 ; p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm12 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 ; p = x * p - vdivps ymm0,ymm2,ymm3 ; tanh = p / q + vandps ymm0,ymm10,YMMWORD PTR [rcx] + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rcx] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 add rcx,8*4 ; advance input by 8 elements vmovups YMMWORD PTR [rdx],ymm0 add rdx,8*4 ; advance output by 8 elements @@ -108,24 +102,23 @@ ProcessRemainingCount: jz ExitKernel neg r8 lea r10,MlasMaskMoveTableAvx+8*4 - vmovups ymm2,YMMWORD PTR [r10+r8*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rcx] - vmaxps ymm0,ymm4,ymm0 ; clamp lower bound - vminps ymm0,ymm5,ymm0 ; clamp upper bound - vmulps ymm1,ymm0,ymm0 ; x2 - vbroadcastss ymm3,TanhConstants.beta_4[rax] - vfmadd231ps ymm7,ymm1,ymm6 ; p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm7,ymm1,ymm8 ; p = x2 * p + alpha_9 - vfmadd213ps ymm7,ymm1,ymm9 ; p = x2 * p + alpha_7 - vfmadd213ps ymm7,ymm1,ymm10 ; p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm11 ; p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm12 ; p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 ; q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 ; q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 ; q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 ; p = x * p - vdivps ymm0,ymm7,ymm3 ; tanh = p / q - vmaskmovps YMMWORD PTR [rdx],ymm2,ymm0 + vmovups ymm15,YMMWORD PTR [r10+r8*4] + vmaskmovps ymm0,ymm15,YMMWORD PTR [rcx] + vandps ymm0,ymm10,ymm0 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rcx] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 + vmaskmovps YMMWORD PTR [rdx],ymm15,ymm0 ExitKernel: vzeroupper diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 04da9ab4fd749..263c831331b98 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1734,6 +1734,17 @@ MlasLoadFloat32x4(const float* Buffer) #endif } +MLAS_FORCEINLINE +MLAS_FLOAT32X4 +MlasPartialLoadFloat32x4(const float* Buffer, const int N) +{ + float temp[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (int ii = 0; ii < N; ++ii) { + temp[ii] = Buffer[ii]; + } + return MlasLoadFloat32x4(temp); +} + MLAS_FORCEINLINE void MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) @@ -1753,6 +1764,17 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) #endif } +MLAS_FORCEINLINE +void +MlasPartialStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector, const int N) +{ + float temp[4]; + MlasStoreFloat32x4(temp, Vector); + for (int ii = 0; ii < N; ++ii) { + Buffer[ii] = temp[ii]; + } +} + MLAS_FORCEINLINE void MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) @@ -2047,6 +2069,25 @@ MlasDivideFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) #endif } +MLAS_FORCEINLINE +MLAS_FLOAT32X4 +MlasGreaterThanEqualFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) +{ +#if defined(MLAS_NEON_INTRINSICS) + return vreinterpretq_f32_u32(vcgeq_f32(Vector1, Vector2)); +#elif defined(MLAS_SSE2_INTRINSICS) + return _mm_cmpge_ps(Vector1, Vector2); +#elif defined(MLAS_WASM_SIMD_INTRINSICS) + return wasm_f32x4_ge(Vector1, Vector2); +#elif defined(MLAS_VSX_INTRINSICS) + return MLAS_FLOAT32X4(vec_cmpge(Vector1, Vector2)); +#elif defined(MLAS_LSX_INTRINSICS) + return (MLAS_FLOAT32X4)__lsx_vfcmp_cle_s(Vector2, Vector1); +#else + return Vector1 >= Vector2; +#endif +} + MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 9750337237b00..d07ab4dcb16bd 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -27,33 +27,23 @@ Module Name: // MLAS_INTERNAL_DATA const struct { - float LowerRange; - float UpperRange; - float alpha_13; - float alpha_11; - float alpha_9; - float alpha_7; - float alpha_5; - float alpha_3; - float alpha_1; - float beta_6; - float beta_4; - float beta_2; - float beta_0; + uint32_t _nc2; + uint32_t _nc1; + uint32_t _nc0; + uint32_t _dc2; + uint32_t _dc1; + uint32_t _dc0; + uint32_t _absmask; + uint32_t _ubound; } MlasTanhConstants = { - -9.0f, - 9.0f, - -2.76076847742355e-16f, - 2.00018790482477e-13f, - -8.60467152213735e-11f, - 5.12229709037114e-08f, - 1.48572235717979e-05f, - 6.37261928875436e-04f, - 4.89352455891786e-03f, - 1.19825839466702e-06f, - 1.18534705686654e-04f, - 2.26843463243900e-03f, - 4.89352518554385e-03f, + 0x3c520a84, /* _nc2 */ + 0x3edef102, /* _nc1 */ + 0x3f800000, /* _nc0 */ + 0x3a2fc8e6, /* _dc2 */ + 0x3dd1c060, /* _dc1 */ + 0xb859e195, /* _dc0 */ + 0x7fffffff, /* _absmask */ + 0x40a00000, /* _ubound = +5.0f */ }; void @@ -83,69 +73,51 @@ Return Value: --*/ { - while (N >= 4) { - - MLAS_FLOAT32X4 Value = MlasLoadFloat32x4(Input); - - Value = MlasMaximumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.LowerRange), Value); - Value = MlasMinimumFloat32x4(MlasBroadcastFloat32x4(MlasTanhConstants.UpperRange), Value); - - MLAS_FLOAT32X4 ValueSquared = MlasMultiplyFloat32x4(Value, Value); - - MLAS_FLOAT32X4 p; - p = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_13), - MlasBroadcastFloat32x4(MlasTanhConstants.alpha_11)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_9)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_7)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_5)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_3)); - p = MlasMultiplyAddFloat32x4(p, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.alpha_1)); - p = MlasMultiplyFloat32x4(p, Value); - - MLAS_FLOAT32X4 q; - q = MlasMultiplyAddFloat32x4(ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_6), - MlasBroadcastFloat32x4(MlasTanhConstants.beta_4)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_2)); - q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasTanhConstants.beta_0)); - - MlasStoreFloat32x4(Output, MlasDivideFloat32x4(p, q)); + const MLAS_FLOAT32X4 nc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc0); + const MLAS_FLOAT32X4 nc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc1); + const MLAS_FLOAT32X4 nc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc2); + const MLAS_FLOAT32X4 dc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc0); + const MLAS_FLOAT32X4 dc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc1); + const MLAS_FLOAT32X4 dc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc2); + const MLAS_FLOAT32X4 ub = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._ubound); + const MLAS_FLOAT32X4 absmask = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._absmask); + MLAS_FLOAT32X4 Val; + + size_t count = 0; + while (count < N) { + + if (N - count >= 4) { + Val = MlasLoadFloat32x4(Input); + } + else { + Val = MlasPartialLoadFloat32x4(Input, N - count); + } + MLAS_FLOAT32X4 ValAbs = MlasAndFloat32x4(Val, absmask); + MLAS_FLOAT32X4 boundmask = MlasGreaterThanEqualFloat32x4(ValAbs, ub); + MLAS_FLOAT32X4 signVal = MlasXorFloat32x4(ValAbs, Val); + MLAS_FLOAT32X4 ValSq = MlasMultiplyFloat32x4(ValAbs, ValAbs); + + MLAS_FLOAT32X4 npoly = MlasMultiplyAddFloat32x4(nc2, ValSq, nc1); + npoly = MlasMultiplyAddFloat32x4(npoly, ValSq, nc0); + + MLAS_FLOAT32X4 dpoly = MlasMultiplyAddFloat32x4(dc2, ValSq, dc1); + dpoly = MlasMultiplyAddFloat32x4(dpoly, ValSq, dc0); + dpoly = MlasMultiplyAddFloat32x4(dpoly, ValAbs, ValAbs); + + MLAS_FLOAT32X4 out = MlasDivideFloat32x4(dpoly, npoly); + out = MlasBlendFloat32x4(out, nc0, boundmask); + out = MlasXorFloat32x4(out, signVal); + + if (N - count >= 4) { + MlasStoreFloat32x4(Output, out); + } + else { + MlasPartialStoreFloat32x4(Output, out, N - count); + } Input += 4; Output += 4; - N -= 4; - } - - while (N > 0) { - - float Value = *Input++; - - // This odd two-step process exists to ensure an input value of NaN carries through - // without modification because "std::min" and "std::max" return unreliable results - // when NaNs are involved, and it's clear from the test's reference outputs that - // they want a NaN on output whenever the input is a NaN. - float v_tmp; - v_tmp = (Value < MlasTanhConstants.LowerRange) ? MlasTanhConstants.LowerRange : Value; - Value = (v_tmp > MlasTanhConstants.UpperRange) ? MlasTanhConstants.UpperRange : v_tmp; - - float ValueSquared = Value * Value; - - float p; - p = ValueSquared * MlasTanhConstants.alpha_13 + MlasTanhConstants.alpha_11; - p = p * ValueSquared + MlasTanhConstants.alpha_9; - p = p * ValueSquared + MlasTanhConstants.alpha_7; - p = p * ValueSquared + MlasTanhConstants.alpha_5; - p = p * ValueSquared + MlasTanhConstants.alpha_3; - p = p * ValueSquared + MlasTanhConstants.alpha_1; - p = p * Value; - - float q; - q = ValueSquared * MlasTanhConstants.beta_6 + MlasTanhConstants.beta_4; - q = q * ValueSquared + MlasTanhConstants.beta_2; - q = q * ValueSquared + MlasTanhConstants.beta_0; - - *Output++ = (p / q); - - N -= 1; + count += 4; } } diff --git a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S index d7c2fd1c6e1dd..74f8325c77a26 100644 --- a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S @@ -20,6 +20,7 @@ Abstract: #include "asmmacro.h" #include "TransKernelCommon.h" + .intel_syntax noprefix .text @@ -48,71 +49,61 @@ Return Value: FUNCTION_ENTRY MlasComputeTanhF32KernelFma3 lea rax,C_UNDERSCORE(MlasTanhConstants)[rip] - vbroadcastss ymm4,.LTanhConstants_LowerRange[rax] - vbroadcastss ymm5,.LTanhConstants_UpperRange[rax] - vbroadcastss ymm6,.LTanhConstants_alpha_13[rax] - vbroadcastss ymm7,.LTanhConstants_alpha_11[rax] - vbroadcastss ymm8,.LTanhConstants_alpha_9[rax] - vbroadcastss ymm9,.LTanhConstants_alpha_7[rax] - vbroadcastss ymm10,.LTanhConstants_alpha_5[rax] - vbroadcastss ymm11,.LTanhConstants_alpha_3[rax] - vbroadcastss ymm12,.LTanhConstants_alpha_1[rax] - vbroadcastss ymm13,.LTanhConstants_beta_6[rax] - vbroadcastss ymm14,.LTanhConstants_beta_2[rax] - vbroadcastss ymm15,.LTanhConstants_beta_0[rax] - + vbroadcastss ymm5, 0x00[rax] // nc2 + vbroadcastss ymm6, 0x04[rax] // nc1 + vbroadcastss ymm4, 0x08[rax] // nc0 + vbroadcastss ymm7, 0x0c[rax] // dc2 + vbroadcastss ymm8, 0x10[rax] // dc1 + vbroadcastss ymm9, 0x14[rax] // dc0 + vbroadcastss ymm10, 0x18[rax] // absmask + vbroadcastss ymm11, 0x1c[rax] // bound sub rdx,8 jb .LProcessRemainingCount .LComputeTanhBy8Loop: - vmaxps ymm0,ymm4,YMMWORD PTR [rdi] # clamp lower bound - vmovaps ymm2,ymm7 - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LTanhConstants_beta_4[rax] - vfmadd231ps ymm2,ymm1,ymm6 # p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm2,ymm1,ymm8 # p = x2 * p + alpha_9 - vfmadd213ps ymm2,ymm1,ymm9 # p = x2 * p + alpha_7 - vfmadd213ps ymm2,ymm1,ymm10 # p = x2 * p + alpha_5 - vfmadd213ps ymm2,ymm1,ymm11 # p = x2 * p + alpha_3 - vfmadd213ps ymm2,ymm1,ymm12 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 # q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm2,ymm0,ymm2 # p = x * p - vdivps ymm0,ymm2,ymm3 # tanh = p / q - add rdi,8*4 # advance input by 8 elements - vmovups YMMWORD PTR [rsi],ymm0 - add rsi,8*4 # advance output by 8 elements - sub rdx,8 + vandps ymm0,ymm10,YMMWORD PTR [rdi] + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rdi] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 + add rdi,8*4 # advance input by 8 elements + vmovups YMMWORD PTR [rsi],ymm0 + add rsi,8*4 # advance output by 8 elements + sub rdx,8 jae .LComputeTanhBy8Loop .LProcessRemainingCount: - add rdx,8 # correct for over-subtract above - jz .LExitKernel - neg rdx - lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] - vmovups ymm2,YMMWORD PTR [r10+rdx*4] - vmaskmovps ymm0,ymm2,YMMWORD PTR [rdi] - vmaxps ymm0,ymm4,ymm0 # clamp lower bound - vminps ymm0,ymm5,ymm0 # clamp upper bound - vmulps ymm1,ymm0,ymm0 # x2 - vbroadcastss ymm3,.LTanhConstants_beta_4[rax] - vfmadd231ps ymm7,ymm1,ymm6 # p = x2 * alpha_13 + alpha_11 - vfmadd213ps ymm7,ymm1,ymm8 # p = x2 * p + alpha_9 - vfmadd213ps ymm7,ymm1,ymm9 # p = x2 * p + alpha_7 - vfmadd213ps ymm7,ymm1,ymm10 # p = x2 * p + alpha_5 - vfmadd213ps ymm7,ymm1,ymm11 # p = x2 * p + alpha_3 - vfmadd213ps ymm7,ymm1,ymm12 # p = x2 * p + alpha_1 - vfmadd231ps ymm3,ymm1,ymm13 # q = x2 * beta_6 + beta_4 - vfmadd213ps ymm3,ymm1,ymm14 # q = x2 * q + beta_2 - vfmadd213ps ymm3,ymm1,ymm15 # q = x2 * q + beta_0 - vmulps ymm7,ymm0,ymm7 # p = x * p - vdivps ymm0,ymm7,ymm3 # tanh = p / q - vmaskmovps YMMWORD PTR [rsi],ymm2,ymm0 + add rdx,8 # correct for over-subtract above + jz .LExitKernel + neg rdx + lea r10,C_UNDERSCORE(MlasMaskMoveTableAvx)[rip+8*4] + vmovups ymm15,YMMWORD PTR [r10+rdx*4] + vmaskmovps ymm0,ymm15,YMMWORD PTR [rdi] + vandps ymm0,ymm10,ymm0 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rdi] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 + vmaskmovps YMMWORD PTR [rsi],ymm15,ymm0 .LExitKernel: vzeroupper ret - - .end From 79a1b840f344c2db23616175e41fb1a7ed3b4637 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 3 May 2024 15:24:44 -0700 Subject: [PATCH 02/13] TEST: Adjust hard coded expected values based on new tanhf algorithm --- onnxruntime/test/framework/inference_session_test.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index d0520ebbcba5a..b0549004f99de 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1851,11 +1851,11 @@ TEST(InferenceSessionTests, TestTruncatedSequence) { 0.56804454f, 0.92559665f, 0.07103606f}; std::vector Y_dims = {5, 1, 2}; - std::vector Y_data = {-1.1730184e-04f, -3.1204990e-04f, - -2.9978977e-04f, -1.0602647e-03f, - -3.8115133e-04f, -2.0684483e-03f, - -2.5120965e-04f, -2.9920202e-03f, - 3.0980256e-05f, -3.5933927e-03f}; + std::vector Y_data = {-1.1725388e-04f, -3.1192770e-04f, + -2.9967332e-04f, -1.0598592e-03f, + -3.8101958e-04f, -2.0676597e-03f, + -2.5116475e-04f, -2.9908563e-03f, + 3.0859868e-05f, -3.5919433e-03f}; OrtValue ml_value; CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], X_dims, X, &ml_value); From cab3a83233777cc1e5e5d926a21a61fdb717af9f Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 9 May 2024 12:26:42 -0700 Subject: [PATCH 03/13] [maint] fix linter failures --- onnxruntime/core/mlas/lib/tanh.cpp | 47 ++++++++++++++---------------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index d07ab4dcb16bd..199d9ab30a423 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -36,14 +36,14 @@ MLAS_INTERNAL_DATA const struct { uint32_t _absmask; uint32_t _ubound; } MlasTanhConstants = { - 0x3c520a84, /* _nc2 */ - 0x3edef102, /* _nc1 */ - 0x3f800000, /* _nc0 */ - 0x3a2fc8e6, /* _dc2 */ - 0x3dd1c060, /* _dc1 */ - 0xb859e195, /* _dc0 */ - 0x7fffffff, /* _absmask */ - 0x40a00000, /* _ubound = +5.0f */ + 0x3c520a84, /* _nc2 */ + 0x3edef102, /* _nc1 */ + 0x3f800000, /* _nc0 */ + 0x3a2fc8e6, /* _dc2 */ + 0x3dd1c060, /* _dc1 */ + 0xb859e195, /* _dc0 */ + 0x7fffffff, /* _absmask */ + 0x40a00000, /* _ubound = +5.0f */ }; void @@ -73,29 +73,27 @@ Return Value: --*/ { - const MLAS_FLOAT32X4 nc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc0); - const MLAS_FLOAT32X4 nc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc1); - const MLAS_FLOAT32X4 nc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._nc2); - const MLAS_FLOAT32X4 dc0 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc0); - const MLAS_FLOAT32X4 dc1 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc1); - const MLAS_FLOAT32X4 dc2 = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._dc2); - const MLAS_FLOAT32X4 ub = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._ubound); - const MLAS_FLOAT32X4 absmask = MlasBroadcastFloat32x4((float*)&MlasTanhConstants._absmask); + const MLAS_FLOAT32X4 nc0 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._nc0)); + const MLAS_FLOAT32X4 nc1 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._nc1)); + const MLAS_FLOAT32X4 nc2 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._nc2)); + const MLAS_FLOAT32X4 dc0 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._dc0)); + const MLAS_FLOAT32X4 dc1 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._dc1)); + const MLAS_FLOAT32X4 dc2 = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._dc2)); + const MLAS_FLOAT32X4 ub = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._ubound)); + const MLAS_FLOAT32X4 absmask = MlasBroadcastFloat32x4(reinterpret_cast(&MlasTanhConstants._absmask)); MLAS_FLOAT32X4 Val; size_t count = 0; while (count < N) { - if (N - count >= 4) { Val = MlasLoadFloat32x4(Input); - } - else { + } else { Val = MlasPartialLoadFloat32x4(Input, N - count); } - MLAS_FLOAT32X4 ValAbs = MlasAndFloat32x4(Val, absmask); - MLAS_FLOAT32X4 boundmask = MlasGreaterThanEqualFloat32x4(ValAbs, ub); - MLAS_FLOAT32X4 signVal = MlasXorFloat32x4(ValAbs, Val); - MLAS_FLOAT32X4 ValSq = MlasMultiplyFloat32x4(ValAbs, ValAbs); + MLAS_FLOAT32X4 ValAbs = MlasAndFloat32x4(Val, absmask); + MLAS_FLOAT32X4 boundmask = MlasGreaterThanEqualFloat32x4(ValAbs, ub); + MLAS_FLOAT32X4 signVal = MlasXorFloat32x4(ValAbs, Val); + MLAS_FLOAT32X4 ValSq = MlasMultiplyFloat32x4(ValAbs, ValAbs); MLAS_FLOAT32X4 npoly = MlasMultiplyAddFloat32x4(nc2, ValSq, nc1); npoly = MlasMultiplyAddFloat32x4(npoly, ValSq, nc0); @@ -110,8 +108,7 @@ Return Value: if (N - count >= 4) { MlasStoreFloat32x4(Output, out); - } - else { + } else { MlasPartialStoreFloat32x4(Output, out, N - count); } From accf576ce09c80884dd1687aebd31f3f27fe4627 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 9 May 2024 12:32:47 -0700 Subject: [PATCH 04/13] Use static_cast to avoid compiler error --- onnxruntime/core/mlas/lib/tanh.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/mlas/lib/tanh.cpp b/onnxruntime/core/mlas/lib/tanh.cpp index 199d9ab30a423..64d3c9e87796e 100644 --- a/onnxruntime/core/mlas/lib/tanh.cpp +++ b/onnxruntime/core/mlas/lib/tanh.cpp @@ -88,7 +88,7 @@ Return Value: if (N - count >= 4) { Val = MlasLoadFloat32x4(Input); } else { - Val = MlasPartialLoadFloat32x4(Input, N - count); + Val = MlasPartialLoadFloat32x4(Input, static_cast(N - count)); } MLAS_FLOAT32X4 ValAbs = MlasAndFloat32x4(Val, absmask); MLAS_FLOAT32X4 boundmask = MlasGreaterThanEqualFloat32x4(ValAbs, ub); @@ -109,7 +109,7 @@ Return Value: if (N - count >= 4) { MlasStoreFloat32x4(Output, out); } else { - MlasPartialStoreFloat32x4(Output, out, N - count); + MlasPartialStoreFloat32x4(Output, out, static_cast(N - count)); } Input += 4; From cc6abd0c8d0e998a88e501ec825ade64898f301a Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 9 May 2024 12:42:06 -0700 Subject: [PATCH 05/13] Modify hard coded constants in tanhf activation test --- onnxruntime/test/mlas/unittest/test_activation.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/mlas/unittest/test_activation.cpp b/onnxruntime/test/mlas/unittest/test_activation.cpp index a4334c6c80477..3ffcf525c3e30 100644 --- a/onnxruntime/test/mlas/unittest/test_activation.cpp +++ b/onnxruntime/test/mlas/unittest/test_activation.cpp @@ -77,7 +77,7 @@ class MlasActivationTest : public MlasTestBase { {0x3e800000}, {0x3e800000}, {0x3e800000}, - {0x3e7acbf5}, + {0x3e7ac9d6}, {0x3f0feacc}, {0x3e800000}, {0x3e2e147b}, @@ -86,7 +86,7 @@ class MlasActivationTest : public MlasTestBase { {0xbe800000}, {0x00000000}, {0xbd4ccccd}, - {0xbe7acbf5}, + {0xbe7ac9d6}, {0x3ee02a67}, {0x00000000}, {0x3d8f5c28}, @@ -95,7 +95,7 @@ class MlasActivationTest : public MlasTestBase { {0x40800000}, {0x40800000}, {0x40800000}, - {0x3f7fd40a}, + {0x3f7fd390}, {0x3f7b6541}, {0x40800000}, {0x3f6b851f}, @@ -104,7 +104,7 @@ class MlasActivationTest : public MlasTestBase { {0xc0800000}, {0x00000000}, {0xbf4ccccd}, - {0xbf7fd40a}, + {0xbf7fd390}, {0x3c9357e0}, {0x00000000}, {0x00000000}, From 6bd807d034801cff45496472abeb53b028f1b398 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 9 May 2024 12:51:21 -0700 Subject: [PATCH 06/13] Increase fp16 error tol for models_opset7_fp16_coreml_FNSCandy --- onnxruntime/test/providers/cpu/model_tests.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index dcb592a4a254e..2aeecb620c40e 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -97,6 +97,14 @@ TEST_P(ModelTest, Run) { } } + // increase tol for models_opset7_fp16_coreml_FNSCandy test on cpu. See + // https://github.com/microsoft/onnxruntime/pull/20612 + if (model_path.find(ORT_TSTR("models_opset7_fp16_coreml_FNSCandy")) > 0) { + if (provider_name == "cpu") { + per_sample_tolerance = 1e-2; + relative_per_sample_tolerance = 1e-2; + } + } std::unique_ptr model_info = std::make_unique(model_path.c_str()); if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || From f3fbe1a2649a348a8577b3cd34f52294d0aaf9e4 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 13 May 2024 13:58:46 -0700 Subject: [PATCH 07/13] Adjust tolerance for new tanhf algorithm --- onnxruntime/test/python/transformers/parity_utilities.py | 2 +- orttraining/orttraining/test/gradient/gradient_ops_test.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/python/transformers/parity_utilities.py b/onnxruntime/test/python/transformers/parity_utilities.py index d7f79304d2d2b..196ac1c50dcbc 100644 --- a/onnxruntime/test/python/transformers/parity_utilities.py +++ b/onnxruntime/test/python/transformers/parity_utilities.py @@ -220,7 +220,7 @@ def run_parity( ort_outputs = onnxruntime_inference(ort_session, input_hidden_states) if tolerance is None: - tolerance = 2e-03 if float16 else 1e-05 + tolerance = 2e-03 if float16 else 2e-04 is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, atol=tolerance, verbose=verbose) max_diffs.append(max_diff) if is_all_close: diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 94ca96c68f2ce..7c93cbd990297 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -154,7 +154,7 @@ void UnaryOpGradientTest(const std::string& op_type, const std::string& domain = std::vector>* execution_providers = nullptr, std::function* transformer = nullptr, const std::vector& attributes = {}, - float error_tolerance = 1e-3f) { + float error_tolerance = 2e-3f) { TensorShape shape({2, 3, 4}); TensorInfo x_info{shape, true, transformer}; float max_error; From 01b0007477d97e56b80ac7da829d14f234fec375 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 13 May 2024 14:03:40 -0700 Subject: [PATCH 08/13] Retab to replace tabs with spaces --- .../core/mlas/lib/amd64/TanhKernelFma3.asm | 52 +++++++++---------- .../core/mlas/lib/x86_64/TanhKernelFma3.S | 52 +++++++++---------- 2 files changed, 52 insertions(+), 52 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm b/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm index 8e003e8f34df0..d29057d85af96 100644 --- a/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm +++ b/onnxruntime/core/mlas/lib/amd64/TanhKernelFma3.asm @@ -78,19 +78,19 @@ INCLUDE TransKernelCommon.inc ComputeTanhBy8Loop: vandps ymm0,ymm10,YMMWORD PTR [rcx] - vmovaps ymm3, ymm5 - vmovaps ymm13, ymm7 - vxorps ymm1, ymm0, YMMWORD PTR [rcx] - vmulps ymm2, ymm0, ymm0 - vcmpps ymm12, ymm0, ymm11, 29 - vfmadd132ps ymm3, ymm6, ymm2 - vfmadd132ps ymm13, ymm8, ymm2 - vfmadd132ps ymm3, ymm4, ymm2 - vfmadd132ps ymm2, ymm9, ymm13 - vfmadd132ps ymm0, ymm0, ymm2 - vdivps ymm0, ymm0, ymm3 - vblendvps ymm0, ymm0, ymm4, ymm12 - vxorps ymm0, ymm0, ymm1 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rcx] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 add rcx,8*4 ; advance input by 8 elements vmovups YMMWORD PTR [rdx],ymm0 add rdx,8*4 ; advance output by 8 elements @@ -105,19 +105,19 @@ ProcessRemainingCount: vmovups ymm15,YMMWORD PTR [r10+r8*4] vmaskmovps ymm0,ymm15,YMMWORD PTR [rcx] vandps ymm0,ymm10,ymm0 - vmovaps ymm3, ymm5 - vmovaps ymm13, ymm7 - vxorps ymm1, ymm0, YMMWORD PTR [rcx] - vmulps ymm2, ymm0, ymm0 - vcmpps ymm12, ymm0, ymm11, 29 - vfmadd132ps ymm3, ymm6, ymm2 - vfmadd132ps ymm13, ymm8, ymm2 - vfmadd132ps ymm3, ymm4, ymm2 - vfmadd132ps ymm2, ymm9, ymm13 - vfmadd132ps ymm0, ymm0, ymm2 - vdivps ymm0, ymm0, ymm3 - vblendvps ymm0, ymm0, ymm4, ymm12 - vxorps ymm0, ymm0, ymm1 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rcx] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 vmaskmovps YMMWORD PTR [rdx],ymm15,ymm0 ExitKernel: diff --git a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S index 74f8325c77a26..8ea9913538a11 100644 --- a/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S +++ b/onnxruntime/core/mlas/lib/x86_64/TanhKernelFma3.S @@ -62,19 +62,19 @@ Return Value: .LComputeTanhBy8Loop: vandps ymm0,ymm10,YMMWORD PTR [rdi] - vmovaps ymm3, ymm5 - vmovaps ymm13, ymm7 - vxorps ymm1, ymm0, YMMWORD PTR [rdi] - vmulps ymm2, ymm0, ymm0 - vcmpps ymm12, ymm0, ymm11, 29 - vfmadd132ps ymm3, ymm6, ymm2 - vfmadd132ps ymm13, ymm8, ymm2 - vfmadd132ps ymm3, ymm4, ymm2 - vfmadd132ps ymm2, ymm9, ymm13 - vfmadd132ps ymm0, ymm0, ymm2 - vdivps ymm0, ymm0, ymm3 - vblendvps ymm0, ymm0, ymm4, ymm12 - vxorps ymm0, ymm0, ymm1 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rdi] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 add rdi,8*4 # advance input by 8 elements vmovups YMMWORD PTR [rsi],ymm0 add rsi,8*4 # advance output by 8 elements @@ -89,19 +89,19 @@ Return Value: vmovups ymm15,YMMWORD PTR [r10+rdx*4] vmaskmovps ymm0,ymm15,YMMWORD PTR [rdi] vandps ymm0,ymm10,ymm0 - vmovaps ymm3, ymm5 - vmovaps ymm13, ymm7 - vxorps ymm1, ymm0, YMMWORD PTR [rdi] - vmulps ymm2, ymm0, ymm0 - vcmpps ymm12, ymm0, ymm11, 29 - vfmadd132ps ymm3, ymm6, ymm2 - vfmadd132ps ymm13, ymm8, ymm2 - vfmadd132ps ymm3, ymm4, ymm2 - vfmadd132ps ymm2, ymm9, ymm13 - vfmadd132ps ymm0, ymm0, ymm2 - vdivps ymm0, ymm0, ymm3 - vblendvps ymm0, ymm0, ymm4, ymm12 - vxorps ymm0, ymm0, ymm1 + vmovaps ymm3, ymm5 + vmovaps ymm13, ymm7 + vxorps ymm1, ymm0, YMMWORD PTR [rdi] + vmulps ymm2, ymm0, ymm0 + vcmpps ymm12, ymm0, ymm11, 29 + vfmadd132ps ymm3, ymm6, ymm2 + vfmadd132ps ymm13, ymm8, ymm2 + vfmadd132ps ymm3, ymm4, ymm2 + vfmadd132ps ymm2, ymm9, ymm13 + vfmadd132ps ymm0, ymm0, ymm2 + vdivps ymm0, ymm0, ymm3 + vblendvps ymm0, ymm0, ymm4, ymm12 + vxorps ymm0, ymm0, ymm1 vmaskmovps YMMWORD PTR [rsi],ymm15,ymm0 .LExitKernel: From 5329f70d3b71f9261add1b1307f6c99364b8b624 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Fri, 17 May 2024 10:20:17 -0700 Subject: [PATCH 09/13] Use a better substring to filter fp16_coreml_FNS_Candy model --- onnxruntime/test/providers/cpu/model_tests.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 2aeecb620c40e..b55a7cc724b92 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -99,7 +99,7 @@ TEST_P(ModelTest, Run) { // increase tol for models_opset7_fp16_coreml_FNSCandy test on cpu. See // https://github.com/microsoft/onnxruntime/pull/20612 - if (model_path.find(ORT_TSTR("models_opset7_fp16_coreml_FNSCandy")) > 0) { + if (model_path.find(ORT_TSTR("fp16_coreml_FNS")) > 0) { if (provider_name == "cpu") { per_sample_tolerance = 1e-2; relative_per_sample_tolerance = 1e-2; From 7e1bf2c493c5476af0efa1ac9fb02eb502b64f56 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 22 May 2024 09:06:36 -0700 Subject: [PATCH 10/13] Add rel and abs error for LSTM.BackwardCompute test --- onnxruntime/test/providers/cpu/model_tests.cc | 6 ++---- .../test/training_ops/cpu/rnn/lstm_test.cc | 20 +++++++++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index b55a7cc724b92..b70801d801211 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -100,10 +100,8 @@ TEST_P(ModelTest, Run) { // increase tol for models_opset7_fp16_coreml_FNSCandy test on cpu. See // https://github.com/microsoft/onnxruntime/pull/20612 if (model_path.find(ORT_TSTR("fp16_coreml_FNS")) > 0) { - if (provider_name == "cpu") { - per_sample_tolerance = 1e-2; - relative_per_sample_tolerance = 1e-2; - } + per_sample_tolerance = 1e-2; + relative_per_sample_tolerance = 1e-2; } std::unique_ptr model_info = std::make_unique(model_path.c_str()); diff --git a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc index f27825f276cb1..d51bbbf7c7eac 100644 --- a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc @@ -165,7 +165,10 @@ TEST(LSTMTest, BackwardCompute) { test.AddOutput( "dX", {sequence_length, batch_size, input_size}, - {9.02288f, 9.77558f, 4.23378f, 4.6432f, 1.92046f, 2.09879f, 1.87627f, 2.06453f}); + {9.02288f, 9.77558f, 4.23378f, 4.6432f, 1.92046f, 2.09879f, 1.87627f, 2.06453f}, + false, + 1e-02, + 1e-02); test.AddOutput( "dW", {directions, 4 * hidden_size, input_size}, {0.030251f, 0.0453894f, @@ -179,7 +182,10 @@ TEST(LSTMTest, BackwardCompute) { 0.0586922f, 0.0936911f, 0.0477309f, 0.0758698f, 0.230594f, 0.623739f, - 0.231839f, 0.440448f}); + 0.231839f, 0.440448f}, + false, + 1e-02, + 1e-02); test.AddOutput( "dR", {directions, 4 * hidden_size, hidden_size}, {0.000595693f, 0.000601335f, 0.000602285f, @@ -193,12 +199,18 @@ TEST(LSTMTest, BackwardCompute) { 0.0132026f, 0.0133107f, 0.0133275f, 0.00346555f, 0.00349401f, 0.00349843f, 0.0081494f, 0.00821467f, 0.00822465f, - 0.0104138f, 0.0104568f, 0.0104593f}); + 0.0104138f, 0.0104568f, 0.0104593f}, + false, + 1e-02, + 1e-02); test.AddOutput( "dB", {directions, 8 * hidden_size}, {0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f, 0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, - 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f}); + 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f}, + false, + 1e-02, + 1e-02); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); } From 31fa70f6b3d0e9cfd3995f44ccd210ff1440bc16 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 22 May 2024 21:44:21 -0700 Subject: [PATCH 11/13] Run clang-format on lstm_test.cc --- .../test/training_ops/cpu/rnn/lstm_test.cc | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc index d51bbbf7c7eac..0d9af91f59b68 100644 --- a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc @@ -166,9 +166,9 @@ TEST(LSTMTest, BackwardCompute) { test.AddOutput( "dX", {sequence_length, batch_size, input_size}, {9.02288f, 9.77558f, 4.23378f, 4.6432f, 1.92046f, 2.09879f, 1.87627f, 2.06453f}, - false, - 1e-02, - 1e-02); + false, + 1e-02, + 1e-02); test.AddOutput( "dW", {directions, 4 * hidden_size, input_size}, {0.030251f, 0.0453894f, @@ -183,9 +183,9 @@ TEST(LSTMTest, BackwardCompute) { 0.0477309f, 0.0758698f, 0.230594f, 0.623739f, 0.231839f, 0.440448f}, - false, - 1e-02, - 1e-02); + false, + 1e-02, + 1e-02); test.AddOutput( "dR", {directions, 4 * hidden_size, hidden_size}, {0.000595693f, 0.000601335f, 0.000602285f, @@ -200,17 +200,17 @@ TEST(LSTMTest, BackwardCompute) { 0.00346555f, 0.00349401f, 0.00349843f, 0.0081494f, 0.00821467f, 0.00822465f, 0.0104138f, 0.0104568f, 0.0104593f}, - false, - 1e-02, - 1e-02); + false, + 1e-02, + 1e-02); test.AddOutput( "dB", {directions, 8 * hidden_size}, {0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f, 0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f}, - false, - 1e-02, - 1e-02); + false, + 1e-02, + 1e-02); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); } From 1bf2d91c6fb72a0771bcd430b360559fd5b868b4 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Thu, 30 May 2024 22:05:11 -0700 Subject: [PATCH 12/13] Mark literals explicitly as float --- .../test/training_ops/cpu/rnn/lstm_test.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc index 0d9af91f59b68..5df5d3c36f120 100644 --- a/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/rnn/lstm_test.cc @@ -167,8 +167,8 @@ TEST(LSTMTest, BackwardCompute) { "dX", {sequence_length, batch_size, input_size}, {9.02288f, 9.77558f, 4.23378f, 4.6432f, 1.92046f, 2.09879f, 1.87627f, 2.06453f}, false, - 1e-02, - 1e-02); + 1e-02f, + 1e-02f); test.AddOutput( "dW", {directions, 4 * hidden_size, input_size}, {0.030251f, 0.0453894f, @@ -184,8 +184,8 @@ TEST(LSTMTest, BackwardCompute) { 0.230594f, 0.623739f, 0.231839f, 0.440448f}, false, - 1e-02, - 1e-02); + 1e-02f, + 1e-02f); test.AddOutput( "dR", {directions, 4 * hidden_size, hidden_size}, {0.000595693f, 0.000601335f, 0.000602285f, @@ -201,16 +201,16 @@ TEST(LSTMTest, BackwardCompute) { 0.0081494f, 0.00821467f, 0.00822465f, 0.0104138f, 0.0104568f, 0.0104593f}, false, - 1e-02, - 1e-02); + 1e-02f, + 1e-02f); test.AddOutput( "dB", {directions, 8 * hidden_size}, {0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f, 0.00756918f, 0.00939937f, 0.0522473f, 0.117724f, 0.444431f, 0.579753f, 0.00210701f, 0.00200243f, 0.0174995f, 0.0140694f, 0.196573f, 0.104304f}, false, - 1e-02, - 1e-02); + 1e-02f, + 1e-02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); } From f7778140e0b2f36eb750793083f4bad58e8796ec Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Mon, 3 Jun 2024 14:37:29 -0700 Subject: [PATCH 13/13] Change tolerance of ModelTest.Run to 1e-02 --- winml/test/model/model_tests.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 27d74d7d6b034..dca53d6abe919 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -245,7 +245,7 @@ static std::vector GetAllTestCases() { WINML_EXPECT_NO_THROW(LoadTests( dataDirs, whitelistedTestCases, - TestTolerances(1e-3, 1e-3, {}, {}), + TestTolerances(1e-2, 1e-2, {}, {}), allDisabledTests, std::move(broken_tests), std::move(broken_tests_keyword_set),