diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 459a8c71ad611..0f328ae84f8be 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -41,6 +41,14 @@ void DebugTrap() { #endif } // namespace +static inline bool use_cosine_similarity() { + static auto value = [&] { + const char* ptr = std::getenv("ORT_TEST_USE_COSINE_SIMILARITY"); + return ptr != nullptr ? std::atoi(ptr) : 0; + }(); + return value; +} + BaseTester::~BaseTester() { #ifndef NDEBUG if (!testing_function_called_) { @@ -288,7 +296,8 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, const std::unordered_map& feeds, const std::vector& output_names, const std::string& provider_type, - bool allow_released_onnx_opset_only) { + bool allow_released_onnx_opset_only, + bool use_cosine_similarity) { fetches_.clear(); std::string s1; @@ -377,10 +386,10 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, } CheckOrtValuesAreEqual(name, expected_data.data, ort_value, expected_data.validation_params, - provider_type); + provider_type, use_cosine_similarity); } else { CheckOrtValuesAreEqual(name, expected_data.data, ort_value, expected_data.validation_params, - provider_type); + provider_type, false); } ++idx; @@ -540,6 +549,7 @@ void BaseTester::Run(SessionOptions so, RunWithConfig(number_of_pre_packed_weights_counter, number_of_shared_pre_packed_weights_counter); } + void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, size_t* number_of_shared_pre_packed_weights_counter) { std::string cur_provider = "not set"; @@ -594,7 +604,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, /*assign_ep_for_nodes=*/false, allow_released_onnx_opset_only, number_of_pre_packed_weights_counter, - number_of_shared_pre_packed_weights_counter); + number_of_shared_pre_packed_weights_counter, + use_cosine_similarity()); } else { #ifdef USE_TENSORRT // only run trt ep to reduce test time @@ -677,7 +688,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, /*try_assign_ep_for_nodes=*/true, allow_released_onnx_opset_only, number_of_pre_packed_weights_counter, - number_of_shared_pre_packed_weights_counter); + number_of_shared_pre_packed_weights_counter, + use_cosine_similarity()); // Run Models with subscribed run_options->config_options if (ctx_.run_options != nullptr && @@ -696,7 +708,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, /*assign_ep_for_nodes=*/true, allow_released_onnx_opset_only, number_of_pre_packed_weights_counter, - number_of_shared_pre_packed_weights_counter); + number_of_shared_pre_packed_weights_counter, + use_cosine_similarity()); } } @@ -735,7 +748,8 @@ void BaseTester::ExecuteModelForEps( bool try_assign_ep_for_nodes, bool allow_released_onnx_opset_only, size_t* number_of_pre_packed_weights_counter, - size_t* number_of_shared_pre_packed_weights_counter) { + size_t* number_of_shared_pre_packed_weights_counter, + bool use_cosine_similarity) { for (auto& entry : execution_providers) { // Be noted, entry in execution providers passed in OpTester will be std::moved in the first BaseTester::Run(), // To make the error more obvious to debug (instead of a segment fault), we do check explicitly here. @@ -785,7 +799,7 @@ void BaseTester::ExecuteModelForEps( ExecuteModel( model, session_object, expect_result, expected_failure_string, - run_options, feeds, output_names, provider_type, allow_released_onnx_opset_only); + run_options, feeds, output_names, provider_type, allow_released_onnx_opset_only, use_cosine_similarity); // After the model has initialized (happens in ExecuteModel), // we should be able to tell how many constant initializers were pre-packed @@ -872,7 +886,8 @@ template void BaseTester::ExecuteModel( const RunOptions* run_options, const std::unordered_map& feeds, const std::vector& output_names, const std::string& provider_type, - bool allow_released_onnx_opset_only); + bool allow_released_onnx_opset_only, + bool use_cosine_similarity); #endif } // namespace test diff --git a/onnxruntime/test/providers/base_tester.h b/onnxruntime/test/providers/base_tester.h index 5607e58315a12..63623cdbab500 100644 --- a/onnxruntime/test/providers/base_tester.h +++ b/onnxruntime/test/providers/base_tester.h @@ -658,7 +658,8 @@ class BaseTester { const std::unordered_map& feeds, const std::vector& output_names, const std::string& provider_type, - bool allow_released_onnx_opset_only = true); + bool allow_released_onnx_opset_only = true, + bool use_cosine_similarity = false); template void AddData(std::vector& data, const char* name, const DimsVariant& dims_var, const T* values, @@ -909,7 +910,8 @@ class BaseTester { bool try_assign_ep_for_nodes, bool allow_released_onnx_opset_only, size_t* number_of_pre_packed_weights_counter, - size_t* number_of_shared_pre_packed_weights_counter); + size_t* number_of_shared_pre_packed_weights_counter, + bool use_cosine_similarity); const std::string test_name_; const std::string domain_; diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 85ccb8f175f62..a16250a5fa91b 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -47,7 +47,7 @@ void sort_expected_and_actual_buffers(std::vector& expected, template struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) const { + const std::string& /*provider_type*/, bool use_cosine_similarity) const { Tensor expected_sorted, actual_sorted; const T* cur_expected; const T* cur_actual; @@ -78,7 +78,8 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& provider_type) const { + const std::string& provider_type, + bool use_cosine_similarity) const { const bool has_abs_err = params.absolute_error.has_value(); const bool has_rel_err = params.relative_error.has_value(); @@ -133,7 +134,8 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) const { + const std::string& /*provider_type*/, + bool use_cosine_similarity) const { Tensor expected_sorted, actual_sorted; const int8_t* cur_expected; const int8_t* cur_actual; @@ -173,7 +175,8 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) const { + const std::string& /*provider_type*/, + bool use_cosine_similarity) const { auto size = actual.Shape().Size(); bool has_abs_err = params.absolute_error.has_value(); @@ -227,7 +230,8 @@ template void InternalNumericalCheck(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) { + const std::string& /*provider_type*/, + bool use_cosine_similarity) { const bool has_abs_err = params.absolute_error.has_value(); const bool has_rel_err = params.relative_error.has_value(); @@ -251,7 +255,12 @@ void InternalNumericalCheck(const Tensor& expected, #else constexpr float threshold = 0.0001f; #endif + constexpr float cosine_similarity_threshold = 0.01f; + TypeToCheck dot = 0.0; + TypeToCheck denom_a = 0.0; + TypeToCheck denom_b = 0.0; + size_t diff_cnt = 0; for (int i = 0; i < size; ++i) { // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted @@ -259,6 +268,13 @@ void InternalNumericalCheck(const Tensor& expected, EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + } else if (use_cosine_similarity) { + if (abs(cur_expected[i] - cur_actual[i]) > threshold) { + dot += cur_expected[i] * cur_actual[i] ; + denom_a += cur_expected[i] * cur_expected[i] ; + denom_b += cur_actual[i] * cur_actual[i] ; + diff_cnt++; + } } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests @@ -275,6 +291,11 @@ void InternalNumericalCheck(const Tensor& expected, } } } + + if (diff_cnt) { + float cos_sim = dot / (sqrt(denom_a) * sqrt(denom_b)); + ASSERT_NEAR(cos_sim, 1.0f, cosine_similarity_threshold)<< "cos_sim is not 1.0 " << cos_sim ; + } } template <> @@ -282,8 +303,9 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& provider_type) const { - InternalNumericalCheck(expected, actual, params, provider_type); + const std::string& provider_type, + bool use_cosine_similarity) const { + InternalNumericalCheck(expected, actual, params, provider_type, use_cosine_similarity); } }; @@ -292,7 +314,8 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) const { + const std::string& /*provider_type*/, + bool use_cosine_similarity) const { auto* cur_expected = expected.Data(); auto* cur_actual = actual.Data(); auto size = actual.Shape().Size(); @@ -346,7 +369,8 @@ struct TensorCheck { void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, - const std::string& /*provider_type*/) const { + const std::string& /*provider_type*/, + bool use_cosine_similarity) const { auto* cur_expected = expected.Data(); auto* cur_actual = actual.Data(); auto size = actual.Shape().Size(); @@ -397,14 +421,14 @@ struct TensorCheck { // default Check template void Check(std::string_view name, const OrtValue& expected, const T& actual, - const ValidateOutputParams& /*params*/, const std::string& /*provider_type*/) { + const ValidateOutputParams& /*params*/, const std::string& /*provider_type*/, bool use_cosine_similarity) { EXPECT_EQ(expected.Get(), actual) << "name: " << name; } // Check for Tensors template <> void Check(std::string_view name, const OrtValue& expected, const Tensor& actual, - const ValidateOutputParams& params, const std::string& provider_type) { + const ValidateOutputParams& params, const std::string& provider_type, bool use_cosine_similarity) { const Tensor& expected_tensor = expected.Get(); ORT_ENFORCE(expected_tensor.Shape() == actual.Shape(), "Expected output shape [", expected_tensor.Shape(), @@ -420,13 +444,13 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor MLFloat16, BFloat16> t_disp(actual.GetElementType()); - t_disp.Invoke(expected_tensor, actual, params, provider_type); + t_disp.Invoke(expected_tensor, actual, params, provider_type, use_cosine_similarity); } // Check for sequence of tensors template <> void Check(std::string_view name, const OrtValue& expected, const TensorSeq& actual, - const ValidateOutputParams& params, const std::string& provider_type) { + const ValidateOutputParams& params, const std::string& provider_type, bool use_cosine_similarity) { const auto& exp_seq = expected.Get(); // first ensure data types match @@ -454,15 +478,15 @@ void Check(std::string_view name, const OrtValue& expected, const Ten t_disp(element_type); for (size_t i = 0; i < actual_num_tensors; ++i) { - t_disp.Invoke(exp_seq.Get(i), actual.Get(i), params, provider_type); + t_disp.Invoke(exp_seq.Get(i), actual.Get(i), params, provider_type, use_cosine_similarity); } } template void CheckDispatch(MLDataType type, std::string_view name, const OrtValue& expected, const OrtValue& actual, - const ValidateOutputParams& params, const std::string& provider_type) { + const ValidateOutputParams& params, const std::string& provider_type, bool use_cosine_similarity) { if (type == DataTypeImpl::GetType()) { - Check(name, expected, actual.Get(), params, provider_type); + Check(name, expected, actual.Get(), params, provider_type, use_cosine_similarity); } else { ORT_THROW("OpTester:Check() not implemented for output tensor type of ", type); } @@ -470,16 +494,16 @@ void CheckDispatch(MLDataType type, std::string_view name, const OrtValue& expec template void CheckDispatch(MLDataType type, std::string_view name, const OrtValue& expected, const OrtValue& actual, - const ValidateOutputParams& params, const std::string& provider_type) { + const ValidateOutputParams& params, const std::string& provider_type, bool use_cosine_similarity) { if (type == DataTypeImpl::GetType()) { - Check(name, expected, actual.Get(), params, provider_type); + Check(name, expected, actual.Get(), params, provider_type, use_cosine_similarity); } else { - CheckDispatch(type, name, expected, actual, params, provider_type); + CheckDispatch(type, name, expected, actual, params, provider_type, use_cosine_similarity); } } void CheckOrtValuesAreEqual(std::string_view name, const OrtValue& expected, const OrtValue& actual, - const ValidateOutputParams& params, const std::string& provider_type) { + const ValidateOutputParams& params, const std::string& provider_type, bool use_cosine_similarity) { // Include provider_type in any error output SCOPED_TRACE(MakeString("provider type: ", provider_type)); @@ -488,7 +512,7 @@ void CheckOrtValuesAreEqual(std::string_view name, const OrtValue& expected, con #if !defined(DISABLE_ML_OPS) VectorMapStringToFloat, VectorMapInt64ToFloat, #endif - TensorSeq>(expected.Type(), name, expected, actual, params, provider_type); + TensorSeq>(expected.Type(), name, expected, actual, params, provider_type, use_cosine_similarity); } } // namespace test diff --git a/onnxruntime/test/providers/checkers.h b/onnxruntime/test/providers/checkers.h index 54f3bb8f0fe5d..a10b05c19902a 100644 --- a/onnxruntime/test/providers/checkers.h +++ b/onnxruntime/test/providers/checkers.h @@ -27,7 +27,8 @@ struct ValidateOutputParams { /// Optional parameters to adjust how the check is performed. /// Execution provider type if relevant. void CheckOrtValuesAreEqual(std::string_view name, const OrtValue& expected, const OrtValue& actual, - const ValidateOutputParams& params = {}, const std::string& provider_type = ""); + const ValidateOutputParams& params = {}, const std::string& provider_type = "", + bool use_cosine_similarity=false); } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 3d53d4a3a0193..118645ddc0b35 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -65,6 +65,26 @@ const char* ElementTypeToString(MLDataType type) { return DataTypeImpl::ToString(type); } +static inline bool use_cosine_similarity() { + static auto value = [&] { + const char* ptr = std::getenv("ORT_TEST_USE_COSINE_SIMILARITY"); + return ptr != nullptr ? std::atoi(ptr) : 0; + }(); + return value; +} + +template +TypeToCheck cosine_similarity(const TypeToCheck *A, const TypeToCheck *B, size_t Vector_Length) +{ + TypeToCheck dot = 0.0, denom_a = 0.0, denom_b = 0.0 ; + for(size_t i = 0u; i < Vector_Length; ++i) { + dot += A[i] * B[i] ; + denom_a += A[i] * A[i] ; + denom_b += B[i] * B[i] ; + } + return dot / (sqrt(denom_a) * sqrt(denom_b)) ; +} + /** * @brief Check if two values are closely matched with given tolerance. @@ -105,12 +125,24 @@ std::pair CompareFloatResult(const Tensor& outvalue std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); double max_diff = 0; size_t diff_count = 0; + const float cosine_similarity_threshold = 0.99; for (size_t di = 0; di != size1; ++di) { const double real_value = post_processing ? std::max(0.0, std::min(255.0, real_output[di])) : real_output[di]; const double diff = std::fabs(expected_output[di] - real_value); const double tol = per_sample_tolerance + relative_per_sample_tolerance * std::fabs(expected_output[di]); if (!IsResultCloselyMatch(real_value, expected_output[di], diff, tol)) { + if (use_cosine_similarity()) { + float cos_sim = cosine_similarity(real_output, expected_output, size1); + if (abs(cos_sim) < cosine_similarity_threshold) { + res.first = COMPARE_RESULT::RESULT_DIFFERS; + std::ostringstream oss; + oss << std::hex << "results differed, cosine similarity factor is " << cos_sim << "."; + res.second = oss.str(); + } + return res; + } + res.first = COMPARE_RESULT::RESULT_DIFFERS; // update error message if this is a larger diff if (diff > max_diff || (std::isnan(diff) && !std::isnan(max_diff))) {