Skip to content

Commit

Permalink
[CPU] Merge bf16 weights type conversion with repack stage
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-gorokhov committed May 20, 2024
1 parent ffb1732 commit 5def15b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/src/graph_optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,8 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
continue;
}
const auto convert = fullyConnected->getParentEdgeAt(1)->getParent();
if (convert->getType() != Type::Convert || convert->getOriginalInputPrecisionAtPort(0) != ov::element::f16 ||
if (convert->getType() != Type::Convert ||
!one_of(convert->getOriginalInputPrecisionAtPort(0), ov::element::f16, ov::element::bf16) ||
!one_of(convert->getOriginalOutputPrecisionAtPort(0), ov::element::f32, ov::element::bf16) ||
!convert->isConstant()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDec

ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(inType, inShapeA)};
std::shared_ptr<ov::Node> inputB = ov::test::utils::make_constant(weiConstElemType, inShapeB.get_shape());
if (weiConstElemType == ElementType::f16) {
if (weiConstElemType == ElementType::f16 || weiConstElemType == ElementType::bf16) {
inputB = std::make_shared<ov::op::v0::Convert>(inputB, convertOutType);
mark_as_decompression(inputB);
}
Expand Down Expand Up @@ -311,15 +311,15 @@ INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP32,
testParams2D_FP32_smoke,
MatMulDecompressConvertTest::getTestCaseName);

const auto testParams2D_FP16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));
const auto testParams2D_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16, ElementType::bf16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));

INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP16,
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D,
MatMulDecompressConvertTest,
testParams2D_FP16_smoke,
testParams2D_smoke,
MatMulDecompressConvertTest::getTestCaseName);

const auto testParams2D_BF16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes2D),
Expand All @@ -344,15 +344,15 @@ INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP32,
testParams3D_FP32_smoke,
MatMulDecompressConvertTest::getTestCaseName);

const auto testParams3D_FP16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
const auto testParams3D_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(transposeParams),
::testing::Values(ElementType::f16),
::testing::Values(ElementType::f16, ElementType::bf16),
::testing::Values(emptyConfig),
::testing::ValuesIn(filter_specific_params(false)));

INSTANTIATE_TEST_SUITE_P(smoke_FC_3D_FP16,
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D,
MatMulDecompressConvertTest,
testParams3D_FP16_smoke,
testParams3D_smoke,
MatMulDecompressConvertTest::getTestCaseName);

const auto testParams3D_BF16_smoke = ::testing::Combine(::testing::ValuesIn(inputShapes3D),
Expand Down

0 comments on commit 5def15b

Please sign in to comment.