Skip to content

Commit

Permalink
[CPU] Reference FC mxfp4 compression support
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed Jul 22, 2024
1 parent b8ba903 commit 966ac90
Show file tree
Hide file tree
Showing 13 changed files with 484 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
auto input_pattern = pattern::any_input();
auto convert_pattern = pattern::wrap_type<ov::op::v0::Convert>({input_pattern}, pattern::consumers_count(1));
auto zero_point_pattern = pattern::any_input();
auto scale_pattern = pattern::any_input();
auto subtract_pattern = pattern::wrap_type<ov::op::v1::Subtract>({convert_pattern, zero_point_pattern});
auto multiply_pattern = pattern::wrap_type<ov::op::v1::Multiply>({subtract_pattern, pattern::any_input()});
auto multiply_pattern = pattern::wrap_type<ov::op::v1::Multiply>({subtract_pattern, scale_pattern});
auto multiply_no_subtract_pattern =
pattern::wrap_type<ov::op::v1::Multiply>({convert_pattern, pattern::any_input()});
pattern::wrap_type<ov::op::v1::Multiply>({convert_pattern, scale_pattern});
auto root = std::make_shared<pattern::op::Or>(OutputVector{multiply_pattern, multiply_no_subtract_pattern});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) -> bool {
Expand Down Expand Up @@ -100,6 +101,10 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
// mark Multiply as dequantization node
ov::mark_as_dequantization_node(multiply);

auto scale = multiply->get_input_node_shared_ptr(1);
ov::disable_constant_folding(scale);
ov::enable_keep_const_precision(scale->get_input_node_shared_ptr(0));

return false;
};

Expand Down
41 changes: 29 additions & 12 deletions src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,9 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseFCAndWeightsDecompression(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionAndBias");
FuseConvolutionMatMulDeconvAndBias(graph);
graph.RemoveDroppedNodes();
// OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseConvolutionAndBias");
// FuseConvolutionMatMulDeconvAndBias(graph);
// graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseMultiplyAndAdd");
FuseMultiplyAndAdd(graph);
Expand Down Expand Up @@ -135,9 +135,9 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
FuseConvolutionAndSimpleOperation(graph);
graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFullyConnectedAndSimpleOperation");
FuseFullyConnectedAndSimpleOperation(graph);
graph.RemoveDroppedNodes();
// OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFullyConnectedAndSimpleOperation");
// FuseFullyConnectedAndSimpleOperation(graph);
// graph.RemoveDroppedNodes();

OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseMatMulAndSimpleOperation");
FuseMatMulAndSimpleOperation(graph);
Expand Down Expand Up @@ -289,7 +289,8 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
}

void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
std::set<ov::element::Type> supportedWeightsPrecisions{ov::element::u8, ov::element::i8, ov::element::nf4, ov::element::u4, ov::element::i4};
std::set<ov::element::Type> supportedWeightsPrecisions{
ov::element::u8, ov::element::i8, ov::element::nf4, ov::element::u4, ov::element::i4, ov::element::f4e2m1};
const std::set<ov::element::Type> supportedDataPrecisions{ov::element::f32, ov::element::bf16};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
Expand Down Expand Up @@ -329,16 +330,24 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
}

CPU_GRAPH_OPTIMIZER_SCOPE(FuseFCAndWeightsDecompression);
const auto multiplyConstNode = multiplyNode->getParentEdgeAt(1)->getParent();
const auto mulParent1 = multiplyNode->getParentEdgeAt(1)->getParent();
NodePtr multiplyParent, multiplyConvertNode, multiplyConstNode;
multiplyParent = mulParent1;
if (multiplyParent->getType() == Type::Convert) {
multiplyConvertNode = multiplyParent;
multiplyParent = multiplyConvertNode->getParentEdgeAt(0)->getParent();
}
multiplyConstNode = multiplyParent;
if (multiplyConstNode->getType() != Type::Input) {
SKIP_FUSION_FOR_NODE(fcNode);
}
const bool withMultiplyConvert = multiplyConvertNode != nullptr;

const auto mulParent = multiplyNode->getParentEdgeAt(0)->getParent();
const bool withSubtract = mulParent->getAlgorithm() == Algorithm::EltwiseSubtract;
const auto mulParent0 = multiplyNode->getParentEdgeAt(0)->getParent();
const bool withSubtract = mulParent0->getAlgorithm() == Algorithm::EltwiseSubtract;
NodePtr subtractNode, subtractConvertNode, subtractConstNode;
if (withSubtract) {
subtractNode = mulParent;
subtractNode = mulParent0;
if (!expectedNode(subtractNode, Type::Eltwise)) {
SKIP_FUSION_FOR_NODE(fcNode);
}
Expand All @@ -354,7 +363,7 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
}

const bool withSubtractConvert = subtractConvertNode != nullptr;
const auto convertNode = withSubtract ? subtractNode->getParentEdgeAt(0)->getParent() : mulParent;
const auto convertNode = withSubtract ? subtractNode->getParentEdgeAt(0)->getParent() : mulParent0;
if (!expectedNode(convertNode, Type::Convert)) {
SKIP_FUSION_FOR_NODE(fcNode);
}
Expand Down Expand Up @@ -461,6 +470,8 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
fcNode->addOriginalLayer(subtractNode->getOriginalLayers());
if (withSubtractConvert)
fcNode->addOriginalLayer(subtractConvertNode->getOriginalLayers());
if (withMultiplyConvert)
fcNode->addOriginalLayer(multiplyConvertNode->getOriginalLayers());

const auto& weightsPrecision = weightsNode->getOriginalOutputPrecisionAtPort(0);
if (withTranspose) {
Expand Down Expand Up @@ -511,6 +522,12 @@ void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
graph.RemoveEdge(subtractConvertNode->getParentEdgeAt(0));
}
graph.RemoveEdge(multiplyNode->getParentEdgeAt(1));
if (withMultiplyConvert) {
// MultiplyConvert is removed only if there are no other consumers (e.g. CompressedGather)
const auto& restChilds = multiplyConvertNode->getChildEdges();
if (restChilds.empty())
graph.RemoveEdge(multiplyConvertNode->getParentEdgeAt(0));
}

graph.DropNode(convertNode);
if (withSubtract)
Expand Down
39 changes: 39 additions & 0 deletions src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,40 @@ struct ConvertFrom4BitPrecision {
parallel_for(ctx.size, [&](size_t i) {
dst[i] = static_cast<DT>(get_i4(src[i / 2], i % 2));
});
} else if (ctx.inType == ov::element::f4e2m1) {
parallel_for(ctx.size, [&](size_t i) {
dst[i] = static_cast<DT>(float4_e2m1::from_bits(get_u4(src[i / 2], i % 2)));
});
} else {
OPENVINO_THROW("cpu_convert doesn't support input data type: ", ctx.inType, ". Not implemented.");
}
ctx.converted = true;
}
};


#define INTEL_CPU_CVT_FROM_BYTE_FP(DT) OV_CASE(ov::element::DT, PrecisionInfo<ov::element::DT>::value_type)

#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \
INTEL_CPU_CVT_FROM_BYTE_FP(f32), INTEL_CPU_CVT_FROM_BYTE_FP(bf16), INTEL_CPU_CVT_FROM_BYTE_FP(f16)

struct ConvertFromByteFPContext {
ov::element::Type_t inType;
const void *srcPtr;
void *dstPtr;
size_t size;
bool converted;
};

template <typename DT>
struct ConvertFromByteFPPrecision {
void operator()(ConvertFromByteFPContext &ctx) {
auto src = static_cast<const uint8_t*>(ctx.srcPtr);
auto dst = static_cast<DT*>(ctx.dstPtr);
if (ctx.inType == ov::element::f8e8m0) {
parallel_for(ctx.size, [&](size_t i) {
dst[i] = static_cast<DT>(float8_e8m0::from_bits(src[i]));
});
} else {
OPENVINO_THROW("cpu_convert doesn't support input data type: ", ctx.inType, ". Not implemented.");
}
Expand Down Expand Up @@ -703,6 +737,11 @@ void cpu_convert(const void *srcPtr,
OV_SWITCH(intel_cpu, ConvertFrom4BitPrecision, ctx, dstPrc, INTEL_CPU_CVT_FROM_4BIT_LIST);
if (!ctx.converted)
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
} else if (srcPrc.bitwidth() == 8u && srcPrc.is_real()) {
ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false};
OV_SWITCH(intel_cpu, ConvertFromByteFPPrecision, ctx, dstPrc, INTEL_CPU_CVT_FROM_BYTE_FP_LIST);
if (!ctx.converted)
OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc);
} else {
ConvertContext ctx {
srcPtr,
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_cpu/src/nodes/executors/executor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ namespace intel_cpu {
# define OV_CPU_INSTANCE_MLAS_X64(...)
#endif

// #if defined(OV_CPU_WITH_TPP)
# define OV_CPU_INSTANCE_TPP(...) {__VA_ARGS__},
// #else
// # define OV_CPU_INSTANCE_TPP(...)
// #endif

#define OV_CPU_INSTANCE_COMMON(...) {__VA_ARGS__},

// @todo another option is to determine shape relation by executor type
Expand All @@ -63,7 +69,8 @@ enum class ExecutorType {
Dnnl,
Acl,
Mlas,
jit_aarch64
jit_aarch64,
Tpp,
};

enum class OperationType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "nodes/executors/precision_matcher.hpp"
#include "nodes/executors/precision_translation.hpp"
#include "nodes/executors/type_mask.hpp"
#include "nodes/executors/tpp/tpp_fullyconnected.hpp"
#include "openvino/core/type/element_type.hpp"
#include "ov_optional.hpp"
#include "utils/cpp/maybe_unused.hpp"
Expand Down Expand Up @@ -205,6 +206,28 @@ const std::vector<ExecutorImplementation<FCAttrs>>& getImplementations() {
const ExecutorContext::CPtr context) {
return std::make_shared<MlasGemmExecutor>(attrs, postOps, memory, context);
})
OV_CPU_INSTANCE_TPP(
"fullyconnected_tpp",
ExecutorType::Tpp,
OperationType::FullyConnected,
ShapeTolerance::Agnostic,
// supports
[](const FCConfig& config) -> bool {
return TPPFCExecutor::supports(config);
},
// requiresFallback
[](const FCConfig& config) -> ov::optional<executor::Config<FCAttrs>> {
return {};
},
// acceptsShapes
[](const MemoryArgs& memory) -> bool {
return true;
},
// create
[](const FCAttrs& attrs, const PostOps& postOps, const MemoryArgs& memory, ExecutorContext::CPtr context) {
return std::make_shared<TPPFCExecutor>(attrs, postOps, memory, context);
})

OV_CPU_INSTANCE_X64(
"convolution_1x1_dnnl",
ExecutorType::Dnnl,
Expand Down
Loading

0 comments on commit 966ac90

Please sign in to comment.