Skip to content

Commit

Permalink
x64: eltwise_injector: use aligned stack for vmms
Browse files Browse the repository at this point in the history
  • Loading branch information
nivas-x86 authored and tprimak committed Jan 5, 2024
1 parent c0ae38c commit 2e3c94c
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 67 deletions.
183 changes: 117 additions & 66 deletions src/cpu/x64/injectors/jit_uni_eltwise_injector.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2019-2023 Intel Corporation
* Copyright 2019-2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,6 +54,17 @@ bool is_supported(cpu_isa_t isa, alg_kind_t alg) {

using namespace Xbyak;

template <cpu_isa_t isa, typename Wmm>
bool jit_uni_eltwise_injector_f32<isa, Wmm>::need_vmm_stack_ptr() {
return op_vecs_count() > 0 || vecs_to_preserve > 0;
}

template <cpu_isa_t isa, typename Wmm>
size_t jit_uni_eltwise_injector_f32<isa, Wmm>::get_stack_vmm_space() {
return (save_state_ * preserve_vmm_ * vecs_to_preserve + op_vecs_count())
* vlen;
}

template <cpu_isa_t isa, typename Wmm>
void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble(
const injector_utils::vmm_index_set_t &vmm_idxs) {
Expand Down Expand Up @@ -108,18 +119,40 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble(
}
assert(preserved_gprs_count == aux_gprs_count());

if (need_vmm_stack_ptr()) {
reg_vmm_stack_ptr_ = Reg64(preserved_gpr_idxs[0]);
}

if (save_state_) {
if (preserve_p_table_) h->push(p_table);
for (size_t i = 0; i < preserved_gprs_count; ++i)
h->push(Reg64(preserved_gpr_idxs[i]));
}

if (preserve_vmm_) {
if (preserved_vecs_count)
h->sub(h->rsp, preserved_vecs_count * vlen);
const auto stack_vmm_space = get_stack_vmm_space();
if (stack_vmm_space) {
// - Let's align stack space used for vmm spilling at runtime. To do
// this we pad the rsp, and allocate the pre-estimated space required.
// - To keep regular gpr spilling as-is we use another register to track
// the vmm_stack_space.
// - Finally, the original stack pointer (rsp) is stored just above the
// vmm stack space, to revert back to address before padding.
h->mov(reg_vmm_stack_ptr_, h->rsp);
h->sub(h->rsp, 8);
const uint32_t mask = ~(static_cast<uint32_t>(vlen) - 1);
h->and_(h->rsp, mask);
h->mov(ptr[h->rsp], reg_vmm_stack_ptr_);
h->sub(h->rsp, stack_vmm_space);
h->mov(reg_vmm_stack_ptr_, h->rsp);
}

if (save_state_) {
if (preserve_vmm_) {
for (size_t i = 0; i < preserved_vecs_count; ++i)
h->uni_vmovups(
h->ptr[h->rsp + i * vlen], Vmm(preserved_vec_idxs[i]));
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_ + i * vlen],
Vmm(preserved_vec_idxs[i]));
if (preserved_vecs_count)
h->add(reg_vmm_stack_ptr_, preserved_vecs_count * vlen);
}
load_table_addr();
}
Expand All @@ -136,19 +169,19 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble_tail(
const int idx_off = vecs_to_preserve - tail_vecs_to_preserve;

if (save_state_) {
if (idx_off) h->add(h->rsp, idx_off * vlen);

for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
h->uni_vmovups(Vmm(preserved_vec_idxs[idx_off + i]),
h->ptr[h->rsp + i * vlen]);
h->ptr[reg_vmm_stack_ptr_
+ (i - tail_vecs_to_preserve) * vlen]);
}

for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
preserved_vec_idxs[idx_off + i] += tail_vecs_to_preserve;

if (save_state_ && preserve_vmm_) {
for (size_t i = 0; i < tail_vecs_to_preserve; ++i)
h->uni_vmovups(h->ptr[h->rsp + i * vlen],
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_
+ (i - tail_vecs_to_preserve) * vlen],
Vmm(preserved_vec_idxs[idx_off + i]));

if (idx_off) h->sub(h->rsp, idx_off * vlen);
Expand All @@ -160,16 +193,21 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_preamble_tail(
template <cpu_isa_t isa, typename Wmm>
void jit_uni_eltwise_injector_f32<isa, Wmm>::injector_postamble() {
using namespace Xbyak::util;
if (!save_state_) return;
const int stack_vmm_space = get_stack_vmm_space();

if (preserve_vmm_) {
if (save_state_ && preserve_vmm_) {
for (size_t i = 0; i < preserved_vecs_count; ++i)
h->uni_vmovups(
Vmm(preserved_vec_idxs[i]), h->ptr[h->rsp + i * vlen]);

if (preserved_vecs_count) h->add(h->rsp, preserved_vecs_count * vlen);
h->uni_vmovups(Vmm(preserved_vec_idxs[i]),
h->ptr[reg_vmm_stack_ptr_
+ (i - preserved_vecs_count) * vlen]);

if (preserved_vecs_count)
h->mov(h->rsp, ptr[reg_vmm_stack_ptr_ + op_vecs_count() * vlen]);
} else if (stack_vmm_space) {
h->mov(h->rsp, ptr[reg_vmm_stack_ptr_ + stack_vmm_space]);
}

if (!save_state_) return;
for (int i = aux_gprs_count() - 1; i >= 0; --i)
h->pop(Reg64(preserved_gpr_idxs[i]));
if (preserve_p_table_) h->pop(p_table);
Expand Down Expand Up @@ -369,7 +407,7 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::tanh_compute_vector_fwd(
if (isa == sse41 || isa == avx) {
assert(aux_gprs_count() >= XMM_float_lanes_count);
for (int i = 0; i < XMM_float_lanes_count; i++)
gpr_idx[i] = Reg64(preserved_gpr_idxs[i]);
gpr_idx[i] = Reg64(preserved_gpr_idxs[i + need_vmm_stack_ptr()]);
}

// We split the positive domain in 33 intervals:
Expand Down Expand Up @@ -561,14 +599,12 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_fwd(
h->uni_vmulps(vmm_src, vmm_src, table_val(gelu_tanh_sqrt_two_over_pi));

// save x on stack as tanh uses vmm_aux0
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_aux0);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_aux0);

// compute tanh(G(x))
tanh_compute_vector_fwd(vmm_src);

h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux0, h->ptr[reg_vmm_stack_ptr_]);

// compute 0.5 * x * (1 + tanh(G(x)))
h->uni_vaddps(vmm_src, vmm_src, table_val(one));
Expand Down Expand Up @@ -820,15 +856,13 @@ template <cpu_isa_t isa, typename Wmm>
void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_fwd(
const Vmm &vmm_src) {
// Save src data on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_src);
// x*alpha
h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
// sigmoid(x*alpha)
logistic_compute_vector_fwd(vmm_src);
// x*sigmoid(alpha*x)
h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux0, h->ptr[reg_vmm_stack_ptr_]);
h->uni_vmulps(vmm_src, vmm_src, vmm_aux0);
}

Expand Down Expand Up @@ -858,8 +892,7 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_fwd(
}

// save source on stack to check neg and zero values at the end
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_src);

// compute i
const int approx_order = 5;
Expand Down Expand Up @@ -926,19 +959,19 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_fwd(
// rest of code puts indices on stack, fetching a table number based
// on an index, replaces index with the value, and, finally, moves
// fetched values into vector register.
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_idxs);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_ + vlen], vmm_idxs);

for (size_t i = 0; i < vlen / sizeof(float); ++i) {
h->mov(reg_tmp.cvt32(), h->ptr[h->rsp + i * sizeof(float)]);
h->mov(reg_tmp.cvt32(),
h->ptr[reg_vmm_stack_ptr_ + vlen + i * sizeof(float)]);
h->shl(reg_tmp.cvt32(), 2); // multiply by simd_w
table_idx = h->ptr[p_table + table_start_idx + offt + reg_tmp];
h->mov(reg_tmp.cvt32(), table_idx);
h->mov(h->ptr[h->rsp + i * sizeof(float)], reg_tmp.cvt32());
h->mov(h->ptr[reg_vmm_stack_ptr_ + vlen + i * sizeof(float)],
reg_tmp.cvt32());
}

h->uni_vmovups(vmm_dst, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_dst, h->ptr[reg_vmm_stack_ptr_ + vlen]);
// restore GPR state
h->mov(reg_tmp, h->ptr[h->rsp]);
h->add(h->rsp, gpr_size);
Expand Down Expand Up @@ -982,8 +1015,7 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::log_compute_vector_fwd(

// Check original source for zero and neg values. skip blend w/ extreme
// values if all src values were positive.
h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux1, h->ptr[reg_vmm_stack_ptr_]);

Xbyak::Label end_log_zero_label;
compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_le_os);
Expand Down Expand Up @@ -1073,12 +1105,11 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_fwd(
// `isa` as the injector. Once the assumption is wrong, `vecs_count` and
// `vlen` should be replaced with `host_isa::vlen` and
// `host_isa::vecs_count`.
h->sub(h->rsp, (vecs_count + 2) * vlen);
for (size_t i = 2; i < vecs_count + 2; ++i)
h->uni_vmovups(h->ptr[h->rsp + i * vlen], Vmm(i - 2));
h->uni_vmovups(h->ptr[h->rsp + 0 * vlen], vmm_src); // src
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_ + i * vlen], Vmm(i - 2));
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_ + 0 * vlen], vmm_src); // src
h->uni_vmovups(vmm_src, table_val(beta));
h->uni_vmovups(h->ptr[h->rsp + 1 * vlen], vmm_src); // beta
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_ + 1 * vlen], vmm_src); // beta

// save function address in gpr to pass in in call instruction
h->mov(h->rbp, reinterpret_cast<uintptr_t>(powf));
Expand All @@ -1101,9 +1132,10 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_fwd(
// Take src, apply powf on it and replace value on a stack with dst.
Xmm xmm0 = Xmm(0), xmm1 = Xmm(1);
for (size_t i = 0; i < vlen / sizeof(float); ++i) {
const Address &source = h->ptr[h->rsp + h->rbx + i * sizeof(float)];
const Address &source
= h->ptr[reg_vmm_stack_ptr_ + i * sizeof(float)];
h->uni_vmovss(xmm0, source);
h->uni_vmovss(xmm1, h->ptr[h->rsp + h->rbx + vlen]); // beta
h->uni_vmovss(xmm1, h->ptr[reg_vmm_stack_ptr_ + vlen]); // beta
h->uni_vzeroupper(); // eliminate performance penalties on avx
h->call(h->rbp);
// eliminate performance penalties on sse isa
Expand All @@ -1115,9 +1147,8 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_fwd(

// restore vector registers
for (size_t i = vecs_count + 1; i >= 2; --i)
h->uni_vmovups(Vmm(i - 2), h->ptr[h->rsp + i * vlen]);
h->uni_vmovups(vmm_src, h->ptr[h->rsp + 0 * vlen]);
h->add(h->rsp, (vecs_count + 2) * vlen);
h->uni_vmovups(Vmm(i - 2), h->ptr[reg_vmm_stack_ptr_ + i * vlen]);
h->uni_vmovups(vmm_src, h->ptr[reg_vmm_stack_ptr_ + 0 * vlen]);

// restore k registers
if (is_avx512) {
Expand Down Expand Up @@ -1340,14 +1371,12 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_tanh_compute_vector_bwd(
h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_aux0);

// save G2 on stack as tanh uses all available registers
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_aux2);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_aux2);

// T = tanh(G1(x))
tanh_compute_vector_fwd(vmm_src);

h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux2, h->ptr[reg_vmm_stack_ptr_]);

// compute 0.5 * (1 + T) * (1 + G2 * (1 - T))
if (isa == sse41 || isa == avx) {
Expand Down Expand Up @@ -1474,12 +1503,10 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::swish_compute_vector_bwd(
// R = alpha * s
h->uni_vmulps(vmm_src, vmm_src, table_val(alpha));
// Save R on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_src);
// Q = sigmoid(alpha * s)
logistic_compute_vector_fwd(vmm_src);
h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux0, h->ptr[reg_vmm_stack_ptr_]);
// compute Q * (1 + R * (1 - Q))
if (utils::one_of(isa, sse41, avx)) {
h->uni_vmovups(vmm_aux1, table_val(one));
Expand Down Expand Up @@ -1535,13 +1562,13 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::pow_compute_vector_bwd(
h->uni_vmovups(vmm_src, table_val(alpha));
} else {
// Save `s` on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_src);
h->add(reg_vmm_stack_ptr_, vlen);
// R = alpha * pow(s, beta)
pow_compute_vector_fwd(vmm_src);
h->sub(reg_vmm_stack_ptr_, vlen);
// Restore `s` from stack
h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux1, h->ptr[reg_vmm_stack_ptr_]);
// Save mask of zero elements to convert them into zeros at the end
if (beta_ >= 1) compute_cmp_mask(vmm_aux1, table_val(zero), _cmp_eq_oq);
// res = alpha * beta * pow(s, beta - 1) = beta * R / s;
Expand All @@ -1562,16 +1589,15 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_bwd(
table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_two));

// Save R on stack for later usage
h->sub(h->rsp, vlen);
h->uni_vmovups(h->ptr[h->rsp], vmm_src);
h->uni_vmovups(h->ptr[reg_vmm_stack_ptr_], vmm_src);

// Q = exp(-R*R)
h->uni_vmulps(vmm_src, vmm_src, vmm_src);
h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));
exp_compute_vector_fwd(vmm_src);

// T = R / sqrt(pi) * Q
h->uni_vmovups(vmm_aux2, h->ptr[h->rsp]);
h->uni_vmovups(vmm_aux2, h->ptr[reg_vmm_stack_ptr_]);
h->uni_vmulps(vmm_aux2, vmm_aux2,
table_val(gelu_erf_Abramowitz_Stegun_one_over_sqrt_pi));
h->uni_vmulps(vmm_aux2, vmm_aux2, vmm_src);
Expand All @@ -1580,12 +1606,11 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::gelu_erf_compute_vector_bwd(
h->uni_vxorps(vmm_src, vmm_src, table_val(sign_mask));

// get sign
h->uni_vmovups(vmm_aux0, h->ptr[h->rsp]);
h->uni_vmovups(vmm_aux0, h->ptr[reg_vmm_stack_ptr_]);
h->uni_vandps(vmm_aux0, vmm_aux0, table_val(sign_mask));

// abs(x)
h->uni_vmovups(vmm_aux1, h->ptr[h->rsp]);
h->add(h->rsp, vlen);
h->uni_vmovups(vmm_aux1, h->ptr[reg_vmm_stack_ptr_]);
abs_compute_vector_fwd(vmm_aux1);

// W = 1 / (p * s + 1)
Expand Down Expand Up @@ -1657,14 +1682,40 @@ void jit_uni_eltwise_injector_f32<isa, Wmm>::hardsigmoid_compute_vector_bwd(
template <cpu_isa_t isa, typename Wmm>
size_t jit_uni_eltwise_injector_f32<isa, Wmm>::aux_gprs_count() {
using namespace alg_kind;
int ret = 0;
switch (alg_) {
case eltwise_tanh_use_dst_for_bwd:
case eltwise_tanh:
case eltwise_gelu_tanh: return isa == sse41 || isa == avx ? 4 : 0;
default: return 0;
case eltwise_gelu_tanh: ret = isa == sse41 || isa == avx ? 4 : 0; break;
default: ret = 0; break;
}
return 0;
};
return ret + need_vmm_stack_ptr();
}

template <cpu_isa_t isa, typename Wmm>
size_t jit_uni_eltwise_injector_f32<isa, Wmm>::op_vecs_count() {
using namespace alg_kind;
int ret = 0;
if (is_fwd_) {
switch (alg_) {
case eltwise_gelu_tanh:
case eltwise_swish: ret = 1; break;
case eltwise_log: ret = 1 + utils::one_of(isa, sse41, avx); break;
case eltwise_pow: ret = vecs_count + 2; break;
default: ret = 0;
}
} else {
switch (alg_) {
case eltwise_gelu_tanh:
case eltwise_swish:
case eltwise_gelu_erf: ret = 1; break;
case eltwise_pow: ret = 1 + (vecs_count + 2 /*calls fwd*/); break;
default: ret = 0;
}
}

return ret;
}

template <cpu_isa_t isa, typename Wmm>
void jit_uni_eltwise_injector_f32<isa, Wmm>::round_compute_vector_fwd(
Expand Down
Loading

0 comments on commit 2e3c94c

Please sign in to comment.