Skip to content

Commit

Permalink
ci fix
Browse files Browse the repository at this point in the history
  • Loading branch information
isanghao committed Dec 3, 2024
1 parent 538adef commit 0dd5a8b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ using MatmulWeightsDecompressionParams = std::tuple<ShapeParams, //
bool, // reshape on decompression constants
bool, // extra multiply
bool, // per-tensor zero-point
uint64_t // dynamic_quantization_group_size
uint64_t, // dynamic_quantization_group_size
float // abs_threshold_f16
>;

class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeightsDecompressionParams>,
Expand All @@ -74,6 +75,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool extra_multiply;
bool per_tensor_zp;
uint64_t dyn_quan_group_size;
float abs_threshold_f16;

std::tie(shape_params,
weights_precision,
Expand All @@ -83,7 +85,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
reshape_on_decompression,
extra_multiply,
per_tensor_zp,
dyn_quan_group_size) = obj.param;
dyn_quan_group_size,
abs_threshold_f16) = obj.param;

std::ostringstream result;
result << "data_shape=";
Expand Down Expand Up @@ -254,6 +257,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
bool extra_multiply;
bool per_tensor_zp;
uint64_t dyn_quan_group_size;
float abs_threshold_f16 = 1.0f;

std::tie(shape_params,
weights_precision,
Expand All @@ -263,7 +267,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
reshape_on_decompression,
extra_multiply,
per_tensor_zp,
dyn_quan_group_size) = GetParam();
dyn_quan_group_size,
abs_threshold_f16) = GetParam();

init_input_shapes({shape_params.data_shape, {{}, {{shape_params.weights_shape}}}});

Expand All @@ -282,7 +287,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig


if (activations_precision == ov::element::f16) {
abs_threshold = 1.0f;
abs_threshold = abs_threshold_f16;
} else {
abs_threshold = 1e-4f;
}
Expand Down Expand Up @@ -341,7 +346,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
::testing::Values(true),
::testing::Values(false),
::testing::Values(false),
::testing::Values(0)),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);

INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply,
Expand All @@ -354,7 +360,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_extra_multiply,
::testing::Values(false),
::testing::Values(true),
::testing::Values(false),
::testing::Values(0)),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);

const std::vector<ShapeParams> input_shapes_corner_cases_basic = {
Expand Down Expand Up @@ -384,7 +391,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases_basic,
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(false),
::testing::ValuesIn(per_tensor_zp),
::testing::Values(0)),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);

INSTANTIATE_TEST_SUITE_P(MatMulCompressedWeights_corner_cases_big,
Expand All @@ -397,7 +405,8 @@ INSTANTIATE_TEST_SUITE_P(MatMulCompressedWeights_corner_cases_big,
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(false),
::testing::ValuesIn(per_tensor_zp),
::testing::Values(0)),
::testing::Values(0),
::testing::Values(1.0f)),
MatmulWeightsDecompression::get_test_case_name);


Expand All @@ -415,7 +424,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_dyn_quan,
::testing::Values(true),
::testing::Values(false),
::testing::Values(true), // per_tensor_zp
::testing::ValuesIn(group_size)),
::testing::ValuesIn(group_size),
::testing::Values(2.0f)), // Note: this is because of potential cldnn accuracy issue
MatmulWeightsDecompression::get_test_case_name);

} // namespace
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ class check_hash_value: public ::testing::Test {
const auto primitive_hash = primitve->hash();
const auto params_hash = primitve->type->get_fake_aligned_params(*prim_inst->get_impl_params()).hash();
if (!engine.get_device_info().supports_immad) {
ASSERT_EQ(primitive_hash, 8017451717095756666UL);
ASSERT_EQ(params_hash, 8889154389021912103UL);
ASSERT_EQ(primitive_hash, 9510988594087947885UL);
ASSERT_EQ(params_hash, 7833603199176871790UL);
} else {
ASSERT_EQ(primitive_hash, 8017451717095756666UL);
ASSERT_EQ(params_hash, 10847775446937354749UL);
ASSERT_EQ(primitive_hash, 9510988594087947885UL);
ASSERT_EQ(params_hash, 16259702189938020305UL);
}
}

Expand Down

0 comments on commit 0dd5a8b

Please sign in to comment.