Skip to content

Commit

Permalink
[CPU] Enable compressed FC via oneDNN Matmul primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Nov 7, 2024
1 parent 696bdde commit d702d96
Show file tree
Hide file tree
Showing 13 changed files with 460 additions and 229 deletions.
58 changes: 55 additions & 3 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,11 +611,63 @@ static MemoryPtr prepackDecompressionParams(const MemoryCPtr& paramsPtr,
auto srcMem = std::make_shared<Memory>(engine, srcMemoryDesc, paramsPtr->getData());

dstMem->load(*srcMem);

return dstMem;
}

void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision) {

static dnnl::memory::dims getGroupDims(const VectorDims& weiDims, const VectorDims& scaleDims) {
int N = weiDims[weiDims.size() - 2];
int K = weiDims[weiDims.size() - 1];
dnnl::memory::dim groupN = N / scaleDims[0];
dnnl::memory::dim groupK = K / scaleDims[1];

return {groupK, groupN};
}

static int getMask(const VectorDims& weiDims, const dnnl::memory::dims& groupDims) {
const int maskN = 1 << (weiDims.size() - 1);
const int maskK = 1 << (weiDims.size() - 2);
int N = weiDims[weiDims.size() - 2];
int K = weiDims[weiDims.size() - 1];
int mask = 0;
if (groupDims[1] != N)
mask += maskN;
if (groupDims[0] != K)
mask += maskK;

return mask;
}

void DnnlPostOpsComposer::appendDecompressionScales(
const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision, const VectorDims& weiDims) {
if (scales_ptr == nullptr)
return;

auto scaleMem = prepackDecompressionParams(scales_ptr, needTranspose, dstPrecision, engine);
auto groupDims = getGroupDims(weiDims, scaleMem->getStaticDims());
auto mask = getMask(weiDims, groupDims);

attr.set_scales(DNNL_ARG_WEIGHTS, mask, groupDims, DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = std::move(scaleMem);
dnnlArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS]->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionZeroPoints(
const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision, const VectorDims& weiDims) {
if (zero_points_ptr == nullptr)
return;

auto zeroPointsMem = prepackDecompressionParams(zero_points_ptr, needTranspose, dstPrecision, engine);
auto groupDims = getGroupDims(weiDims, zeroPointsMem->getStaticDims());
auto mask = getMask(weiDims, groupDims);

attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groupDims, DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem;
dnnlArgs[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionScalesLegacy(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision) {
if (scales_ptr == nullptr)
return;

Expand All @@ -627,7 +679,7 @@ void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS]->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision) {
void DnnlPostOpsComposer::appendDecompressionZeroPointsLegacy(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision) {
if (zero_points_ptr == nullptr)
return;

Expand Down
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ class DnnlPostOpsComposer {
const bool hasBias,
const dnnl::memory::data_type outDataType);
DnnlPrimitiveAttrs compose();
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision);
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision, const VectorDims& weiDims);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision, const VectorDims& weiDims);
void appendDecompressionScalesLegacy(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision);
void appendDecompressionZeroPointsLegacy(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision);
void setDynamicQuantizationParams(uint64_t groupSize);

private:
Expand Down
4 changes: 2 additions & 2 deletions src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,8 @@ struct ConvertFromBinPrecision<std::tuple<src_t, dst_t>> {
};

#define INTEL_CPU_CVT_FROM_4BIT_LIST \
INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), INTEL_CPU_CVT(u4, i8), INTEL_CPU_CVT(u4, u8), \
INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), \
INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, i32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), INTEL_CPU_CVT(u4, i8), INTEL_CPU_CVT(u4, u8), \
INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, i32), INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), \
INTEL_CPU_CVT(nf4, f32), INTEL_CPU_CVT(nf4, bf16), INTEL_CPU_CVT(nf4, f16), INTEL_CPU_CVT(nf4, i8), INTEL_CPU_CVT(nf4, u8), \
INTEL_CPU_CVT(f4e2m1, f32), INTEL_CPU_CVT(f4e2m1, bf16), INTEL_CPU_CVT(f4e2m1, f16), INTEL_CPU_CVT(f4e2m1, i8), INTEL_CPU_CVT(f4e2m1, u8)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
if (dstPrc != f8e8m0 || useDynamicQuantization)
dstPrc = ov::element::f32;

dnnlpoc.appendDecompressionScales(attrs.decompressionMultiplyPtr, !attrs.weightsNonTransposed, dstPrc);
dnnlpoc.appendDecompressionScalesLegacy(attrs.decompressionMultiplyPtr, !attrs.weightsNonTransposed, dstPrc);
}
if (attrs.decompressionSubtractPtr) {
auto dstPrc = useDynamicQuantization ? ov::element::u8 : ov::element::f32;
dnnlpoc.appendDecompressionZeroPoints(attrs.decompressionSubtractPtr, !attrs.weightsNonTransposed, dstPrc);
dnnlpoc.appendDecompressionZeroPointsLegacy(attrs.decompressionSubtractPtr, !attrs.weightsNonTransposed, dstPrc);
}
if (useDynamicQuantization) {
auto wei_precision = weiDesc->getPrecision();
Expand All @@ -247,7 +247,7 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
uint8_t zp_value = (wei_precision == ov::element::i8) ? 128 : 8;
DnnlBlockedMemoryDesc zpMemoryDesc(ov::element::u8, Shape({1}));
auto decompressionSubtractPtr = std::make_shared<Memory>(context->getEngine(), zpMemoryDesc, &zp_value);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr,
dnnlpoc.appendDecompressionZeroPointsLegacy(decompressionSubtractPtr,
!attrs.weightsNonTransposed,
ov::element::u8);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <memory>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_common.hpp>
#include <cpu/x64/cpu_isa_traits.hpp>

#include "cpu_memory.h"
#include "cpu_types.h"
Expand Down Expand Up @@ -75,6 +76,23 @@ bool DnnlMatMulPrimitive::Key::operator==(const Key& rhs) const {
return result;
}

template <typename dimsType>
static dimsType normalizeToRank(const dimsType& vec, size_t rank) {
if (vec.size() == rank || vec.empty())
return vec;

dimsType result;
result.reserve(rank);

for (size_t i = vec.size(); i < rank; ++i) {
result.push_back(1);
}

result.insert(result.end(), vec.begin(), vec.end());

return result;
}

std::shared_ptr<DnnlMatMulPrimitive> DnnlMatMulPrimitive::create(const MemoryArgs& memory,
const MatMulAttrs& attrs,
const ExecutorContext::CPtr context,
Expand Down Expand Up @@ -108,15 +126,17 @@ DnnlMemoryDescPtr DnnlMatMulPrimitive::makeTransposedWeightDescriptor(const Dnnl

const auto format = weightsNonTransposed ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
const auto transposedWeiDesc = dnnl::memory::desc{wDims, wDataType, format};
const auto reshapedWeiDesc = transposedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims());

return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc);
return DnnlExtensionUtils::makeDescriptor(reshapedWeiDesc);
}

static DnnlPrimitiveAttrs createPrimitiveAttrs(const MatMulAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
ExecutorContext::CPtr context,
bool useDynamicQuantization) {
bool useWeightsDecompression,
bool weightsNonTransposed) {
const auto& srcDesc = memory.at(ARG_SRC)->getDescPtr();
const auto& weiDesc = memory.at(ARG_WEI)->getDescPtr();
const auto& dstDesc = memory.at(ARG_DST)->getDescPtr();
Expand All @@ -138,7 +158,24 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const MatMulAttrs& attrs,
!memory.at(ARG_BIAS)->getDesc().empty(),
outputDataType);

return dnnlpoc.compose();
const auto maxRank = std::max({srcDesc->getShape().getRank(), weiDesc->getShape().getRank(), dstDesc->getShape().getRank()});
auto normWeiDims = normalizeToRank(weiDesc->getShape().getStaticDims(), maxRank);
if (useWeightsDecompression && attrs.decompressionMultiplyPtr) {
auto dstPrc = ov::element::f32;
dnnlpoc.appendDecompressionScales(attrs.decompressionMultiplyPtr, !weightsNonTransposed, dstPrc, normWeiDims);
}
if (useWeightsDecompression && attrs.decompressionSubtractPtr) {
// TODO: clarify oneDNN requirements on ZP precision
auto dstPrc = ov::element::i32;
dnnlpoc.appendDecompressionZeroPoints(attrs.decompressionSubtractPtr, !weightsNonTransposed, dstPrc, normWeiDims);
}

auto primAttrs = dnnlpoc.compose();
if (useWeightsDecompression) {
primAttrs.attr.set_fpmath_mode(fpmath_mode::bf16, true);
}

return primAttrs;
}

static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory::desc& inputDesc,
Expand All @@ -149,22 +186,6 @@ static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory:
const dnnl::engine& engine,
const bool useSparseWeights,
const bool useWeightsDecompression) {
auto normalizeToRank = [](const dnnl::memory::dims& vec, size_t rank) -> dnnl::memory::dims {
if (vec.size() == rank || vec.empty())
return vec;

dnnl::memory::dims result;
result.reserve(rank);

for (size_t i = vec.size(); i < rank; ++i) {
result.push_back(1);
}

result.insert(result.end(), vec.begin(), vec.end());

return result;
};

auto weiDims = weightDesc.get_dims();
std::swap(weiDims[weiDims.size() - 1], weiDims[weiDims.size() - 2]);

Expand All @@ -181,7 +202,9 @@ static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory:

auto idt = inputDesc.get_data_type();
auto wdt = idt;
if (idt == dnnl::memory::data_type::u8 || idt == dnnl::memory::data_type::s8) {
if (useWeightsDecompression) {
wdt = weightDesc.get_data_type();
} else if (idt == dnnl::memory::data_type::u8 || idt == dnnl::memory::data_type::s8) {
wdt = memory::data_type::s8;
}

Expand Down Expand Up @@ -252,6 +275,16 @@ static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDim
return outputShape;
}

bool DnnlMatMulPrimitive::useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
// TODO: fp16
if (one_of(inputType, f32, bf16) && one_of(weightsType, u8, i8, u4, i4))
return true;
}
return false;
}

DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
Expand All @@ -262,9 +295,10 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAt
const auto& weiDesc = memory.at(ARG_WEI)->getDescPtr();
const auto& biasDesc = memory.at(ARG_BIAS)->getDescPtr();
auto dstDesc = memory.at(ARG_DST)->getDescPtr();
MatMulAttrs mmAttrs{false, false, attrs.dequantizationScales};
MatMulAttrs mmAttrs{false, false, attrs.dequantizationScales, attrs.decompressionSubtractPtr, attrs.decompressionMultiplyPtr};

const auto postOpData = createPrimitiveAttrs(mmAttrs, postOps, memory, context, false);
const auto useWeightsDecompression = useWeightsDecompressionImpl(srcDesc->getPrecision(), weiDesc->getPrecision());
const auto postOpData = createPrimitiveAttrs(mmAttrs, postOps, memory, context, useWeightsDecompression, attrs.weightsNonTransposed);

if (!cacheWeights)
return std::make_shared<DnnlShapeAgnosticData>(postOpData);
Expand Down Expand Up @@ -292,7 +326,7 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAt
context->getEngine(),
context->getImplPriorities(),
false,
false);
useWeightsDecompression);

const auto weightsDesc = DnnlExtensionUtils::makeDescriptor(primDesc.weights_desc());
auto originalWeightsDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc);
Expand Down Expand Up @@ -326,7 +360,7 @@ DnnlMatMulPrimitive::DnnlMatMulPrimitive(const Key& key,
engine,
implPriorities,
false,
false)),
useWeightsDecompressionImpl(key.src->getPrecision(), key.wei->getPrecision()))),
m_implType(implTypeFromPrimDesc(m_primDesc)),
m_srcDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.src_desc())),
m_weiDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.weights_desc())),
Expand All @@ -335,7 +369,6 @@ DnnlMatMulPrimitive::DnnlMatMulPrimitive(const Key& key,
m_prim(primitive(m_primDesc)) {}

void DnnlMatMulPrimitive::execute(const dnnl_primitive_args& primArgs) const {
std::cout << "Executing MM primitive" << "\n";
m_prim.execute(m_stream, primArgs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class DnnlMatMulPrimitive {
return m_implType;
}

static bool useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType);

static DnnlShapeAgnosticDataPtr createShapeAgnosticData(const FCAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ static const TypeMapping dnnlMatMulTypeMapping {
// quantization configuration
{{_u8 | _i8, _i8, _u8|_i8|_i32|_bf16|_f16|_f32|_undefined, _u8|_i8|_i32|_bf16|_f16|_f32}, pt(bypass(), bypass(), bypass(), bypass())},
{{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
// compresses int weights
{{_bf16, _u8 | _i8 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>()),
Require<dnnl::impl::cpu::x64::avx512_core_bf16>()},
{{_bf16, _u8 | _i8 | _u4 | _i4, _any, _any}, pt(just<f32>(), bypass(), just<f32>(), just<f32>())},
{{_f32, _u8 | _i8 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// TODO: fp16
// @todo should we fallback to FPXX instead of _f32?
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
// @todo explicitly cover configuration limitations for oneDNN on ARM
Expand Down Expand Up @@ -411,7 +417,6 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
CPU_DEBUG_CAP_ENABLE(
if (getEnvBool("OV_CPU_ENABLE_DNNL_MAMTUL_FOR_FC")) {
VERIFY(noSparseDecompression(config), UNSUPPORTED_SPARSE_WEIGHTS);
VERIFY(noWeightsDecompression(config), UNSUPPORTED_WEIGHTS_DECOMPRESSION);
return true;
})
return false;
Expand Down Expand Up @@ -440,7 +445,9 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
std::shared_ptr<DnnlShapeAgnosticData> shareAgnosticData) const {
MatMulAttrs matMulAttrs{false,
false,
attrs.dequantizationScales};
attrs.dequantizationScales,
attrs.decompressionSubtractPtr,
attrs.decompressionMultiplyPtr};
auto primitive =
DefaultInstantiator<DnnlMatMulPrimitive, MatMulAttrs, DnnlShapeAgnosticData>{}(
memory,
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_cpu/src/nodes/executors/matmul_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct MatMulAttrs {
bool transposeA;
bool transposeB;
std::vector<float> dequantizationScales;
MemoryCPtr decompressionSubtractPtr;
MemoryCPtr decompressionMultiplyPtr;
};

using MatMulConfig = executor::Config<MatMulAttrs>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ endif()

# find all the source files with the name of a class file
if(X86_64)
file(GLOB_RECURSE LIST_OF_TEST_ARCH_INSTANCES ${TEST_DIR}/instances/x64/${TEST_CLASS_FILE_NAME})
file(GLOB_RECURSE LIST_OF_TEST_ARCH_INSTANCES ${TEST_DIR}/x64/${TEST_CLASS_FILE_NAME})
elseif(ARM OR AARCH64)
file(GLOB_RECURSE LIST_OF_TEST_ARCH_INSTANCES ${TEST_DIR}/instances/arm/${TEST_CLASS_FILE_NAME})
file(GLOB_RECURSE LIST_OF_TEST_ARCH_INSTANCES ${TEST_DIR}/arm/${TEST_CLASS_FILE_NAME})
endif()
file(GLOB_RECURSE LIST_OF_TEST_COMMON_INSTANCES ${TEST_DIR}/instances/common/${TEST_CLASS_FILE_NAME})
file(GLOB_RECURSE LIST_OF_TEST_COMMON_INSTANCES ${TEST_DIR}/common/${TEST_CLASS_FILE_NAME})
set(LIST_OF_TEST_INSTANCES ${LIST_OF_TEST_COMMON_INSTANCES} ${LIST_OF_TEST_ARCH_INSTANCES})

set(TEST_INSTANCES "${LIST_OF_TEST_INSTANCES}")
Expand Down
Loading

0 comments on commit d702d96

Please sign in to comment.