Skip to content

Commit

Permalink
Added reference mamtul weights decompression kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Jul 21, 2023
1 parent f2fc37b commit a95cfa3
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
25 changes: 24 additions & 1 deletion src/plugins/intel_cpu/src/nodes/fullyconnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,17 @@ void FullyConnected::getSupportedDescriptors() {
if (getParentEdges().size() != 2 && getParentEdges().size() != 3)
IE_THROW() << errorPrefix << " has incorrect number of input edges";
if (getChildEdges().empty())
IE_THROW()<< errorPrefix << " has incorrect number of output edges";
IE_THROW() << errorPrefix << " has incorrect number of output edges";

withBiases = getOriginalInputsNumber() == 3;

useSparseWeights = useSparseWeightsDecompression();
useWeightsDecompression = canUseWeightsDecompression();
if (!useWeightsDecompression) {
if (!decompressionSubtract.empty() || !decompressionMultiply.empty()) {
IE_THROW() << errorPrefix << " doesn't support weights decompression feature";
}
}

auto inputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(DATA_ID));
outputDataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(DATA_ID));
Expand Down Expand Up @@ -839,6 +845,10 @@ bool FullyConnected::canBeExecutedInConv1x1() const {
bool retVal = false;
const auto inRank = getInputShapeAtPort(DATA_ID).getRank();
const auto weightRank = getInputShapeAtPort(WEIGHTS_ID).getRank();

if (useWeightsDecompression)
return false;

// disable rank=4:
// if layout is nhwc:
// A matrix: N * IC * H * W --> N * (IC*H*W), the M, N', K of matrix multiply will be:
Expand Down Expand Up @@ -953,6 +963,19 @@ void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
elementsCount);
}

// todo: reuse the method in fusion pass for limitations check
bool FullyConnected::canUseWeightsDecompression() {
if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx2))
return false;

if (getOriginalInputPrecisionAtPort(DATA_ID) != Precision::FP32 ||
getOriginalInputPrecisionAtPort(WEIGHTS_ID) != Precision::U8) {
return false;
}

return true;
}

void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
if (!constInputNode) {
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/fullyconnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class FullyConnected : public Node {
bool useSparseWeightsDecompression();
VectorDims expectedBiasDims {};

bool canUseWeightsDecompression();
bool useWeightsDecompression = false;
std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,14 @@ TEST_P(MatmulWeightsDecompressionTest, CompareWithRefs) {
namespace {

std::vector<std::vector<ngraph::Shape>> inputShapes = {
{{1, 4, 16}, {32, 16}},
{{1, 4, 16}, {256, 16}},
{{1, 4, 32}, {256, 32}},
{{1, 4, 48}, {256, 48}},
{{1, 4, 512}, {256, 512}},
{{1, 16, 32}, {64, 32}},
{{10, 4, 16}, {32, 16}},
{{10, 40, 496}, {240, 496}},
};

std::vector<size_t> patternTypes = {
Expand Down

0 comments on commit a95cfa3

Please sign in to comment.