diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 6a205103e227cc..ced702d31ca3cc 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -12,13 +12,17 @@ namespace snippets { namespace { -std::vector> transposedShape_4D(bool with_dynamic = true) { - auto shapes = SNIPPETS_TESTS_STATIC_SHAPES( - {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, - {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, - {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); +std::vector> transposedShape_4D(bool with_static = true, bool with_dynamic = true) { + std::vector> shapes; + if (with_static) { + auto static_shapes = SNIPPETS_TESTS_STATIC_SHAPES( + {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, + {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, + {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); + shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end()); + } if (with_dynamic) { std::vector> dynamic_shapes = {{ {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, @@ -74,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D_WithScalarMul, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(true), @@ -137,32 +141,80 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), ::testing::ValuesIn(precision_fp16_if_supported(4)), ::testing::Values(ov::element::f16), - ::testing::ValuesIn({false, true}), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), ::testing::Values(MHA::default_thread_count), ::testing::Values(2), ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); +// 3 nodes and 2 subgraph for dynamic with multiply case. +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f16), - ::testing::ValuesIn({false, true}), + ::testing::ValuesIn({false}), ::testing::Values(MHA::default_thread_count), ::testing::Values(2), ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), MHA::getTestCaseName); - +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); } // namespace } // namespace snippets } // namespace test diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 8d0cb8613bc47e..0a8fcc77717c42 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -65,6 +65,8 @@ void MHABase::SetUp() { #endif if (inType == ov::element::bf16) rel_threshold = 0.05f; + if (inType == ov::element::f16) + abs_threshold = 2e-2; } std::string MHA::getTestCaseName(testing::TestParamInfo obj) {