Skip to content

Commit

Permalink
[CPU] Enable compressed FC via oneDNN Matmul primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Dec 12, 2024
1 parent 45bf77b commit 370e2b2
Show file tree
Hide file tree
Showing 18 changed files with 546 additions and 258 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/src/cpu_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void transferData(const IMemory& src, const IMemory& dst, bool ftz) {
if (!ftz) {
return;
}
if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() == ov::element::bf16) {
if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() != ov::element::f32) {
return;
}
size_t offset = 0;
Expand Down
68 changes: 63 additions & 5 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,13 +673,71 @@ static MemoryPtr prepackDecompressionParams(const MemoryCPtr& paramsPtr,
auto srcMem = std::make_shared<Memory>(engine, srcMemoryDesc, paramsPtr->getData());

dstMem->load(*srcMem);

return dstMem;
}

static dnnl::memory::dims getGroupDims(const VectorDims& weiDims, const VectorDims& scaleDims) {
if (scaleDims[0] == 1 && scaleDims[1] == 1)
return {};

int N = weiDims[weiDims.size() - 2];
int K = weiDims[weiDims.size() - 1];
dnnl::memory::dim groupN = N / scaleDims[0];
dnnl::memory::dim groupK = K / scaleDims[1];

return {groupK, groupN};
}

static int getMask(const VectorDims& weiDims, const dnnl::memory::dims& groupDims) {
const int maskN = 1 << (weiDims.size() - 1);
const int maskK = 1 << (weiDims.size() - 2);
int N = weiDims[weiDims.size() - 2];
int K = weiDims[weiDims.size() - 1];
int mask = 0;
if (!groupDims.empty() && groupDims[1] != N)
mask += maskN;
if (!groupDims.empty() && groupDims[0] != K)
mask += maskK;

return mask;
}

void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr,
bool needTranspose,
ov::element::Type dstPrecision) {
ov::element::Type dstPrecision,
const VectorDims& weiDims) {
if (scales_ptr == nullptr)
return;

auto scaleMem = prepackDecompressionParams(scales_ptr, needTranspose, dstPrecision, engine);
auto groupDims = getGroupDims(weiDims, scaleMem->getStaticDims());
auto mask = getMask(weiDims, groupDims);

attr.set_scales(DNNL_ARG_WEIGHTS, mask, groupDims, DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] = std::move(scaleMem);
dnnlArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS]->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr,
bool needTranspose,
ov::element::Type dstPrecision,
const VectorDims& weiDims) {
if (zero_points_ptr == nullptr)
return;

auto zeroPointsMem = prepackDecompressionParams(zero_points_ptr, needTranspose, dstPrecision, engine);
auto groupDims = getGroupDims(weiDims, zeroPointsMem->getStaticDims());
auto mask = getMask(weiDims, groupDims);

attr.set_zero_points(DNNL_ARG_WEIGHTS, mask, groupDims, DnnlExtensionUtils::ElementTypeToDataType(dstPrecision));
cpuArgs[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem;
dnnlArgs[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS] = zeroPointsMem->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionScalesLegacy(const MemoryCPtr& scales_ptr,
bool needTranspose,
ov::element::Type dstPrecision) {
if (scales_ptr == nullptr)
return;

Expand All @@ -692,9 +750,9 @@ void DnnlPostOpsComposer::appendDecompressionScales(const MemoryCPtr& scales_ptr
cpuArgs[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS]->getPrimitive();
}

void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr,
bool needTranspose,
ov::element::Type dstPrecision) {
void DnnlPostOpsComposer::appendDecompressionZeroPointsLegacy(const MemoryCPtr& zero_points_ptr,
bool needTranspose,
ov::element::Type dstPrecision) {
if (zero_points_ptr == nullptr)
return;

Expand Down
15 changes: 13 additions & 2 deletions src/plugins/intel_cpu/src/dnnl_postops_composer.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,21 @@ class DnnlPostOpsComposer {
const MemoryArgs& memory,
const dnnl::memory::data_type outDataType);
DnnlPrimitiveAttrs compose();
void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision);

void appendDecompressionScales(const MemoryCPtr& scales_ptr,
bool needTranspose,
ov::element::Type dstPrecision,
const VectorDims& weiDims);
void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr,
bool needTranspose,
ov::element::Type dstPrecision);
ov::element::Type dstPrecision,
const VectorDims& weiDims);
void appendDecompressionScalesLegacy(const MemoryCPtr& scales_ptr,
bool needTranspose,
ov::element::Type dstPrecision);
void appendDecompressionZeroPointsLegacy(const MemoryCPtr& zero_points_ptr,
bool needTranspose,
ov::element::Type dstPrecision);
void setDynamicQuantizationParams(uint64_t groupSize);

private:
Expand Down
13 changes: 7 additions & 6 deletions src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -570,12 +570,13 @@ struct ConvertFromBinPrecision<std::tuple<src_t, dst_t>> {
}
};

#define INTEL_CPU_CVT_FROM_4BIT_LIST \
INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), INTEL_CPU_CVT(u4, i8), \
INTEL_CPU_CVT(u4, u8), INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), \
INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), INTEL_CPU_CVT(nf4, f32), INTEL_CPU_CVT(nf4, bf16), \
INTEL_CPU_CVT(nf4, f16), INTEL_CPU_CVT(nf4, i8), INTEL_CPU_CVT(nf4, u8), INTEL_CPU_CVT(f4e2m1, f32), \
INTEL_CPU_CVT(f4e2m1, bf16), INTEL_CPU_CVT(f4e2m1, f16), INTEL_CPU_CVT(f4e2m1, i8), INTEL_CPU_CVT(f4e2m1, u8)
#define INTEL_CPU_CVT_FROM_4BIT_LIST \
INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, i32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), \
INTEL_CPU_CVT(u4, i8), INTEL_CPU_CVT(u4, u8), INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, i32), \
INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), \
INTEL_CPU_CVT(nf4, f32), INTEL_CPU_CVT(nf4, bf16), INTEL_CPU_CVT(nf4, f16), INTEL_CPU_CVT(nf4, i8), \
INTEL_CPU_CVT(nf4, u8), INTEL_CPU_CVT(f4e2m1, f32), INTEL_CPU_CVT(f4e2m1, bf16), INTEL_CPU_CVT(f4e2m1, f16), \
INTEL_CPU_CVT(f4e2m1, i8), INTEL_CPU_CVT(f4e2m1, u8)

struct ConvertFrom4BitContext {
ov::element::Type_t inType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,16 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
if (dstPrc != f8e8m0 || useDynamicQuantization)
dstPrc = ov::element::f32;

dnnlpoc.appendDecompressionScales(memory.at(ARG_WEI | ARG_ATTR_SCALES), !attrs.weightsNonTransposed, dstPrc);
dnnlpoc.appendDecompressionScalesLegacy(memory.at(ARG_WEI | ARG_ATTR_SCALES),
!attrs.weightsNonTransposed,
dstPrc);
}

if (memory.count(ARG_WEI | ARG_ATTR_ZERO_POINTS)) {
auto dstPrc = useDynamicQuantization ? ov::element::u8 : ov::element::f32;
dnnlpoc.appendDecompressionZeroPoints(memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS),
!attrs.weightsNonTransposed,
dstPrc);
dnnlpoc.appendDecompressionZeroPointsLegacy(memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS),
!attrs.weightsNonTransposed,
dstPrc);
}

if (useDynamicQuantization) {
Expand All @@ -247,9 +249,9 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs,
uint8_t zp_value = (wei_precision == ov::element::i8) ? 128 : 8;
DnnlBlockedMemoryDesc zpMemoryDesc(ov::element::u8, Shape({1}));
auto decompressionSubtractPtr = std::make_shared<Memory>(context->getEngine(), zpMemoryDesc, &zp_value);
dnnlpoc.appendDecompressionZeroPoints(decompressionSubtractPtr,
!attrs.weightsNonTransposed,
ov::element::u8);
dnnlpoc.appendDecompressionZeroPointsLegacy(decompressionSubtractPtr,
!attrs.weightsNonTransposed,
ov::element::u8);
}
dnnlpoc.setDynamicQuantizationParams(attrs.dynamicQuantizationGroupSize);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <common/primitive_attr.hpp>
#include <common/primitive_desc_iface.hpp>
#include <common/primitive_iface.hpp>
#include <cpu/x64/cpu_isa_traits.hpp>
#include <memory>
#include <oneapi/dnnl/dnnl.hpp>
#include <oneapi/dnnl/dnnl_common.hpp>
Expand Down Expand Up @@ -76,6 +77,23 @@ bool DnnlMatMulPrimitive::Key::operator==(const Key& rhs) const {
return result;
}

template <typename dimsType>
static dimsType normalizeToRank(const dimsType& vec, size_t rank) {
if (vec.size() == rank || vec.empty())
return vec;

dimsType result;
result.reserve(rank);

for (size_t i = vec.size(); i < rank; ++i) {
result.push_back(1);
}

result.insert(result.end(), vec.begin(), vec.end());

return result;
}

std::shared_ptr<DnnlMatMulPrimitive> DnnlMatMulPrimitive::create(const MemoryArgs& memory,
const MatMulAttrs& attrs,
const ExecutorContext::CPtr context,
Expand Down Expand Up @@ -105,19 +123,22 @@ DnnlMemoryDescPtr DnnlMatMulPrimitive::makeTransposedWeightDescriptor(const Dnnl
const auto& weiDesc = srcDesc->getDnnlDesc();
auto wDims = weiDesc.get_dims();
auto wDataType = weiDesc.get_data_type();
std::swap(wDims[wDims.size() - 1], wDims[wDims.size() - 2]);
dnnl::memory::dims wDims2D = reshapeDownToRank<2>(wDims);

const auto format = weightsNonTransposed ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
const auto transposedWeiDesc = dnnl::memory::desc{wDims2D, wDataType, format};
const auto reshapedWeiDesc = transposedWeiDesc.reshape(dstDesc->getDnnlDesc().get_dims());

return DnnlExtensionUtils::makeDescriptor(transposedWeiDesc);
return DnnlExtensionUtils::makeDescriptor(reshapedWeiDesc);
}

static DnnlPrimitiveAttrs createPrimitiveAttrs(const MatMulAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
ExecutorContext::CPtr context,
bool useDynamicQuantization) {
bool useWeightsDecompression,
bool weightsNonTransposed) {
const auto& srcDesc = memory.at(ARG_SRC)->getDescPtr();
const auto& weiDesc = memory.at(ARG_WEI)->getDescPtr();
const auto& dstDesc = memory.at(ARG_DST)->getDescPtr();
Expand All @@ -132,7 +153,30 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const MatMulAttrs& attrs,
DnnlPostOpsComposer
dnnlpoc(postOps, context->getEngine(), dims, dims.size() - 1, isINT8, 1 << 0, memory, outputDataType);

return dnnlpoc.compose();
const auto maxRank =
std::max({srcDesc->getShape().getRank(), weiDesc->getShape().getRank(), dstDesc->getShape().getRank()});
auto normWeiDims = normalizeToRank(weiDesc->getShape().getStaticDims(), maxRank);
if (memory.count(ARG_WEI | ARG_ATTR_SCALES)) {
auto dstPrc = ov::element::f32;
dnnlpoc.appendDecompressionScales(memory.at(ARG_WEI | ARG_ATTR_SCALES),
!weightsNonTransposed,
dstPrc,
normWeiDims);
}
if (memory.count(ARG_WEI | ARG_ATTR_ZERO_POINTS)) {
// TODO: clarify oneDNN requirements on ZP precision
auto zp = memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS);
auto zpPrc = zp->getPrecision();
auto dstPrc = one_of(zpPrc, i32, i8, u8, i4, u4) ? zpPrc : i32;
dnnlpoc.appendDecompressionZeroPoints(zp, !weightsNonTransposed, dstPrc, normWeiDims);
}

auto primAttrs = dnnlpoc.compose();
if (useWeightsDecompression) {
primAttrs.attr.set_fpmath_mode(fpmath_mode::any, true);
}

return primAttrs;
}

static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory::desc& inputDesc,
Expand All @@ -143,22 +187,6 @@ static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory:
const dnnl::engine& engine,
const bool useSparseWeights,
const bool useWeightsDecompression) {
auto normalizeToRank = [](const dnnl::memory::dims& vec, size_t rank) -> dnnl::memory::dims {
if (vec.size() == rank || vec.empty())
return vec;

dnnl::memory::dims result;
result.reserve(rank);

for (size_t i = vec.size(); i < rank; ++i) {
result.push_back(1);
}

result.insert(result.end(), vec.begin(), vec.end());

return result;
};

auto weiDims = weightDesc.get_dims();
std::swap(weiDims[weiDims.size() - 1], weiDims[weiDims.size() - 2]);

Expand All @@ -175,7 +203,9 @@ static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory:

auto idt = inputDesc.get_data_type();
auto wdt = idt;
if (idt == dnnl::memory::data_type::u8 || idt == dnnl::memory::data_type::s8) {
if (useWeightsDecompression) {
wdt = weightDesc.get_data_type();
} else if (idt == dnnl::memory::data_type::u8 || idt == dnnl::memory::data_type::s8) {
wdt = memory::data_type::s8;
}

Expand Down Expand Up @@ -245,6 +275,16 @@ static VectorDims makeDummyOutputDims(const VectorDims& inShape, const VectorDim
return outputShape;
}

bool DnnlMatMulPrimitive::useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType) {
#if defined(OPENVINO_ARCH_X86_64)
if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2))
return false;
#endif

return (one_of(inputType, f32, bf16, f16) && one_of(weightsType, u8, i8, u4, i4));
}

DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
Expand All @@ -257,7 +297,9 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAt
auto dstDesc = memory.at(ARG_DST)->getDescPtr();
MatMulAttrs mmAttrs{false, false};

const auto postOpData = createPrimitiveAttrs(mmAttrs, postOps, memory, context, false);
const auto useWeightsDecompression = useWeightsDecompressionImpl(srcDesc->getPrecision(), weiDesc->getPrecision());
const auto postOpData =
createPrimitiveAttrs(mmAttrs, postOps, memory, context, useWeightsDecompression, attrs.weightsNonTransposed);

if (!cacheWeights)
return std::make_shared<DnnlShapeAgnosticData>(postOpData);
Expand Down Expand Up @@ -285,7 +327,7 @@ DnnlShapeAgnosticDataPtr DnnlMatMulPrimitive::createShapeAgnosticData(const FCAt
context->getEngine(),
context->getImplPriorities(),
false,
false);
useWeightsDecompression);

const auto weightsDesc = DnnlExtensionUtils::makeDescriptor(primDesc.weights_desc());
auto originalWeightsDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weiDesc);
Expand Down Expand Up @@ -319,7 +361,7 @@ DnnlMatMulPrimitive::DnnlMatMulPrimitive(const Key& key,
engine,
implPriorities,
false,
false)),
useWeightsDecompressionImpl(key.src->getPrecision(), key.wei->getPrecision()))),
m_implType(implTypeFromPrimDesc(m_primDesc)),
m_srcDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.src_desc())),
m_weiDesc(DnnlExtensionUtils::makeDescriptor(m_primDesc.weights_desc())),
Expand All @@ -328,8 +370,6 @@ DnnlMatMulPrimitive::DnnlMatMulPrimitive(const Key& key,
m_prim(primitive(m_primDesc)) {}

void DnnlMatMulPrimitive::execute(const dnnl_primitive_args& primArgs) const {
std::cout << "Executing MM primitive"
<< "\n";
m_prim.execute(m_stream, primArgs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class DnnlMatMulPrimitive {
return m_implType;
}

static bool useWeightsDecompressionImpl(const ov::element::Type inputType,
const ov::element::Type weightsType);

static DnnlShapeAgnosticDataPtr createShapeAgnosticData(const FCAttrs& attrs,
const PostOps& postOps,
const MemoryArgs& memory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ static const TypeMapping dnnlMatMulTypeMapping {
// quantization configuration
{{_u8 | _i8, _i8, _u8|_i8|_i32|_bf16|_f16|_f32|_undefined, _u8|_i8|_i32|_bf16|_f16|_f32}, pt(bypass(), bypass(), bypass(), bypass())},
{{_u8 | _i8, _i8, _any, _any}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
// compresses int weights
{{_f32 | _bf16 | _f16, _u8 | _i8, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// TODO: fp16
// @todo should we fallback to FPXX instead of _f32?
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
// @todo explicitly cover configuration limitations for oneDNN on ARM
Expand Down Expand Up @@ -404,7 +407,7 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
return std::make_shared<ShlFCExecutor>(attrs, postOps, memory, context);
}
)
OV_CPU_INSTANCE_X64(
OV_CPU_INSTANCE_DNNL(
"matmul_dnnl",
ExecutorType::Dnnl,
OperationType::MatMul,
Expand All @@ -415,7 +418,6 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
CPU_DEBUG_CAP_ENABLE(
if (getEnvBool("OV_CPU_ENABLE_DNNL_MAMTUL_FOR_FC")) {
VERIFY(noSparseDecompression(config), UNSUPPORTED_SPARSE_WEIGHTS);
VERIFY(noWeightsDecompression(config), UNSUPPORTED_WEIGHTS_DECOMPRESSION);
return true;
})
return false;
Expand Down
Loading

0 comments on commit 370e2b2

Please sign in to comment.