Skip to content

Commit

Permalink
limit bf16_emitters truncation impl to eltwise Constant inputs to kee…
Browse files Browse the repository at this point in the history
…p acc fix and minimize the impact on performance
  • Loading branch information
liubo-intel committed Dec 12, 2024
1 parent b4076b5 commit b43892d
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 25 deletions.
35 changes: 20 additions & 15 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
public:
jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host,
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::bf16)
ov::element::Type exec_prc = ov::element::bf16,
arithmetic_mode mode = arithmetic_mode::none)
: jit_emitter(host, host_isa, exec_prc) {
prepare_table();
mode_ = mode;
}

size_t get_inputs_num() const override {
return 1;
}

private:
arithmetic_mode mode_ = arithmetic_mode::none;
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override {
if (host_isa_ == dnnl::impl::cpu::x64::avx512_core) {
emit_isa<dnnl::impl::cpu::x64::avx512_core>(in_vec_idxs, out_vec_idxs);
Expand All @@ -42,23 +45,25 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
conditional3<isa == dnnl::impl::cpu::x64::sse41, Xmm, isa == dnnl::impl::cpu::x64::avx2, Ymm, Zmm>::type;

Vmm in = Vmm(in_vec_idxs[0]);
Vmm vmm_temp = Vmm(out_vec_idxs[0]);
if (mode_ == arithmetic_mode::constant_saturation) {
Vmm vmm_temp = Vmm(out_vec_idxs[0]);

h->uni_vmaxps(vmm_temp, in, table_val("bf16_min"));
h->uni_vminps(vmm_temp, vmm_temp, table_val("bf16_max"));
h->uni_vmaxps(vmm_temp, in, table_val("bf16_min"));
h->uni_vminps(vmm_temp, vmm_temp, table_val("bf16_max"));

if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) {
h->vfixupimmps(vmm_temp, in, table_val("selector"), 0);
} else {
Vmm mask = Vmm(aux_vec_idxs[0]);
h->uni_vcmpps(mask, in, in, 0x03); // _CMP_UNORD_Q
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("nan"), mask);
h->uni_vcmpps(mask, in, table_val("inf"), 0x00); // _CMP_EQ_OQ
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("inf"), mask);
h->uni_vcmpps(mask, in, table_val("neg_inf"), 0x00); // _CMP_EQ_OQ
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("neg_inf"), mask);
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) {
h->vfixupimmps(vmm_temp, in, table_val("selector"), 0);
} else {
Vmm mask = Vmm(aux_vec_idxs[0]);
h->uni_vcmpps(mask, in, in, 0x03); // _CMP_UNORD_Q
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("nan"), mask);
h->uni_vcmpps(mask, in, table_val("inf"), 0x00); // _CMP_EQ_OQ
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("inf"), mask);
h->uni_vcmpps(mask, in, table_val("neg_inf"), 0x00); // _CMP_EQ_OQ
h->uni_vblendvps(vmm_temp, vmm_temp, table_val("neg_inf"), mask);
}
h->uni_vmovups(in, vmm_temp);
}
h->uni_vmovups(in, vmm_temp);

if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
Ymm out = Ymm(out_vec_idxs[0]);
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ enum emitter_in_out_map {
gpr_to_gpr,
};

// Arithmetic modes for data type conversion in store_emitter
enum arithmetic_mode { none, saturation, truncation, constant_saturation };

// structure for storage of emitter parameters to hash in map
struct emitter_params {
virtual size_t hash() const = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ struct store_emitter_params : public emitter_params {
int store_num_;
};

// Arithmetic modes for data type conversion in store_emitter
enum arithmetic_mode {
saturation,
truncation
};

class jit_load_emitter : public jit_emitter {
public:
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator* host,
Expand Down
28 changes: 24 additions & 4 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener
reg_d_bias));
}

if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2))
uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa));
if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) {
auto const mode =
jep_.do_constant_saturation ? arithmetic_mode::constant_saturation : arithmetic_mode::none;
uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa, element::bf16, mode));
}

const auto& jep = jep_;

Expand Down Expand Up @@ -1355,6 +1358,7 @@ struct EltwiseKey {
ov::element::Type outPrc;
dnnl::post_ops postOps;
EltwiseImplType implType;
bool doConstantSaturation;

size_t hash() const {
using namespace dnnl::impl;
Expand Down Expand Up @@ -1390,6 +1394,10 @@ struct EltwiseKey {
seed = hash_combine(seed, outPrc.hash());
seed = get_post_op_hash(seed, *postOps.get());
seed = hash_combine(seed, implType);

if (outPrc == ov::element::bf16) {
seed = hash_combine(seed, doConstantSaturation);
}
return seed;
}

Expand All @@ -1416,6 +1424,8 @@ struct EltwiseKey {
result = result && (inpDims[i] == rhs.inpDims[i]);
}
}
if ((outPrc == ov::element::bf16) && (doConstantSaturation != rhs.doConstantSaturation))
return false;
}

return result;
Expand Down Expand Up @@ -1448,7 +1458,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
const std::vector<ov::element::Type>& inpPrc,
const ov::element::Type& outPrc,
const dnnl::post_ops& post_ops,
bool useRuntimePtrs) {
bool useRuntimePtrs,
bool doConstantSaturation) {
auto collapseLastDims = [](std::vector<size_t>& dims, int dimsToCollapse) {
for (size_t i = dims.size() - 2; i > dims.size() - dimsToCollapse - 2; i--) {
dims[dims.size() - 1] *= dims[i];
Expand Down Expand Up @@ -1639,6 +1650,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
jep.dst_prc = outPrc;
jep.work_amount = jep.dst_size = jep.dims.back();
jep.oc_size = oc_size;
jep.do_constant_saturation = doConstantSaturation;

std::transform(jep.oc_offsets.begin(), jep.oc_offsets.end(), jep.oc_offsets.begin(), [](size_t& offset) {
return offset * sizeof(float);
Expand Down Expand Up @@ -2160,7 +2172,8 @@ static Eltwise::executorPtr buildExecutor(const EltwiseKey& key) {
key.inpPrc,
key.outPrc,
key.postOps,
key.implType == EltwiseImplType::optimizedShapeAgnostic);
key.implType == EltwiseImplType::optimizedShapeAgnostic,
key.doConstantSaturation);
}

bool Eltwise::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
Expand Down Expand Up @@ -2861,6 +2874,13 @@ void Eltwise::prepareParams() {
"'");
}
}
key.doConstantSaturation = false;
for (size_t i = 0; i < getParentEdges().size(); i++) {
if (!getParentEdgeAt(i)->getParent()->isConstant()) {
key.doConstantSaturation = true;
break;
}
}

auto cache = context->getParamsCache();
auto result = cache->getOrCreate(key, buildExecutor);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_cpu/src/nodes/eltwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ struct jit_eltwise_params {

size_t work_amount;
bool use_runtime_ptrs;
bool do_constant_saturation;
};

struct jit_eltwise_call_args_indexes {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct jit_eltwise_params {

size_t work_amount;
bool use_runtime_ptrs;
bool do_constant_saturation;
};

struct jit_eltwise_call_args_indexes {
Expand Down

0 comments on commit b43892d

Please sign in to comment.