diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 7c8c70f913dca..5f8c7bd084d02 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1171,6 +1171,45 @@ endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + # Accuracy test runner + if (onnxruntime_BUILD_SHARED_LIB) + set(onnxruntime_acc_test_src_dir ${TEST_SRC_DIR}/acc_test) + set(onnxruntime_acc_test_src_patterns + "${onnxruntime_acc_test_src_dir}/*.cc" + "${onnxruntime_acc_test_src_dir}/*.h") + + file(GLOB onnxruntime_acc_test_src CONFIGURE_DEPENDS + ${onnxruntime_acc_test_src_patterns} + ) + onnxruntime_add_executable(onnxruntime_acc_test ${onnxruntime_acc_test_src}) + target_include_directories(onnxruntime_acc_test PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + if (WIN32) + target_compile_options(onnxruntime_acc_test PRIVATE ${disabled_warnings}) + endif() + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_acc_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() + + set(onnxruntime_acc_test_libs onnxruntime) + if(NOT WIN32) + list(APPEND onnxruntime_acc_test_libs stdc++fs ${CMAKE_DL_LIBS}) + endif() + if (onnxruntime_LINK_LIBATOMIC) + list(APPEND onnxruntime_acc_test_libs atomic) + endif() + target_link_libraries(onnxruntime_acc_test PRIVATE ${onnxruntime_acc_test_libs} Threads::Threads) + + set_target_properties(onnxruntime_acc_test PROPERTIES FOLDER "ONNXRuntimeTest") + + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_acc_test PRIVATE "/STACK:4000000") + endif() + endif() + endif() + #perf test runner set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) set(onnxruntime_perf_test_src_patterns diff --git a/onnxruntime/test/acc_test/acc_task.cc b/onnxruntime/test/acc_test/acc_task.cc new file mode 100644 index 0000000000000..2939b67feafaf --- /dev/null +++ b/onnxruntime/test/acc_test/acc_task.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "acc_task.h" +#include +#include +#include +#include +#include + +static std::vector RunInference(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer) { + // Setup input + const std::vector& input_infos = model_io_info.inputs; + const size_t num_inputs = input_infos.size(); + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.reserve(num_inputs); + ort_input_names.reserve(num_inputs); + + for (size_t input_offset = 0, i = 0; i < num_inputs; input_offset += input_infos[i].total_data_size, i++) { + assert(input_offset < input_buffer.size()); + const IOInfo& input_info = input_infos[i]; + Span input_data(&input_buffer[input_offset], input_info.total_data_size); + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, (void*)input_data.data(), input_data.size(), + input_info.shape.data(), input_info.shape.size(), + input_info.data_type)); + ort_input_names.push_back(input_info.name.c_str()); + } + + const size_t num_outputs = model_io_info.outputs.size(); + std::vector ort_output_names; + ort_output_names.reserve(num_outputs); + + for (size_t i = 0; i < num_outputs; i++) { + ort_output_names.push_back(model_io_info.outputs[i].name.c_str()); + } + + return session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), ort_output_names.data(), ort_output_names.size()); +} + +void Task::Run() { + std::vector ort_output_vals = RunInference(session_, model_io_info_, input_buffer_); + + AccuracyCheck* accuracy_check_data = std::get_if(&variant_); + if (accuracy_check_data) { + const std::vector& output_infos = model_io_info_.get().outputs; + const size_t num_outputs = output_infos.size(); + Span expected_output_buffer = accuracy_check_data->expected_output_buffer; + + for (size_t output_offset = 0, i = 0; i < num_outputs; output_offset += output_infos[i].total_data_size, i++) { + assert(output_offset < expected_output_buffer.size()); + const IOInfo& output_info = output_infos[i]; + Span raw_expected_output(&expected_output_buffer[output_offset], output_info.total_data_size); + + accuracy_check_data->output_acc_metric[i] = ComputeAccuracyMetric(ort_output_vals[i].GetConst(), + raw_expected_output, + output_info); + } + return; + } + + Inference* inference_data = std::get_if(&variant_); + if (inference_data) { + Span& output_buffer = inference_data->output_buffer; + + // Unfortunately, we have to copy output values (Ort::Value is not copyable, so it is limited when stored in a std::vector) + const std::vector& output_infos = model_io_info_.get().outputs; + const size_t num_outputs = output_infos.size(); + + for (size_t output_offset = 0, i = 0; i < num_outputs; output_offset += output_infos[i].total_data_size, i++) { + assert(output_offset < output_buffer.size()); + std::memcpy(&output_buffer[output_offset], + ort_output_vals[i].GetTensorRawData(), + output_infos[i].total_data_size); + } + return; + } + + // Should not reach this line unless we add a new (unhandled) std::variant type. + std::cerr << "[ERROR]: Unhandled std::variant type for Task::variant_ member." << std::endl; + std::abort(); +} diff --git a/onnxruntime/test/acc_test/acc_task.h b/onnxruntime/test/acc_test/acc_task.h new file mode 100644 index 0000000000000..9fe62da562d1a --- /dev/null +++ b/onnxruntime/test/acc_test/acc_task.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include "basic_utils.h" +#include "model_io_utils.h" + +class Task { + private: + struct Inference { + Span output_buffer; + }; + + struct AccuracyCheck { + Span expected_output_buffer; + Span output_acc_metric; + }; + + public: + Task(Task&& other) = default; + Task(const Task& other) = default; + Task(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span output_buffer) + : session_(session), model_io_info_(model_io_info), input_buffer_(input_buffer), variant_(Inference{output_buffer}) {} + Task(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span expected_output_buffer, Span output_acc_metric) + : session_(session), model_io_info_(model_io_info), input_buffer_(input_buffer), variant_(AccuracyCheck{expected_output_buffer, output_acc_metric}) {} + + static Task CreateInferenceTask(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span output_buffer) { + return Task(session, model_io_info, input_buffer, output_buffer); + } + + static Task CreateAccuracyCheckTask(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span expected_output_buffer, + Span output_acc_metric) { + return Task(session, model_io_info, input_buffer, expected_output_buffer, output_acc_metric); + } + + void Run(); + + private: + std::reference_wrapper session_; + std::reference_wrapper model_io_info_; + Span input_buffer_; + std::variant variant_; +}; diff --git a/onnxruntime/test/acc_test/basic_utils.cc b/onnxruntime/test/acc_test/basic_utils.cc new file mode 100644 index 0000000000000..d72d496af2bfc --- /dev/null +++ b/onnxruntime/test/acc_test/basic_utils.cc @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "basic_utils.h" +#include +#include + +bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepath) { + std::ifstream input_ifs(binary_filepath, std::ifstream::binary); + + if (!input_ifs.is_open()) { + return false; + } + + input_ifs.seekg(0, input_ifs.end); + auto file_byte_size = input_ifs.tellg(); + input_ifs.seekg(0, input_ifs.beg); + + if (static_cast(file_byte_size) != array.size()) { + return false; + } + + input_ifs.read(array.data(), file_byte_size); + return static_cast(input_ifs); +} + +int32_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix) { + int32_t index = -1; + const char* str = filename_wo_ext.c_str(); + + // Move past the prefix. + while (*str && *prefix && *str == *prefix) { + str++; + prefix++; + } + + if (*prefix) { + return -1; // File doesn't start with the prefix. + } + + // Parse the input index from file name. + index = 0; + while (*str) { + int32_t c = *str; + if (!(c >= '0' && c <= '9')) { + return -1; // Not a number. + } + + index *= 10; + index += (c - '0'); + str++; + } + + return index; +} diff --git a/onnxruntime/test/acc_test/basic_utils.h b/onnxruntime/test/acc_test/basic_utils.h new file mode 100644 index 0000000000000..44c0f0e758b60 --- /dev/null +++ b/onnxruntime/test/acc_test/basic_utils.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +// Make a bootleg std::span for C++ versions older than 20 +template +class Span { + public: + Span() = default; + Span(T* data, size_t size) : data_(data), size_(size) {} + Span(std::vector& vec) : data_(vec.data()), size_(vec.size()) {} + Span(const std::vector>& vec) : data_(vec.data()), size_(vec.size()) {} + + template + Span(std::array arr) : data_(arr.data()), size_(N) {} + + Span(const Span& other) = default; + Span(Span&& other) = default; + + Span& operator=(const Span& other) = default; + Span& operator=(Span&& other) = default; + + T& operator[](size_t index) const { + return data_[index]; + } + + T* data() const { return data_; } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + private: + T* data_{nullptr}; + size_t size_{0}; +}; + +template +static Span ReinterpretBytesAsSpan(Span, const char, char>> bytes_span) { + return Span(reinterpret_cast(bytes_span.data()), bytes_span.size() / sizeof(T)); +} + +template +constexpr int64_t GetShapeSize(Span shape) { + int64_t size = 1; + + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + + return size; +} + +int32_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix); +bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepath); + +constexpr double EPSILON_DBL = 2e-16; + +struct AccMetrics { + double rmse = 0.0; + double snr = 0.0; + double min_val = 0.0; + double max_val = 0.0; + double min_expected_val = 0.0; + double max_expected_val = 0.0; + friend bool operator==(const AccMetrics& l, const AccMetrics& r) { + if (l.rmse != r.rmse) return false; + if (l.min_val != r.min_val) return false; + if (l.max_val != r.max_val) return false; + if (l.min_expected_val != r.min_expected_val) return false; + if (l.max_expected_val != r.max_expected_val) return false; + if (l.snr != r.snr) return false; + + return true; + } + friend bool operator!=(const AccMetrics& l, const AccMetrics& r) { + return !(l == r); + } +}; + +template +void GetAccuracy(Span expected_output, Span actual_output, AccMetrics& metrics) { + // Compute RMSE. This is not a great way to measure accuracy, but .... + assert(expected_output.size() == actual_output.size()); + const size_t num_outputs = expected_output.size(); + + metrics.rmse = 0.0; + metrics.min_val = static_cast(actual_output[0]); + metrics.max_val = static_cast(actual_output[0]); + metrics.min_expected_val = static_cast(expected_output[0]); + metrics.max_expected_val = static_cast(expected_output[0]); + double tensor_norm = 0.0; + double diff_norm = 0.0; + for (size_t i = 0; i < num_outputs; i++) { + double diff = static_cast(actual_output[i]) - static_cast(expected_output[i]); + diff_norm += diff * diff; + tensor_norm += static_cast(expected_output[i]) * static_cast(expected_output[i]); + + metrics.rmse += diff * diff; + metrics.min_val = std::min(metrics.min_val, static_cast(actual_output[i])); + metrics.max_val = std::max(metrics.max_val, static_cast(actual_output[i])); + metrics.min_expected_val = std::min(metrics.min_expected_val, static_cast(expected_output[i])); + metrics.max_expected_val = std::max(metrics.max_expected_val, static_cast(expected_output[i])); + } + + metrics.rmse = std::sqrt(metrics.rmse / static_cast(num_outputs)); + + tensor_norm = std::max(std::sqrt(tensor_norm), EPSILON_DBL); + diff_norm = std::max(std::sqrt(diff_norm), EPSILON_DBL); + metrics.snr = 20.0 * std::log10(tensor_norm / diff_norm); +} diff --git a/onnxruntime/test/acc_test/cmd_args.cc b/onnxruntime/test/acc_test/cmd_args.cc new file mode 100644 index 0000000000000..b12459e02704f --- /dev/null +++ b/onnxruntime/test/acc_test/cmd_args.cc @@ -0,0 +1,294 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "cmd_args.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct CmdArgs { + CmdArgs(int argc, char** argv) noexcept : argc_(argc), argv_(argv), index_(0) {} + + [[nodiscard]] bool HasNext() const { return index_ < argc_; } + + [[nodiscard]] std::string_view GetNext() { + assert(HasNext()); + return argv_[index_++]; + } + + [[nodiscard]] std::string_view PeekNext() { + assert(HasNext()); + return argv_[index_]; + } + + private: + int argc_; + char** argv_; + int index_; +}; + +static void PrintUsage(std::ostream& stream, std::string_view prog_name) { + stream << "Usage: " << prog_name << " [OPTIONS]" + << std::endl; + stream << "OPTIONS:" << std::endl; + stream << " -h/--help Print this help message" << std::endl; + stream << " -t/--test_dir Path to test directory with models and inputs/outputs" << std::endl; + stream << " -l/--load_expected_outputs Load expected outputs from raw output_.raw files" << std::endl; + stream << " -s/--save_expected_outputs Save outputs from baseline model on CPU EP to disk" << std::endl; + stream << " -e/--execution_provider The execution provider to test (e.g., qnn)" << std::endl; + stream << " -o/--output_file The output file into which to save accuracy results" << std::endl; + stream << " -a/--expected_accuracy_file The file containing expected accuracy results" << std::endl + << std::endl; +} + +static bool ParseQnnRuntimeOptions(std::string ep_config_string, + std::unordered_map& qnn_options) { + std::istringstream ss(ep_config_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + std::cerr << "Use a '|' to separate the key and value for the run-time option you are trying to use." << std::endl; + return false; + } + + std::string_view key(token_sv.substr(0, pos)); + std::string_view value(token_sv.substr(pos + 1)); + + if (key == "backend_path") { + if (value.empty()) { + std::cerr << "[ERROR]: Please provide the QNN backend path." << std::endl; + return false; + } + } else if (key == "qnn_context_cache_enable") { + if (value != "1") { + std::cerr << "[ERROR]: Set to 1 to enable qnn_context_cache_enable." << std::endl; + return false; + } + } else if (key == "qnn_context_cache_path") { + // no validation + } else if (key == "profiling_level") { + std::unordered_set supported_profiling_level = {"off", "basic", "detailed"}; + if (supported_profiling_level.find(value) == supported_profiling_level.end()) { + std::cerr << "[ERROR]: Supported profiling_level: off, basic, detailed" << std::endl; + return false; + } + } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + // no validation + } else if (key == "htp_performance_mode") { + std::unordered_set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", + "high_power_saver", "low_balanced", "low_power_saver", + "power_saver", "sustained_high_performance"}; + if (supported_htp_perf_mode.find(value) == supported_htp_perf_mode.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_perf_mode.begin(), supported_htp_perf_mode.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + std::cerr << "[ERROR]: Supported htp_performance_mode: " << str << std::endl; + return false; + } + } else if (key == "qnn_saver_path") { + // no validation + } else if (key == "htp_graph_finalization_optimization_mode") { + std::unordered_set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; + if (supported_htp_graph_final_opt_modes.find(value) == supported_htp_graph_final_opt_modes.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_graph_final_opt_modes.begin(), supported_htp_graph_final_opt_modes.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + std::cerr << "[ERROR]: Wrong value for htp_graph_finalization_optimization_mode. select from: " << str << std::endl; + return false; + } + } else if (key == "qnn_context_priority") { + std::unordered_set supported_qnn_context_priority = {"low", "normal", "normal_high", "high"}; + if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { + std::cerr << "[ERROR]: Supported qnn_context_priority: low, normal, normal_high, high" << std::endl; + return false; + } + } else { + std::cerr << R"([ERROR]: Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', +'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])" + << std::endl; + return false; + } + + qnn_options.insert(std::make_pair(std::string(key), std::string(value))); + } + + return true; +} + +static bool ParseQnnArgs(AppArgs& app_args, CmdArgs& cmd_args) { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must specify at least a QNN backend path." << std::endl; + return false; + } + + std::string_view args = cmd_args.GetNext(); + std::unordered_map qnn_options; + + if (!ParseQnnRuntimeOptions(std::string(args), qnn_options)) { + return false; + } + + auto backend_iter = qnn_options.find("backend_path"); + if (backend_iter == qnn_options.end()) { + std::cerr << "[ERROR]: Must provide a backend_path for the QNN execution provider." << std::endl; + return false; + } + + app_args.session_options.AppendExecutionProvider("QNN", qnn_options); + app_args.session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // TODO: Parse config entries + app_args.uses_qdq_model = backend_iter->second.rfind("QnnHtp") != std::string::npos; + app_args.supports_multithread_inference = false; // TODO: Work on enabling multi-threaded inference. + return true; +} + +static bool GetValidPath(std::string_view prog_name, std::string_view provided_path, bool is_dir, + std::filesystem::path& valid_path) { + std::filesystem::path path = provided_path; + std::error_code error_code; + + if (!std::filesystem::exists(path, error_code)) { + std::cerr << "[ERROR]: Invalid path " << provided_path << ": " + << error_code.message() << std::endl + << std::endl; + return false; + } + + std::error_code abs_error_code; + std::filesystem::path abs_path = std::filesystem::absolute(path, abs_error_code); + if (abs_error_code) { + std::cerr << "[ERROR]: Invalid path: " << abs_error_code.message() << std::endl + << std::endl; + return false; + } + + if (is_dir && !std::filesystem::is_directory(abs_path)) { + std::cerr << "[ERROR]: " << provided_path << " is not a directory" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (!is_dir && !std::filesystem::is_regular_file(abs_path)) { + std::cerr << "[ERROR]: " << provided_path << " is not a regular file" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + valid_path = std::move(abs_path); + + return true; +} + +bool ParseCmdLineArgs(AppArgs& app_args, int argc, char** argv) { + CmdArgs cmd_args(argc, argv); + std::string_view prog_name = cmd_args.GetNext(); + + // Parse command-line arguments. + while (cmd_args.HasNext()) { + std::string_view arg = cmd_args.GetNext(); + + if (arg == "-h" || arg == "--help") { + PrintUsage(std::cout, prog_name); + return true; + } else if (arg == "-t" || arg == "--test_dir") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (!GetValidPath(prog_name, arg, true, app_args.test_dir)) { + return false; + } + } else if (arg == "-o" || arg == "--output_file") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + app_args.output_file = cmd_args.GetNext(); + } else if (arg == "-a" || arg == "--expected_accuracy_file") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (!GetValidPath(prog_name, arg, false, app_args.expected_accuracy_file)) { + return false; + } + } else if (arg == "-e" || arg == "--execution_provider") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (arg == "qnn") { + if (!ParseQnnArgs(app_args, cmd_args)) { + return false; + } + } else { + std::cerr << "[ERROR]: Unsupported execution provider: " << arg << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + app_args.execution_provider = arg; + } else if (arg == "-s" || arg == "--save_expected_outputs") { + app_args.save_expected_outputs_to_disk = true; + } else if (arg == "-l" || arg == "--load_expected_outputs") { + app_args.load_expected_outputs_from_disk = true; + } else { + std::cerr << "[ERROR]: unknown command-line argument `" << arg << "`" << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + } + + // + // Final argument validation: + // + + if (app_args.test_dir.empty()) { + std::cerr << "[ERROR]: Must provide a test directory using the -t/--test_dir option." << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (app_args.execution_provider.empty()) { + std::cerr << "[ERROR]: Must provide an execution provider using the -e/--execution_provider option." << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (app_args.load_expected_outputs_from_disk && app_args.save_expected_outputs_to_disk) { + std::cerr << "[ERROR]: Cannot enable both -s/--save_expected_outputs and -l/--load_expected_outputs" << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + return true; +} \ No newline at end of file diff --git a/onnxruntime/test/acc_test/cmd_args.h b/onnxruntime/test/acc_test/cmd_args.h new file mode 100644 index 0000000000000..b4ea7356b78c8 --- /dev/null +++ b/onnxruntime/test/acc_test/cmd_args.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include + +struct AppArgs { + std::filesystem::path test_dir; + std::string output_file; + std::filesystem::path expected_accuracy_file; + std::string execution_provider; + bool uses_qdq_model = false; + bool supports_multithread_inference = true; + bool save_expected_outputs_to_disk = false; + bool load_expected_outputs_from_disk = false; + Ort::SessionOptions session_options; +}; + +bool ParseCmdLineArgs(AppArgs& app_args, int argc, char** argv); diff --git a/onnxruntime/test/acc_test/data_loader.cc b/onnxruntime/test/acc_test/data_loader.cc new file mode 100644 index 0000000000000..222cb83705534 --- /dev/null +++ b/onnxruntime/test/acc_test/data_loader.cc @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "data_loader.h" +#include +#include +#include +#include +#include +#include + +namespace acctest { + +bool LoadIODataFromDisk(const std::vector& dataset_paths, + const std::vector& io_infos, + const char* data_file_prefix, + std::vector>& dataset_data) { + size_t total_data_size = 0; + for (const auto& io_info : io_infos) { + total_data_size += io_info.total_data_size; + } + + dataset_data.clear(); + dataset_data.reserve(dataset_paths.size()); + + for (const auto& dataset_path : dataset_paths) { + dataset_data.emplace_back(std::make_unique(total_data_size)); + + size_t num_files_loaded = 0; + + for (const auto& data_file_entry : std::filesystem::directory_iterator{dataset_path}) { + const std::filesystem::path& data_file_path = data_file_entry.path(); + + if (!std::filesystem::is_regular_file(data_file_path)) { + continue; + } + + std::string data_filename_wo_ext = data_file_path.stem().string(); + if (data_filename_wo_ext.rfind(data_file_prefix, 0) != 0) { + continue; + } + + const int32_t io_index_s32 = GetFileIndexSuffix(data_filename_wo_ext, data_file_prefix); + if (io_index_s32 < 0) { + std::cerr << "[ERROR]: The file " << data_file_path << " does not have a properly formatted name" + << " (e.g., " << data_file_prefix << "0.raw)" << std::endl; + return false; + } + + const size_t io_index = static_cast(io_index_s32); + if (io_index >= io_infos.size()) { + std::cerr << "[ERROR]: The input (or output) file index for file " << data_file_path + << " exceeds the number of inputs (or outputs) in the model (" + << io_infos.size() << ")" << std::endl; + return false; + } + + size_t offset = 0; + for (size_t i = 0; i < io_index; i++) { + offset += io_infos[i].total_data_size; + } + assert(offset < total_data_size); + + Span span_to_fill(dataset_data.back().get() + offset, io_infos[io_index].total_data_size); + if (!FillBytesFromBinaryFile(span_to_fill, data_file_path.string())) { + std::cerr << "[ERROR]: Unable to read raw data from file " << data_file_path << std::endl; + return false; + } + + num_files_loaded += 1; + } + + if (num_files_loaded != io_infos.size()) { + std::cerr << "[ERROR]: " << dataset_path << " does not have the expected number of " + << data_file_prefix << ".raw files. Loaded " << num_files_loaded << "files, but expected " + << io_infos.size() << "files." << std::endl; + return false; + } + } + + return true; +} +} // namespace acctest diff --git a/onnxruntime/test/acc_test/data_loader.h b/onnxruntime/test/acc_test/data_loader.h new file mode 100644 index 0000000000000..3f6a3ac7ddf81 --- /dev/null +++ b/onnxruntime/test/acc_test/data_loader.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include + +#include "basic_utils.h" +#include "model_io_utils.h" + +namespace acctest { + +bool LoadIODataFromDisk(const std::vector& dataset_paths, + const std::vector& io_infos, + const char* data_file_prefix, + std::vector>& dataset_data); +} // namespace acctest diff --git a/onnxruntime/test/acc_test/main.cc b/onnxruntime/test/acc_test/main.cc new file mode 100644 index 0000000000000..486367d38d61a --- /dev/null +++ b/onnxruntime/test/acc_test/main.cc @@ -0,0 +1,367 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cmd_args.h" +#include "model_io_utils.h" +#include "data_loader.h" +#include "acc_task.h" +#include "task_thread_pool.h" + +static std::vector GetSortedDatasetPaths(const std::filesystem::path& model_dir) { + std::vector dataset_paths; + const char* dataset_prefix = "test_data_set_"; + + for (const auto& entry : std::filesystem::directory_iterator{model_dir}) { + std::filesystem::path entry_path = entry.path(); + std::string entry_filename = entry_path.filename().string(); + + if (std::filesystem::is_directory(entry_path) && entry_filename.rfind(dataset_prefix, 0) == 0) { + dataset_paths.push_back(std::move(entry_path)); + } + } + + auto cmp_indexed_paths = [dataset_prefix](const std::filesystem::path& a, + const std::filesystem::path& b) -> bool { + const int32_t a_index = GetFileIndexSuffix(a.filename().string(), dataset_prefix); + const int32_t b_index = GetFileIndexSuffix(b.filename().string(), dataset_prefix); + return a_index < b_index; + }; + + std::sort(dataset_paths.begin(), dataset_paths.end(), cmp_indexed_paths); + + return dataset_paths; +} + +static bool GetExpectedOutputsFromModel(Ort::Env& env, + TaskThreadPool& pool, + const AppArgs& args, + const std::filesystem::path& model_path, + const std::vector& dataset_paths, + std::vector>& all_inputs, + std::vector>& all_outputs) { + Ort::SessionOptions session_options; + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + + Ort::Session f32_cpu_sess(env, model_path.c_str(), session_options); + ModelIOInfo model_io_info; + + if (!ModelIOInfo::Init(model_io_info, f32_cpu_sess.GetConst())) { + std::cerr << "[ERROR]: Failed to query model I/O information." << std::endl; + return false; + } + + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.inputs, "input_", all_inputs)) { + std::cerr << "[ERROR]: Failed to load test inputs for model directory " << model_path.parent_path() << std::endl; + return false; + } + + const size_t num_datasets = dataset_paths.size(); + std::vector tasks; + tasks.reserve(num_datasets); + + const size_t total_input_data_size = model_io_info.GetTotalInputSize(); + const size_t total_output_data_size = model_io_info.GetTotalOutputSize(); + + all_outputs.reserve(num_datasets); + + for (size_t i = 0; i < num_datasets; i++) { + all_outputs.emplace_back(std::make_unique(total_output_data_size)); + + Task task = Task::CreateInferenceTask(f32_cpu_sess, model_io_info, + Span(all_inputs[i].get(), total_input_data_size), + Span(all_outputs.back().get(), total_output_data_size)); + tasks.push_back(std::move(task)); + } + + pool.CompleteTasks(tasks); + + if (args.save_expected_outputs_to_disk) { + // Write outputs to disk: output_0.raw, output_1.raw, ... + for (size_t dataset_index = 0; dataset_index < num_datasets; dataset_index++) { + const std::filesystem::path& dataset_dir = dataset_paths[dataset_index]; + Span dataset_output(all_outputs[dataset_index].get(), total_output_data_size); + const std::vector& output_infos = model_io_info.outputs; + const size_t num_outputs = output_infos.size(); + + for (size_t buf_offset = 0, i = 0; i < num_outputs; buf_offset += output_infos[i].total_data_size, i++) { + std::ostringstream oss; + oss << "output_" << i << ".raw"; + + std::filesystem::path output_filepath = dataset_dir / oss.str(); + std::ofstream ofs(output_filepath, std::ios::binary); + + assert(buf_offset < dataset_output.size()); + ofs.write(&dataset_output[buf_offset], output_infos[i].total_data_size); + } + } + } + return true; +} + +static bool RunTestModel(Ort::Env& env, + TaskThreadPool& pool, + const std::filesystem::path& model_path, + const std::vector& dataset_paths, + const Ort::SessionOptions& session_options, + std::vector>& all_inputs, + std::vector>& all_outputs, + std::vector>& test_accuracy_results) { + Ort::Session session(env, model_path.c_str(), session_options); + ModelIOInfo model_io_info; + + if (!ModelIOInfo::Init(model_io_info, session.GetConst())) { + std::cerr << "[ERROR]: Failed to query model I/O information " + << "for model " << model_path << std::endl; + return false; + } + + const size_t num_datasets = dataset_paths.size(); + + if (all_inputs.empty()) { + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.inputs, "input_", all_inputs)) { + std::cerr << "[ERROR]: Failed to load test inputs for model directory " + << model_path.parent_path() << std::endl; + return false; + } + } + + if (all_outputs.empty()) { + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.outputs, "output_", all_outputs)) { + std::cerr << "[ERROR]: Failed to load test outputs for model directory " + << model_path.parent_path() << std::endl; + return false; + } + } + + assert(all_inputs.size() == num_datasets); + assert(all_outputs.size() == num_datasets); + + std::vector tasks; + tasks.reserve(num_datasets); + + test_accuracy_results.resize(num_datasets, std::vector(model_io_info.outputs.size())); + + const size_t total_input_data_size = model_io_info.GetTotalInputSize(); + const size_t total_output_data_size = model_io_info.GetTotalOutputSize(); + + for (size_t i = 0; i < num_datasets; i++) { + Task task = Task::CreateAccuracyCheckTask(session, model_io_info, + Span(all_inputs[i].get(), total_input_data_size), + Span(all_outputs[i].get(), total_output_data_size), + Span(test_accuracy_results[i])); + tasks.push_back(std::move(task)); + } + + pool.CompleteTasks(tasks); + return true; +} + +static void PrintAccuracyResults(const std::vector>& test_accuracy_results, + const std::vector& dataset_paths, + const std::filesystem::directory_entry& model_dir, + const std::string& output_file, + std::unordered_map& test_name_to_acc_result_index) { + assert(test_accuracy_results.size() == dataset_paths.size()); + std::ostringstream oss; + for (size_t i = 0; i < test_accuracy_results.size(); i++) { + const std::filesystem::path& test_path = dataset_paths[i]; + const std::vector& metrics = test_accuracy_results[i]; + std::string key = model_dir.path().filename().string() + "/" + test_path.filename().string(); + test_name_to_acc_result_index[key] = i; + + oss << key << ","; + for (size_t j = 0; j < metrics.size(); j++) { + oss << std::setprecision(std::numeric_limits::max_digits10) << metrics[j].snr; + if (j < metrics.size() - 1) { + oss << ","; + } + } + oss << std::endl; + } + + if (output_file.empty()) { + std::cout << std::endl; + std::cout << "Accuracy Results:" << std::endl; + std::cout << "=================" << std::endl; + std::cout << oss.str() << std::endl; + } else { + std::ofstream out_fs(output_file); + out_fs << oss.str(); + out_fs.close(); + } +} + +static bool CompareAccuracyWithExpectedValues(const std::filesystem::path& expected_accuracy_file, + const std::vector>& test_accuracy_results, + const std::unordered_map& test_name_to_acc_result_index, + size_t& total_tests, + size_t& total_failed_tests) { + std::cout << std::endl; + std::cout << "[INFO]: Comparing accuracy with " << expected_accuracy_file.filename().string() << std::endl; + std::cout << "===============================================" << std::endl; + std::ifstream in_fs(expected_accuracy_file); + constexpr size_t N = 512; + std::array tmp_buf = {}; + + while (in_fs.getline(&tmp_buf[0], tmp_buf.size())) { + std::istringstream iss(tmp_buf.data()); + if (!iss.getline(&tmp_buf[0], tmp_buf.size(), ',')) { + std::cerr << "[ERROR]: Failed to parse expected accuracy file " << expected_accuracy_file << std::endl; + return false; + } + + std::string key(tmp_buf.data()); + auto it = test_name_to_acc_result_index.find(key); + if (it == test_name_to_acc_result_index.end()) { + std::cerr << "[ERROR]: " << key << " was not a test that was run."; + return false; + } + + std::vector expected_values; + while (iss.getline(&tmp_buf[0], tmp_buf.size(), ',')) { + expected_values.push_back(std::stod(tmp_buf.data())); + } + + const std::vector& actual_output_metrics = test_accuracy_results[it->second]; + if (actual_output_metrics.size() != expected_values.size()) { + std::cerr << "[ERROR]: test " << key << " does not have the expected number of outputs."; + return false; + } + + std::ostringstream oss; + bool passed = true; + for (size_t i = 0; i < expected_values.size(); i++) { + const auto& metrics = actual_output_metrics[i]; + + if (!(expected_values[i] - metrics.snr <= EPSILON_DBL)) { + passed = false; + oss << "\tOutput " << i << " SNR decreased: expected " + << std::setprecision(std::numeric_limits::max_digits10) << expected_values[i] << ", actual " + << metrics.snr << std::endl; + } + } + + std::cout << "[INFO]: Checking if " << key << " degraded ... "; + if (passed) { + std::cout << "PASSED" << std::endl; + } else { + std::cout << "FAILED" << std::endl; + std::cout << oss.str() << std::endl; + total_failed_tests += 1; + } + total_tests += 1; + } + + return true; +} + +int main(int argc, char** argv) { + try { + AppArgs args; + if (!ParseCmdLineArgs(args, argc, argv)) { + return 1; + } + + Ort::Env env; + + constexpr size_t num_pool_threads = 3; + TaskThreadPool pool(num_pool_threads); + TaskThreadPool dummy_pool(0); // For EPs that only support single-threaded inference (e.g., QNN). + size_t total_tests = 0; + size_t total_failed_tests = 0; + + for (const std::filesystem::directory_entry& model_dir : std::filesystem::directory_iterator{args.test_dir}) { + const std::filesystem::path& model_dir_path = model_dir.path(); + const std::vector dataset_paths = GetSortedDatasetPaths(model_dir_path); + + if (dataset_paths.empty()) { + continue; // Nothing to test. + } + + std::filesystem::path base_model_path = model_dir_path / "model.onnx"; + std::filesystem::path ep_model_path; + + // Some EPs will need to use a QDQ model instead of the the original model. + if (args.uses_qdq_model) { + std::filesystem::path qdq_model_path = model_dir_path / "model.qdq.onnx"; + + if (!std::filesystem::is_regular_file(qdq_model_path)) { + std::cerr << "[ERROR]: Execution provider '" << args.execution_provider + << "' requires a QDQ model." << std::endl; + return 1; + } + ep_model_path = std::move(qdq_model_path); + } else { + ep_model_path = base_model_path; + } + + std::vector> all_inputs; + std::vector> all_outputs; + + // Load expected outputs from base model running on CPU EP (unless user wants to use outputs from disk). + if (!args.load_expected_outputs_from_disk) { + if (!std::filesystem::is_regular_file(base_model_path)) { + std::cerr << "[ERROR]: Cannot find ONNX model " << base_model_path + << " from which to get expected outputs." << std::endl; + return 1; + } + + if (!GetExpectedOutputsFromModel(env, pool, args, base_model_path, dataset_paths, all_inputs, all_outputs)) { + return 1; + } + } + + // Run accuracy measurements with the EP under test. + std::vector> test_accuracy_results; + TaskThreadPool& ep_pool = args.supports_multithread_inference ? pool : dummy_pool; + if (!RunTestModel(env, ep_pool, ep_model_path, dataset_paths, args.session_options, + all_inputs, all_outputs, test_accuracy_results)) { + return 1; + } + + // Print the accuracy results to file or stdout. + std::unordered_map test_name_to_acc_result_index; + PrintAccuracyResults(test_accuracy_results, + dataset_paths, + model_dir, + args.output_file, + test_name_to_acc_result_index); + + if (!args.expected_accuracy_file.empty()) { + if (!CompareAccuracyWithExpectedValues(args.expected_accuracy_file, test_accuracy_results, + test_name_to_acc_result_index, total_tests, total_failed_tests)) { + return 1; + } + } + } + + if (!args.expected_accuracy_file.empty()) { + const size_t total_tests_passed = total_tests - total_failed_tests; + std::cout << std::endl + << "[INFO]: " << total_tests_passed << "/" << total_tests << " tests passed." << std::endl + << "[INFO]: " << total_failed_tests << "/" << total_tests << " tests failed." << std::endl; + return 1; + } + } catch (const std::exception& e) { + std::cerr << "[ORT_QNN_APP EXCEPTION]: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/onnxruntime/test/acc_test/model_io_utils.cc b/onnxruntime/test/acc_test/model_io_utils.cc new file mode 100644 index 0000000000000..9ea9882d8f00a --- /dev/null +++ b/onnxruntime/test/acc_test/model_io_utils.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "model_io_utils.h" +#include + +bool GetTensorElemDataSize(ONNXTensorElementDataType data_type, size_t& size) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + size = sizeof(uint8_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + size = sizeof(int8_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + size = sizeof(uint16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + size = sizeof(int16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + size = sizeof(int32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + size = sizeof(int64_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + size = sizeof(bool); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + size = sizeof(uint32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + size = sizeof(uint64_t); + break; + default: + std::cerr << "[ERROR]: Unsupported tensor element data type: " << data_type << std::endl; + return false; + } + + return true; +} + +AccMetrics ComputeAccuracyMetric(Ort::ConstValue ort_output, Span raw_expected_output, + const IOInfo& output_info) { + AccMetrics metrics = {}; + switch (output_info.data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + default: + // Note: shouldn't get here because we've already validated expected output data types when loading model. + std::cerr << "[ERROR]: Unsupported tensor element data type: " << output_info.data_type << std::endl; + std::abort(); + } + + return metrics; +} + +bool ModelIOInfo::Init(ModelIOInfo& model_info, Ort::ConstSession session) { + Ort::AllocatorWithDefaultOptions allocator; + + // Get model input info (name, shape, type) + { + size_t num_inputs = session.GetInputCount(); + model_info.inputs.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; i++) { + Ort::TypeInfo type_info = session.GetInputTypeInfo(i); + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) { + std::cerr << "[ERROR]: Only support models with tensor inputs" << std::endl; + return false; + } + + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + IOInfo input_info; + if (!IOInfo::Init(input_info, session.GetInputNameAllocated(i, allocator).get(), + tensor_info.GetElementType(), tensor_info.GetShape())) { + std::cerr << "[ERROR]: Unsupported tensor element type (" << tensor_info.GetElementType() + << ") for input at index " << i << std::endl; + return false; + } + + model_info.inputs.push_back(std::move(input_info)); + } + } + + // Get model output info (name, shape, type) + { + size_t num_outputs = session.GetOutputCount(); + model_info.outputs.reserve(num_outputs); + + for (size_t i = 0; i < num_outputs; i++) { + Ort::TypeInfo type_info = session.GetOutputTypeInfo(i); + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) { + std::cerr << "[ERROR]: Only support models with tensor outputs" << std::endl; + return false; + } + + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + IOInfo output_info; + if (!IOInfo::Init(output_info, session.GetOutputNameAllocated(i, allocator).get(), + tensor_info.GetElementType(), tensor_info.GetShape())) { + std::cerr << "[ERROR]: Unsupported tensor element type (" << tensor_info.GetElementType() + << ") for output at index " << i << std::endl; + return false; + } + + model_info.outputs.push_back(std::move(output_info)); + } + } + + return true; +} + +size_t ModelIOInfo::GetTotalInputSize() const { + size_t total_size = 0; + + for (const auto& input_info : inputs) { + total_size += input_info.total_data_size; + } + + return total_size; +} + +size_t ModelIOInfo::GetTotalOutputSize() const { + size_t total_size = 0; + + for (const auto& output_info : outputs) { + total_size += output_info.total_data_size; + } + + return total_size; +} diff --git a/onnxruntime/test/acc_test/model_io_utils.h b/onnxruntime/test/acc_test/model_io_utils.h new file mode 100644 index 0000000000000..27bf42cebe7bf --- /dev/null +++ b/onnxruntime/test/acc_test/model_io_utils.h @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include + +#include "basic_utils.h" + +bool GetTensorElemDataSize(ONNXTensorElementDataType data_type, size_t& size); + +struct IOInfo { + IOInfo() = default; + IOInfo(IOInfo&& other) = default; + IOInfo(const IOInfo& other) = default; + + IOInfo& operator=(const IOInfo& other) = default; + IOInfo& operator=(IOInfo&& other) = default; + + static bool Init(IOInfo& io_info, const char* name, + ONNXTensorElementDataType data_type, std::vector shape) { + size_t elem_size = 0; + if (!GetTensorElemDataSize(data_type, elem_size)) { + return false; + } + + const size_t total_data_size = elem_size * static_cast(GetShapeSize(Span(shape))); + + io_info.name = name; + io_info.shape = std::move(shape); + io_info.data_type = data_type; + io_info.total_data_size = total_data_size; + + return true; + } + + friend bool operator==(const IOInfo& l, const IOInfo& r) { + if (l.name != r.name || l.data_type != r.data_type || l.shape.size() != r.shape.size()) { + return false; + } + + for (size_t i = 0; i < l.shape.size(); i++) { + if (l.shape[i] != r.shape[i]) { + return false; + } + } + + return true; + } + + friend bool operator!=(const IOInfo& l, const IOInfo& r) { + return !(l == r); + } + + std::string name; + std::vector shape; + ONNXTensorElementDataType data_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + size_t total_data_size = 0; +}; + +struct ModelIOInfo { + ModelIOInfo() = default; + ModelIOInfo(ModelIOInfo&& other) = default; + ModelIOInfo(const ModelIOInfo& other) = default; + + ModelIOInfo& operator=(const ModelIOInfo& other) = default; + ModelIOInfo& operator=(ModelIOInfo&& other) = default; + + friend bool operator==(const ModelIOInfo& l, const ModelIOInfo& r) { + return l.inputs == r.inputs && l.outputs == r.outputs; + } + + friend bool operator!=(const ModelIOInfo& l, const ModelIOInfo& r) { + return !(l == r); + } + + static bool Init(ModelIOInfo& model_info, Ort::ConstSession session); + + size_t GetTotalInputSize() const; + size_t GetTotalOutputSize() const; + + std::vector inputs; + std::vector outputs; +}; + +AccMetrics ComputeAccuracyMetric(Ort::ConstValue ort_output, Span raw_expected_output, + const IOInfo& output_info); diff --git a/onnxruntime/test/acc_test/task_thread_pool.cc b/onnxruntime/test/acc_test/task_thread_pool.cc new file mode 100644 index 0000000000000..8252165fefa4a --- /dev/null +++ b/onnxruntime/test/acc_test/task_thread_pool.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "task_thread_pool.h" +#include +#include "acc_task.h" + +TaskThreadPool::TaskThreadPool(size_t num_threads) { + threads_.reserve(num_threads); + for (size_t i = 0; i < num_threads; i++) { + threads_.emplace_back(&TaskThreadPool::ThreadEntry, this); + } +} + +TaskThreadPool::~TaskThreadPool() { + { + // Acquire lock, set shutdown to true, and wake up all threads. + std::unique_lock lock(lock_); + shutdown_ = true; + signal_.notify_all(); + } + + // Wait for all threads to exit. + for (auto& thread : threads_) { + thread.join(); + } +} + +void TaskThreadPool::CompleteTasks(Span tasks) { + // Assert that it is only possible to call CompleteTasks() when either + // this is the first set of tasks or we've completely processed the previous tasks. + assert(tasks_completed_ == tasks_.size()); + + { + // Acquire lock, set new tasks, and wake up all threads. + std::unique_lock lock(lock_); + tasks_ = tasks; + tasks_completed_ = 0; + next_task_index_ = 0; + signal_.notify_all(); + } + + // The main thread (calling thread) can also help out until all tasks have been completed. + while (tasks_completed_ < tasks_.size()) { + while (RunNextTask()) { + // Keep helping out the pool threads. + } + } +} + +void TaskThreadPool::ThreadEntry() { + while (true) { + // Keep running tasks until they have *all* been claimed by some thread. + while (RunNextTask()) { + } + + { + // Get lock and sleep if all tasks have been taken by some thread. If shutdown_ flag is set, exit. + std::unique_lock lock(lock_); + while (!shutdown_ && (next_task_index_ >= tasks_.size())) { + signal_.wait(lock); // wait() may be unblocked spuriously (according to docs), so need to call it in a loop. + } + + if (shutdown_) { + return; + } + } + } +} + +bool TaskThreadPool::RunNextTask() { + if (tasks_.empty()) { + return false; + } + + const size_t task_index = std::atomic_fetch_add(&next_task_index_, static_cast(1)); + if (task_index >= tasks_.size()) { + return false; + } + + tasks_[task_index].Run(); + + std::atomic_fetch_add(&tasks_completed_, static_cast(1)); + return true; +} diff --git a/onnxruntime/test/acc_test/task_thread_pool.h b/onnxruntime/test/acc_test/task_thread_pool.h new file mode 100644 index 0000000000000..bd2f644cca9d1 --- /dev/null +++ b/onnxruntime/test/acc_test/task_thread_pool.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include +#include + +#include "basic_utils.h" +#include "acc_task.h" + +class TaskThreadPool { + public: + TaskThreadPool(size_t num_threads); + ~TaskThreadPool(); + + void CompleteTasks(Span tasks); + + private: + void ThreadEntry(); + bool RunNextTask(); + + std::mutex lock_; + std::condition_variable signal_; + bool shutdown_ = false; + Span tasks_; + std::atomic next_task_index_ = 0; + std::atomic tasks_completed_ = 0; + std::vector threads_; +};