diff --git a/linalg/x86_64/fma/fma_mmm_i32_scalars.tmpliq b/linalg/x86_64/fma/fma_mmm_i32_scalars.tmpliq index 427dbaa774..4de1a5d49c 100644 --- a/linalg/x86_64/fma/fma_mmm_i32_scalars.tmpliq +++ b/linalg/x86_64/fma/fma_mmm_i32_scalars.tmpliq @@ -8,18 +8,16 @@ {% include "fma_mmm_ymm_scalar.tmpliq" label:"scalar_sub_flipped", op:"vpsubd", from:from, to:to, flipped: true%} {{L}}leaky_relu: - // can only use zmm12 to zmm15 + // can only use ymm12 to ymm15 // ymm15 <- alpha - vbroadcastss zmm15, dword ptr [rdi + 8] + vbroadcastss ymm15, dword ptr [rdi + 8] // ymm14 <- all zero - vpxorq zmm14, zmm14, zmm14 + vpxor ymm14, ymm14, ymm14 {% for reg in (from..to) %} - // ymm12 <- alpha * x - vpmulld zmm12, zmm{{reg}}, zmm15 - vpcmpd k1, zmm14, zmm{{reg}}, 1 // 1 means LT - vblendmps zmm{{reg}} {k1}, zmm12, zmm{{reg}} + vpmulld ymm12, ymm{{reg}}, ymm15 + vpcmpgtd ymm13, ymm14, ymm{{reg}} + vblendvps ymm{{reg}}, ymm{{reg}}, ymm12, ymm13 {% endfor %} - // select muled of orginal jmp {{L}}non_linear_loop