Skip to content

Commit

Permalink
x86 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Nov 18, 2024
1 parent 13d3d83 commit 4137826
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
12 changes: 8 additions & 4 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,10 @@ static dnnl::memory::dims getGroupDims(const VectorDims& weiDims, const VectorDi
static int getMask(const VectorDims& weiDims, const dnnl::memory::dims& groupDims) {
const int maskN = 1 << (weiDims.size() - 1);
const int maskK = 1 << (weiDims.size() - 2);
int N = weiDims[weiDims.size() - 2];
int K = weiDims[weiDims.size() - 1];
int mask = 0;
if (!groupDims.empty() && groupDims[1] != N)
if (!groupDims.empty())
mask += maskN;
if (!groupDims.empty() && groupDims[0] != K)
if (!groupDims.empty())
mask += maskK;

return mask;
Expand All @@ -650,6 +648,12 @@ void DnnlPostOpsComposer::appendDecompressionScales(
auto groupDims = getGroupDims(weiDims, scaleMem->getStaticDims());
auto mask = getMask(weiDims, groupDims);

// [WA] OneDNN JIT Brgemm Matmul shows bad performance with PER-OC scaling passed via non-empty groups
if (!groupDims.empty() && scaleMem->getStaticDims()[1] == 1) {
groupDims = {};
mask = 1 << (weiDims.size() - 1);
}

attr.set_scales(DNNL_ARG_WEIGHTS, mask, groupDims, DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = std::move(scaleMem);
dnnlArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ std::vector<std::string> disabledTestPatterns() {
// by calc abs_threshold with expected value
retVector.emplace_back(R"(.*smoke_GatherCompressedWeights_basic/GatherWeightsDecompression.CompareWithRefs.*INFERENCE_PRECISION_HINT.*bf16.*)");
retVector.emplace_back(R"(.*smoke_Interaction/IntertactionCPUTest.CompareWithRefs.*Prc=i32.*)");
retVector.emplace_back(R"(.*smoke_MatMulCompressedWeights_(amx|sym_amx|corner_cases_amx)/MatmulWeightsDecompression.CompareWithRefs.*INFERENCE_PRECISION_HINT.*bf16.*)");
retVector.emplace_back(R"(.*smoke_MatMulCompressedWeights_(sym_amx)/MatmulWeightsDecompression.CompareWithRefs.*INFERENCE_PRECISION_HINT.*bf16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16/EnforcePrecisionTest.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHABF16_4D/MHA.CompareWithRefImpl/.*\[1.58.16.34\]_IS\[1\]=\[1.58.16.34\]_IS\[2\]=\[1.1.1.58\]_IS\[3\]=\[1.58.16.34\].*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeBF16_(3|4)D/MHAWOTranspose.*)");
Expand Down

0 comments on commit 4137826

Please sign in to comment.