diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 7239e5242543d..087b9d604128e 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -6,7 +6,6 @@ #include "TestCase.h" #include -#include #include #include #include @@ -732,8 +731,6 @@ void LoadTests(const std::vector>& input_paths const std::vector>& whitelisted_test_cases, const TestTolerances& tolerances, const std::unordered_set>& disabled_tests, - std::unique_ptr> broken_tests, - std::unique_ptr> broken_tests_keyword_set, const std::function)>& process_function) { std::vector> paths(input_paths); while (!paths.empty()) { @@ -786,60 +783,11 @@ void LoadTests(const std::vector>& input_paths ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported"); } - auto test_case_dir = model_info->GetDir(); - auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir; - -#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) - // to skip some models like *-int8 or *-qdq - if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || - (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain"); - return true; - } -#endif - - bool has_test_data = false; - LoopDir(test_case_dir, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool { - if (filename[0] == '.') return true; - if (f_type == OrtFileType::TYPE_DIR) { - has_test_data = true; - return false; - } - return true; - }); - if (!has_test_data) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data"); - return true; - } - - if (broken_tests) { - BrokenTest t = {ToUTF8String(test_case_name), ""}; - auto iter = broken_tests->find(t); - auto opset_version = model_info->GetNominalOpsetVersion(); - if (iter != broken_tests->end() && - (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() || - iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests"); - return true; - } - } - - if (broken_tests_keyword_set) { - for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) { - std::string keyword = *iter2; - if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) { - fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords"); - return true; - } - } - } - const auto tolerance_key = ToUTF8String(my_dir_name); std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), tolerances.absolute(tolerance_key), tolerances.relative(tolerance_key)); - fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str()); process_function(std::move(l)); return true; }); diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 96b0b5f6f7c08..4d4b2177019c9 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -101,6 +101,12 @@ class TestTolerances { const Map relative_overrides_; }; +void LoadTests(const std::vector>& input_paths, + const std::vector>& whitelisted_test_cases, + const TestTolerances& tolerances, + const std::unordered_set>& disabled_tests, + const std::function)>& process_function); + struct BrokenTest { std::string test_name_; std::string reason_; @@ -112,16 +118,6 @@ struct BrokenTest { } }; -void LoadTests(const std::vector>& input_paths, - const std::vector>& whitelisted_test_cases, - const TestTolerances& tolerances, - const std::unordered_set>& disabled_tests, - std::unique_ptr> broken_test_list, - std::unique_ptr> broken_tests_keyword_set, - const std::function)>& process_function); - std::unique_ptr> GetBrokenTests(const std::string& provider_name); std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); - -std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index de5431ca4a460..f165b3a4a647a 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -783,14 +783,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); all_disabled_tests.insert(std::begin(x86_disabled_tests), std::end(x86_disabled_tests)); #endif - auto broken_tests = GetBrokenTests(provider_name); - auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet(provider_name); std::vector tests; LoadTests(data_dirs, whitelisted_test_cases, LoadTestTolerances(enable_cuda, enable_openvino, override_tolerance, atol, rtol), all_disabled_tests, - std::move(broken_tests), - std::move(broken_tests_keyword_set), [&owned_tests, &tests](std::unique_ptr l) { tests.push_back(l.get()); owned_tests.push_back(std::move(l)); @@ -807,10 +803,18 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); fwrite(res.c_str(), 1, res.size(), stdout); } + auto broken_tests = GetBrokenTests(provider_name); int result = 0; for (const auto& p : stat.GetFailedTest()) { - fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); - result = -1; + BrokenTest t = {p.first, ""}; + auto iter = broken_tests->find(t); + if (iter == broken_tests->end() || (p.second != TestModelInfo::unknown_version && !iter->broken_opset_versions_.empty() && + iter->broken_opset_versions_.find(p.second) == iter->broken_opset_versions_.end())) { + fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); + result = -1; + } else { + fprintf(stderr, "test %s failed, but it is a known broken test, so we ignore it\n", p.first.c_str()); + } } return result; } diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index cb5cbbecb5ef0..5057f74046638 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -238,16 +238,11 @@ static std::vector GetAllTestCases() { // Bad onnx test output caused by previously wrong SAME_UPPER/SAME_LOWER for ConvTranspose allDisabledTests.insert(ORT_TSTR("cntk_simple_seg")); - auto broken_tests = GetBrokenTests("dml"); - auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet("dml"); - WINML_EXPECT_NO_THROW(LoadTests( dataDirs, whitelistedTestCases, TestTolerances(1e-3, 1e-3, {}, {}), allDisabledTests, - std::move(broken_tests), - std::move(broken_tests_keyword_set), [&tests](std::unique_ptr l) { tests.push_back(l.get()); ownedTests.push_back(std::move(l));