Skip to content

Commit

Permalink
add dynamic shape test
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Dec 10, 2024
1 parent ab92158 commit 2b9470c
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ namespace snippets {

namespace {

std::vector<std::vector<InputShape>> transposedShape_4D(bool with_dynamic = true) {
auto shapes = SNIPPETS_TESTS_STATIC_SHAPES(
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}});
std::vector<std::vector<InputShape>> transposedShape_4D(bool with_static = true, bool with_dynamic = true) {
std::vector<std::vector<ov::test::InputShape>> shapes;
if (with_static) {
auto static_shapes = SNIPPETS_TESTS_STATIC_SHAPES(
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}});
shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end());
}
if (with_dynamic) {
std::vector<std::vector<ov::test::InputShape>> dynamic_shapes = {{
{PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}},
Expand Down Expand Up @@ -74,7 +78,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D,

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D_WithScalarMul,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::Values(true),
Expand Down Expand Up @@ -137,32 +141,80 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({false, true}),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Static,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);
// 3 nodes and 2 subgraph for dynamic with multiply case.
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Dynamic,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(3),
::testing::Values(2),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({false, true}),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_f16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_f16_plugin_config)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f16),
::testing::ValuesIn({true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(3),
::testing::Values(2),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_f16_plugin_config)),
MHA::getTestCaseName);
} // namespace
} // namespace snippets
} // namespace test
Expand Down
2 changes: 2 additions & 0 deletions src/tests/functional/plugin/shared/src/snippets/mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ void MHABase::SetUp() {
#endif
if (inType == ov::element::bf16)
rel_threshold = 0.05f;
if (inType == ov::element::f16)
abs_threshold = 2e-2;
}

std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
Expand Down

0 comments on commit 2b9470c

Please sign in to comment.