Skip to content

Commit

Permalink
Enable onnx_test_runner to run the whole models dir in CI machine (mi…
Browse files Browse the repository at this point in the history
…crosoft#17863)

### Description
1. If the model should be skipped, don't load it.
2. print loaded tests and skipped tests
3. add more same filters as of the onnxruntime_test_all.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
mszhanyi authored and kleiti committed Mar 22, 2024
1 parent 231f928 commit 9bd9d70
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 16 deletions.
52 changes: 52 additions & 0 deletions onnxruntime/test/onnx/TestCase.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "TestCase.h"

#include <cctype>
#include <filesystem>
#include <fstream>
#include <memory>
#include <sstream>
Expand Down Expand Up @@ -731,6 +732,8 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
const std::vector<std::basic_string<PATH_CHAR_TYPE>>& whitelisted_test_cases,
const TestTolerances& tolerances,
const std::unordered_set<std::basic_string<ORTCHAR_T>>& disabled_tests,
std::unique_ptr<std::set<BrokenTest>> broken_tests,
std::unique_ptr<std::set<std::string>> broken_tests_keyword_set,
const std::function<void(std::unique_ptr<ITestCase>)>& process_function) {
std::vector<std::basic_string<PATH_CHAR_TYPE>> paths(input_paths);
while (!paths.empty()) {
Expand Down Expand Up @@ -783,11 +786,60 @@ void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths
ORT_NOT_IMPLEMENTED(ToUTF8String(filename_str), " is not supported");
}

auto test_case_dir = model_info->GetDir();
auto test_case_name_in_log = test_case_name + ORT_TSTR(" in ") + test_case_dir;

#if !defined(ORT_MINIMAL_BUILD) && !defined(USE_QNN)
// to skip some models like *-int8 or *-qdq
if ((reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) ||
(reinterpret_cast<OnnxModelInfo*>(model_info.get()))->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it has training domain");
return true;
}
#endif

bool has_test_data = false;
LoopDir(test_case_dir, [&](const PATH_CHAR_TYPE* filename, OrtFileType f_type) -> bool {
if (filename[0] == '.') return true;
if (f_type == OrtFileType::TYPE_DIR) {
has_test_data = true;
return false;
}
return true;
});
if (!has_test_data) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to no test data");
return true;
}

if (broken_tests) {
BrokenTest t = {ToUTF8String(test_case_name), ""};
auto iter = broken_tests->find(t);
auto opset_version = model_info->GetNominalOpsetVersion();
if (iter != broken_tests->end() &&
(opset_version == TestModelInfo::unknown_version || iter->broken_opset_versions_.empty() ||
iter->broken_opset_versions_.find(opset_version) != iter->broken_opset_versions_.end())) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " due to broken_tests");
return true;
}
}

if (broken_tests_keyword_set) {
for (auto iter2 = broken_tests_keyword_set->begin(); iter2 != broken_tests_keyword_set->end(); ++iter2) {
std::string keyword = *iter2;
if (ToUTF8String(test_case_name).find(keyword) != std::string::npos) {
fprintf(stderr, "Skip test case:: %s %s\n", ToUTF8String(test_case_name_in_log).c_str(), " as it is in broken test keywords");
return true;
}
}
}

const auto tolerance_key = ToUTF8String(my_dir_name);

std::unique_ptr<ITestCase> l = CreateOnnxTestCase(ToUTF8String(test_case_name), std::move(model_info),
tolerances.absolute(tolerance_key),
tolerances.relative(tolerance_key));
fprintf(stdout, "Load Test Case: %s\n", ToUTF8String(test_case_name_in_log).c_str());
process_function(std::move(l));
return true;
});
Expand Down
16 changes: 10 additions & 6 deletions onnxruntime/test/onnx/TestCase.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,6 @@ class TestTolerances {
const Map relative_overrides_;
};

void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths,
const std::vector<std::basic_string<PATH_CHAR_TYPE>>& whitelisted_test_cases,
const TestTolerances& tolerances,
const std::unordered_set<std::basic_string<ORTCHAR_T>>& disabled_tests,
const std::function<void(std::unique_ptr<ITestCase>)>& process_function);

struct BrokenTest {
std::string test_name_;
std::string reason_;
Expand All @@ -118,6 +112,16 @@ struct BrokenTest {
}
};

void LoadTests(const std::vector<std::basic_string<PATH_CHAR_TYPE>>& input_paths,
const std::vector<std::basic_string<PATH_CHAR_TYPE>>& whitelisted_test_cases,
const TestTolerances& tolerances,
const std::unordered_set<std::basic_string<ORTCHAR_T>>& disabled_tests,
std::unique_ptr<std::set<BrokenTest>> broken_test_list,
std::unique_ptr<std::set<std::string>> broken_tests_keyword_set,
const std::function<void(std::unique_ptr<ITestCase>)>& process_function);

std::unique_ptr<std::set<BrokenTest>> GetBrokenTests(const std::string& provider_name);

std::unique_ptr<std::set<std::string>> GetBrokenTestsKeyWordSet(const std::string& provider_name);

std::unique_ptr<std::set<std::string>> GetBrokenTestsKeyWordSet(const std::string& provider_name);
16 changes: 6 additions & 10 deletions onnxruntime/test/onnx/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
all_disabled_tests.insert(std::begin(x86_disabled_tests), std::end(x86_disabled_tests));
#endif

auto broken_tests = GetBrokenTests(provider_name);
auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet(provider_name);
std::vector<ITestCase*> tests;
LoadTests(data_dirs, whitelisted_test_cases,
LoadTestTolerances(enable_cuda, enable_openvino, override_tolerance, atol, rtol),
all_disabled_tests,
std::move(broken_tests),
std::move(broken_tests_keyword_set),
[&owned_tests, &tests](std::unique_ptr<ITestCase> l) {
tests.push_back(l.get());
owned_tests.push_back(std::move(l));
Expand All @@ -803,18 +807,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
fwrite(res.c_str(), 1, res.size(), stdout);
}

auto broken_tests = GetBrokenTests(provider_name);
int result = 0;
for (const auto& p : stat.GetFailedTest()) {
BrokenTest t = {p.first, ""};
auto iter = broken_tests->find(t);
if (iter == broken_tests->end() || (p.second != TestModelInfo::unknown_version && !iter->broken_opset_versions_.empty() &&
iter->broken_opset_versions_.find(p.second) == iter->broken_opset_versions_.end())) {
fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str());
result = -1;
} else {
fprintf(stderr, "test %s failed, but it is a known broken test, so we ignore it\n", p.first.c_str());
}
fprintf(stderr, "test %s failed, please fix it\n", p.first.c_str());
result = -1;
}
return result;
}
Expand Down
5 changes: 5 additions & 0 deletions winml/test/model/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,16 @@ static std::vector<ITestCase*> GetAllTestCases() {
// Bad onnx test output caused by previously wrong SAME_UPPER/SAME_LOWER for ConvTranspose
allDisabledTests.insert(ORT_TSTR("cntk_simple_seg"));

auto broken_tests = GetBrokenTests("dml");
auto broken_tests_keyword_set = GetBrokenTestsKeyWordSet("dml");

WINML_EXPECT_NO_THROW(LoadTests(
dataDirs,
whitelistedTestCases,
TestTolerances(1e-3, 1e-3, {}, {}),
allDisabledTests,
std::move(broken_tests),
std::move(broken_tests_keyword_set),
[&tests](std::unique_ptr<ITestCase> l) {
tests.push_back(l.get());
ownedTests.push_back(std::move(l));
Expand Down

0 comments on commit 9bd9d70

Please sign in to comment.