diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 087b9d604128e..7239e5242543d 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -6,6 +6,7 @@ #include "TestCase.h" #include +#include #include #include #include @@ -731,6 +732,8 @@ void LoadTests(const std::vector>& input_paths const std::vector>& whitelisted_test_cases, const TestTolerances& tolerances, const std::unordered_set>& disabled_tests, + std::unique_ptr> broken_tests, + std::unique_ptr> broken_tests_keyword_set, const std::function)>& process_function) { std::vector> paths(input_paths); while (!paths.empty()) { @@ -783,11 +786,60 @@ void LoadTests(const std::vector>& input_paths ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported"); } + auto test_case_dir = model_info->GetDir(); + auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir; + +#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN) + // to skip some models like *-int8 or *-qdq + if ((reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || + (reinterpret_cast(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain"); + return true; + } +#endif + + bool has_test_data = false; + LoopDir(test_case_dir, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool { + if (filename[0] == '.') return true; + if (f_type == OrtFileType::TYPE_DIR) { + has_test_data = true; + return false; + } + return true; + }); + if (!has_test_data) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data"); + return true; + } + + if (broken_tests) { + BrokenTest t = {ToUTF8String(test_case_name), ""}; + auto iter = broken_tests->find(t); + auto opset_version = model_info->GetNominalOpsetVersion(); + if (iter != broken_tests->end() && + (opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() || + iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests"); + return true; + } + } + + if (broken_tests_keyword_set) { + for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) { + std::string keyword = *iter2; + if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) { + fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords"); + return true; + } + } + } + const auto tolerance_key = ToUTF8String(my_dir_name); std::unique_ptr l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info), tolerances.absolute(tolerance_key), tolerances.relative(tolerance_key)); + fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str()); process_function(std::move(l)); return true; }); diff --git a/onnxruntime/test/onnx/TestCase.h b/onnxruntime/test/onnx/TestCase.h index 4d4b2177019c9..96b0b5f6f7c08 100644 --- a/onnxruntime/test/onnx/TestCase.h +++ b/onnxruntime/test/onnx/TestCase.h @@ -101,12 +101,6 @@ class TestTolerances { const Map relative_overrides_; }; -void LoadTests(const std::vector>& input_paths, - const std::vector>& whitelisted_test_cases, - const TestTolerances& tolerances, - const std::unordered_set>& disabled_tests, - const std::function)>& process_function); - struct BrokenTest { std::string test_name_; std::string reason_; @@ -118,6 +112,16 @@ struct BrokenTest { } }; +void LoadTests(const std::vector>& input_paths, + const std::vector>& whitelisted_test_cases, + const TestTolerances& tolerances, + const std::unordered_set>& disabled_tests, + std::unique_ptr> broken_test_list, + std::unique_ptr> broken_tests_keyword_set, + const std::function)>& process_function); + std::unique_ptr> GetBrokenTests(const std::string& provider_name); std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); + +std::unique_ptr> GetBrokenTestsKeyWordSet(const std::string& provider_name); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index f165b3a4a647a..de5431ca4a460 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -783,10 +783,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); all_disabled_tests.insert(std::begin(x86_disabled_tests), std::end(x86_disabled_tests)); #endif + auto broken_tests = GetBrokenTests(provider_name); + auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet(provider_name); std::vector tests; LoadTests(data_dirs, whitelisted_test_cases, LoadTestTolerances(enable_cuda, enable_openvino, override_tolerance, atol, rtol), all_disabled_tests, + std::move(broken_tests), + std::move(broken_tests_keyword_set), [&owned_tests, &tests](std::unique_ptr l) { tests.push_back(l.get()); owned_tests.push_back(std::move(l)); @@ -803,18 +807,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); fwrite(res.c_str(), 1, res.size(), stdout); } - auto broken_tests = GetBrokenTests(provider_name); int result = 0; for (const auto& p : stat.GetFailedTest()) { - BrokenTest t = {p.first, ""}; - auto iter = broken_tests->find(t); - if (iter == broken_tests->end() || (p.second != TestModelInfo::unknown_version && !iter->broken_opset_versions_.empty() && - iter->broken_opset_versions_.find(p.second) == iter->broken_opset_versions_.end())) { - fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); - result = -1; - } else { - fprintf(stderr, "test %s failed, but it is a known broken test, so we ignore it\n", p.first.c_str()); - } + fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str()); + result = -1; } return result; } diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 5057f74046638..cb5cbbecb5ef0 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -238,11 +238,16 @@ static std::vector GetAllTestCases() { // Bad onnx test output caused by previously wrong SAME_UPPER/SAME_LOWER for ConvTranspose allDisabledTests.insert(ORT_TSTR("cntk_simple_seg")); + auto broken_tests = GetBrokenTests("dml"); + auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet("dml"); + WINML_EXPECT_NO_THROW(LoadTests( dataDirs, whitelistedTestCases, TestTolerances(1e-3, 1e-3, {}, {}), allDisabledTests, + std::move(broken_tests), + std::move(broken_tests_keyword_set), [&tests](std::unique_ptr l) { tests.push_back(l.get()); ownedTests.push_back(std::move(l));