From 9258a8d6d1851911e682403650299f583d425423 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 17 Dec 2023 20:58:35 -0800 Subject: [PATCH 1/8] Add initial implementation of accuracy test tool --- cmake/onnxruntime_unittests.cmake | 33 ++ onnxruntime/test/acc_test/acc_task.cc | 77 ++++ onnxruntime/test/acc_test/acc_task.h | 48 +++ onnxruntime/test/acc_test/basic_utils.cc | 57 +++ onnxruntime/test/acc_test/basic_utils.h | 116 ++++++ onnxruntime/test/acc_test/cmd_args.cc | 291 ++++++++++++++ onnxruntime/test/acc_test/cmd_args.h | 18 + onnxruntime/test/acc_test/data_loader.cc | 79 ++++ onnxruntime/test/acc_test/data_loader.h | 17 + onnxruntime/test/acc_test/main.cc | 375 ++++++++++++++++++ onnxruntime/test/acc_test/model_io_utils.cc | 203 ++++++++++ onnxruntime/test/acc_test/model_io_utils.h | 88 ++++ onnxruntime/test/acc_test/task_thread_pool.cc | 84 ++++ onnxruntime/test/acc_test/task_thread_pool.h | 29 ++ 14 files changed, 1515 insertions(+) create mode 100644 onnxruntime/test/acc_test/acc_task.cc create mode 100644 onnxruntime/test/acc_test/acc_task.h create mode 100644 onnxruntime/test/acc_test/basic_utils.cc create mode 100644 onnxruntime/test/acc_test/basic_utils.h create mode 100644 onnxruntime/test/acc_test/cmd_args.cc create mode 100644 onnxruntime/test/acc_test/cmd_args.h create mode 100644 onnxruntime/test/acc_test/data_loader.cc create mode 100644 onnxruntime/test/acc_test/data_loader.h create mode 100644 onnxruntime/test/acc_test/main.cc create mode 100644 onnxruntime/test/acc_test/model_io_utils.cc create mode 100644 onnxruntime/test/acc_test/model_io_utils.h create mode 100644 onnxruntime/test/acc_test/task_thread_pool.cc create mode 100644 onnxruntime/test/acc_test/task_thread_pool.h diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 7c8c70f913dca..757c1ed4c27dc 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1171,6 +1171,39 @@ endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + # Accuracy test runner + set(onnxruntime_acc_test_src_dir ${TEST_SRC_DIR}/acc_test) + set(onnxruntime_acc_test_src_patterns + "${onnxruntime_acc_test_src_dir}/*.cc" + "${onnxruntime_acc_test_src_dir}/*.h") + + file(GLOB onnxruntime_acc_test_src CONFIGURE_DEPENDS + ${onnxruntime_acc_test_src_patterns} + ) + onnxruntime_add_executable(onnxruntime_acc_test ${onnxruntime_acc_test_src}) + target_include_directories(onnxruntime_acc_test PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + if (WIN32) + target_compile_options(onnxruntime_acc_test PRIVATE ${disabled_warnings}) + endif() + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_acc_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() + + if (onnxruntime_BUILD_SHARED_LIB) + set(onnxruntime_acc_test_libs onnxruntime) + target_link_libraries(onnxruntime_acc_test PRIVATE ${onnxruntime_acc_test_libs}) + endif() + + set_target_properties(onnxruntime_acc_test PROPERTIES FOLDER "ONNXRuntimeTest") + + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_acc_test PRIVATE "/STACK:4000000") + endif() + endif() + #perf test runner set(onnxruntime_perf_test_src_dir ${TEST_SRC_DIR}/perftest) set(onnxruntime_perf_test_src_patterns diff --git a/onnxruntime/test/acc_test/acc_task.cc b/onnxruntime/test/acc_test/acc_task.cc new file mode 100644 index 0000000000000..430d4703786e1 --- /dev/null +++ b/onnxruntime/test/acc_test/acc_task.cc @@ -0,0 +1,77 @@ +#include "acc_task.h" + +static std::vector RunInference(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer) { + // Setup input + const std::vector& input_infos = model_io_info.inputs; + const size_t num_inputs = input_infos.size(); + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.reserve(num_inputs); + ort_input_names.reserve(num_inputs); + + for (size_t input_offset = 0, i = 0; i < num_inputs; input_offset += input_infos[i].total_data_size, i++) { + assert(input_offset < input_buffer.size()); + const IOInfo& input_info = input_infos[i]; + Span input_data(&input_buffer[input_offset], input_info.total_data_size); + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, (void*)input_data.data(), input_data.size(), + input_info.shape.data(), input_info.shape.size(), + input_info.data_type)); + ort_input_names.push_back(input_info.name.c_str()); + } + + const size_t num_outputs = model_io_info.outputs.size(); + std::vector ort_output_names; + ort_output_names.reserve(num_outputs); + + for (size_t i = 0; i < num_outputs; i++) { + ort_output_names.push_back(model_io_info.outputs[i].name.c_str()); + } + + return session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), ort_output_names.data(), ort_output_names.size()); +} + +void Task::Run() { + std::vector ort_output_vals = RunInference(session_, model_io_info_, input_buffer_); + + AccuracyCheck* accuracy_check_data = std::get_if(&variant_); + if (accuracy_check_data) { + const std::vector& output_infos = model_io_info_.get().outputs; + const size_t num_outputs = output_infos.size(); + Span expected_output_buffer = accuracy_check_data->expected_output_buffer; + + for (size_t output_offset = 0, i = 0; i < num_outputs; output_offset += output_infos[i].total_data_size, i++) { + assert(output_offset < expected_output_buffer.size()); + const IOInfo& output_info = output_infos[i]; + Span raw_expected_output(&expected_output_buffer[output_offset], output_info.total_data_size); + + accuracy_check_data->output_acc_metric[i] = ComputeAccuracyMetric(ort_output_vals[i].GetConst(), + raw_expected_output, + output_info); + } + return; + } + + Inference* inference_data = std::get_if(&variant_); + if (inference_data) { + Span& output_buffer = inference_data->output_buffer; + + // Unfortunately, we have to copy output values (Ort::Value is not copyable, so it is limited when stored in a std::vector) + const std::vector& output_infos = model_io_info_.get().outputs; + const size_t num_outputs = output_infos.size(); + + for (size_t output_offset = 0, i = 0; i < num_outputs; output_offset += output_infos[i].total_data_size, i++) { + assert(output_offset < output_buffer.size()); + std::memcpy(&output_buffer[output_offset], + ort_output_vals[i].GetTensorRawData(), + output_infos[i].total_data_size); + } + return; + } + + std::abort(); +} diff --git a/onnxruntime/test/acc_test/acc_task.h b/onnxruntime/test/acc_test/acc_task.h new file mode 100644 index 0000000000000..980a90ac74fd2 --- /dev/null +++ b/onnxruntime/test/acc_test/acc_task.h @@ -0,0 +1,48 @@ +#pragma once +#include +#include +#include +#include "basic_utils.h" +#include "model_io_utils.h" + +class Task { + private: + struct Inference { + Span output_buffer; + }; + + struct AccuracyCheck { + Span expected_output_buffer; + Span output_acc_metric; + }; + + public: + Task() = default; + Task(Task&& other) = default; + Task(const Task& other) = default; + Task(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span output_buffer) + : session_(session), model_io_info_(model_io_info), input_buffer_(input_buffer), variant_(Inference{output_buffer}) {} + Task(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span expected_output_buffer, Span output_acc_metric) + : session_(session), model_io_info_(model_io_info), input_buffer_(input_buffer), variant_(AccuracyCheck{expected_output_buffer, output_acc_metric}) {} + + static Task CreateInferenceTask(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span output_buffer) { + return Task(session, model_io_info, input_buffer, output_buffer); + } + + static Task CreateAccuracyCheckTask(Ort::Session& session, const ModelIOInfo& model_io_info, + Span input_buffer, Span expected_output_buffer, + Span output_acc_metric) { + return Task(session, model_io_info, input_buffer, expected_output_buffer, output_acc_metric); + } + + void Run(); + + private: + std::reference_wrapper session_; + std::reference_wrapper model_io_info_; + Span input_buffer_; + std::variant variant_; +}; diff --git a/onnxruntime/test/acc_test/basic_utils.cc b/onnxruntime/test/acc_test/basic_utils.cc new file mode 100644 index 0000000000000..b874f4f02b84b --- /dev/null +++ b/onnxruntime/test/acc_test/basic_utils.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "basic_utils.h" +#include +#include + +bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepath) { + std::ifstream input_ifs(binary_filepath, std::ifstream::binary); + + if (!input_ifs.is_open()) { + return false; + } + + size_t file_byte_size = 0; + input_ifs.seekg(0, input_ifs.end); + file_byte_size = input_ifs.tellg(); + input_ifs.seekg(0, input_ifs.beg); + + if (file_byte_size != array.size()) { + return false; + } + + input_ifs.read(array.data(), file_byte_size); + return static_cast(input_ifs); +} + +int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix) { + int64_t index = -1; + const char* str = filename_wo_ext.c_str(); + + // Move past the prefix. + while (*str && *prefix && *str == *prefix) { + str++; + prefix++; + } + + if (*prefix) { + return -1; // File doesn't start with the prefix. + } + + // Parse the input index from file name. + index = 0; + while (*str) { + int64_t c = *str; + if (!(c >= '0' && c <= '9')) { + return -1; // Not a number. + } + + index *= 10; + index += (c - '0'); + str++; + } + + return index; +} + diff --git a/onnxruntime/test/acc_test/basic_utils.h b/onnxruntime/test/acc_test/basic_utils.h new file mode 100644 index 0000000000000..885992d6f13e2 --- /dev/null +++ b/onnxruntime/test/acc_test/basic_utils.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include +#include +#include +#include +#include + +// Make a bootleg std::span for C++ versions older than 20 +template +class Span { + public: + Span() = default; + Span(T* data, size_t size) : data_(data), size_(size) {} + Span(std::vector& vec) : data_(vec.data()), size_(vec.size()) {} + Span(const std::vector>& vec) : data_(vec.data()), size_(vec.size()) {} + + template + Span(std::array arr) : data_(arr.data()), size_(N) {} + + Span(const Span& other) = default; + Span(Span&& other) = default; + + Span& operator=(const Span& other) = default; + Span& operator=(Span&& other) = default; + + T& operator[](size_t index) const { + return data_[index]; + } + + T* data() const { return data_; } + size_t size() const { return size_; } + bool empty() const { return size_ == 0; } + + private: + T* data_{nullptr}; + size_t size_{0}; +}; + +template +static Span ReinterpretBytesAsSpan(Span, const char, char>> bytes_span) { + return Span(reinterpret_cast(bytes_span.data()), bytes_span.size() / sizeof(T)); +} + +template +constexpr int64_t GetShapeSize(Span shape) { + int64_t size = 1; + + for (size_t i = 0; i < shape.size(); i++) { + size *= shape[i]; + } + + return size; +} + +int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix); +bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepath); + +constexpr double EPSILON_DBL = 2e-16; + +struct AccMetrics { + double rmse = 0.0; + double snr = 0.0; + double min_val = 0.0; + double max_val = 0.0; + double min_expected_val = 0.0; + double max_expected_val = 0.0; + friend bool operator==(const AccMetrics& l, const AccMetrics& r) { + if (l.rmse != r.rmse) return false; + if (l.min_val != r.min_val) return false; + if (l.max_val != r.max_val) return false; + if (l.min_expected_val != r.min_expected_val) return false; + if (l.max_expected_val != r.max_expected_val) return false; + if (l.snr != r.snr) return false; + + return true; + } + friend bool operator!=(const AccMetrics& l, const AccMetrics& r) { + return !(l == r); + } +}; + +template +void GetAccuracy(Span expected_output, Span actual_output, AccMetrics& metrics) { + // Compute RMSE. This is not a great way to measure accuracy, but .... + assert(expected_output.size() == actual_output.size()); + const size_t num_outputs = expected_output.size(); + + metrics.rmse = 0.0; + metrics.min_val = static_cast(actual_output[0]); + metrics.max_val = static_cast(actual_output[0]); + metrics.min_expected_val = static_cast(expected_output[0]); + metrics.max_expected_val = static_cast(expected_output[0]); + double tensor_norm = 0.0; + double diff_norm = 0.0; + for (size_t i = 0; i < num_outputs; i++) { + double diff = static_cast(actual_output[i]) - static_cast(expected_output[i]); + diff_norm += diff * diff; + tensor_norm += static_cast(expected_output[i]) * static_cast(expected_output[i]); + + metrics.rmse += diff * diff; + metrics.min_val = std::min(metrics.min_val, static_cast(actual_output[i])); + metrics.max_val = std::max(metrics.max_val, static_cast(actual_output[i])); + metrics.min_expected_val = std::min(metrics.min_expected_val, static_cast(expected_output[i])); + metrics.max_expected_val = std::max(metrics.max_expected_val, static_cast(expected_output[i])); + } + + metrics.rmse = std::sqrt(metrics.rmse / static_cast(num_outputs)); + + tensor_norm = std::max(std::sqrt(tensor_norm), EPSILON_DBL); + diff_norm = std::max(std::sqrt(diff_norm), EPSILON_DBL); + metrics.snr = 20.0 * std::log10(tensor_norm / diff_norm); +} diff --git a/onnxruntime/test/acc_test/cmd_args.cc b/onnxruntime/test/acc_test/cmd_args.cc new file mode 100644 index 0000000000000..b1b8286389249 --- /dev/null +++ b/onnxruntime/test/acc_test/cmd_args.cc @@ -0,0 +1,291 @@ +#include "cmd_args.h" +#include +#include +#include +#include +#include +#include +#include +#include + +struct CmdArgs { + CmdArgs(int argc, char** argv) noexcept : argc_(argc), argv_(argv), index_(0) {} + + [[nodiscard]] bool HasNext() const { return index_ < argc_; } + + [[nodiscard]] std::string_view GetNext() { + assert(HasNext()); + return argv_[index_++]; + } + + [[nodiscard]] std::string_view PeekNext() { + assert(HasNext()); + return argv_[index_]; + } + + private: + int argc_; + char** argv_; + int index_; +}; + +static void PrintUsage(std::ostream& stream, std::string_view prog_name) { + stream << "Usage: " << prog_name << " [OPTIONS]" + << std::endl; + stream << "OPTIONS:" << std::endl; + stream << " -h/--help Print this help message" << std::endl; + stream << " -t/--test_dir Path to test directory with models and inputs/outputs" << std::endl; + stream << " -l/--load_expected_outputs Load expected outputs from raw output_.raw files" << std::endl; + stream << " -s/--save_expected_outputs Save outputs from baseline model on CPU EP to disk" << std::endl; + stream << " -e/--execution_provider The execution provider to test (e.g., qnn)" << std::endl; + stream << " -o/--output_file The output file into which to save accuracy results" << std::endl; + stream << " -a/--expected_accuracy_file The file containing expected accuracy results" << std::endl + << std::endl; +} + +static bool ParseQnnRuntimeOptions(std::string ep_config_string, + std::unordered_map& qnn_options) { + std::istringstream ss(ep_config_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + std::cerr << "Use a '|' to separate the key and value for the run-time option you are trying to use." << std::endl; + return false; + } + + std::string_view key(token_sv.substr(0, pos)); + std::string_view value(token_sv.substr(pos + 1)); + + if (key == "backend_path") { + if (value.empty()) { + std::cerr << "[ERROR]: Please provide the QNN backend path." << std::endl; + return false; + } + } else if (key == "qnn_context_cache_enable") { + if (value != "1") { + std::cerr << "[ERROR]: Set to 1 to enable qnn_context_cache_enable." << std::endl; + return false; + } + } else if (key == "qnn_context_cache_path") { + // no validation + } else if (key == "profiling_level") { + std::unordered_set supported_profiling_level = {"off", "basic", "detailed"}; + if (supported_profiling_level.find(value) == supported_profiling_level.end()) { + std::cerr << "[ERROR]: Supported profiling_level: off, basic, detailed" << std::endl; + return false; + } + } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + // no validation + } else if (key == "htp_performance_mode") { + std::unordered_set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", + "high_power_saver", "low_balanced", "low_power_saver", + "power_saver", "sustained_high_performance"}; + if (supported_htp_perf_mode.find(value) == supported_htp_perf_mode.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_perf_mode.begin(), supported_htp_perf_mode.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + std::cerr << "[ERROR]: Supported htp_performance_mode: " << str << std::endl; + return false; + } + } else if (key == "qnn_saver_path") { + // no validation + } else if (key == "htp_graph_finalization_optimization_mode") { + std::unordered_set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; + if (supported_htp_graph_final_opt_modes.find(value) == supported_htp_graph_final_opt_modes.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_graph_final_opt_modes.begin(), supported_htp_graph_final_opt_modes.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + std::cerr << "[ERROR]: Wrong value for htp_graph_finalization_optimization_mode. select from: " << str << std::endl; + return false; + } + } else if (key == "qnn_context_priority") { + std::unordered_set supported_qnn_context_priority = {"low", "normal", "normal_high", "high"}; + if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { + std::cerr << "[ERROR]: Supported qnn_context_priority: low, normal, normal_high, high" << std::endl; + return false; + } + } else { + std::cerr << R"([ERROR]: Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', +'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])" + << std::endl; + return false; + } + + qnn_options.insert(std::make_pair(std::string(key), std::string(value))); + } + + return true; +} + +static bool ParseQnnArgs(AppArgs& app_args, CmdArgs& cmd_args) { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must specify at least a QNN backend path." << std::endl; + return false; + } + + std::string_view args = cmd_args.GetNext(); + std::unordered_map qnn_options; + + if (!ParseQnnRuntimeOptions(std::string(args), qnn_options)) { + return false; + } + + auto backend_iter = qnn_options.find("backend_path"); + if (backend_iter == qnn_options.end()) { + std::cerr << "[ERROR]: Must provide a backend_path for the QNN execution provider." << std::endl; + return false; + } + + app_args.session_options.AppendExecutionProvider("QNN", qnn_options); + app_args.session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // TODO: Parse config entries + app_args.uses_qdq_model = backend_iter->second.rfind("QnnHtp") != std::string::npos; + app_args.supports_multithread_inference = false; // TODO: Work on enabling multi-threaded inference. + return true; +} + +static bool GetValidPath(std::string_view prog_name, std::string_view provided_path, bool is_dir, + std::filesystem::path& valid_path) { + std::filesystem::path path = provided_path; + std::error_code error_code; + + if (!std::filesystem::exists(path, error_code)) { + std::cerr << "[ERROR]: Invalid path " << provided_path << ": " + << error_code.message() << std::endl + << std::endl; + return false; + } + + std::error_code abs_error_code; + std::filesystem::path abs_path = std::filesystem::absolute(path, abs_error_code); + if (abs_error_code) { + std::cerr << "[ERROR]: Invalid path: " << abs_error_code.message() << std::endl + << std::endl; + return false; + } + + if (is_dir && !std::filesystem::is_directory(abs_path)) { + std::cerr << "[ERROR]: " << provided_path << " is not a directory" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (!is_dir && !std::filesystem::is_regular_file(abs_path)) { + std::cerr << "[ERROR]: " << provided_path << " is not a regular file" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + valid_path = std::move(abs_path); + + return true; +} + +bool ParseCmdLineArgs(AppArgs& app_args, int argc, char** argv) { + CmdArgs cmd_args(argc, argv); + std::string_view prog_name = cmd_args.GetNext(); + + // Parse command-line arguments. + while (cmd_args.HasNext()) { + std::string_view arg = cmd_args.GetNext(); + + if (arg == "-h" || arg == "--help") { + PrintUsage(std::cout, prog_name); + return true; + } else if (arg == "-t" || arg == "--test_dir") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (!GetValidPath(prog_name, arg, true, app_args.test_dir)) { + return false; + } + } else if (arg == "-o" || arg == "--output_file") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + app_args.output_file = cmd_args.GetNext(); + } else if (arg == "-a" || arg == "--expected_accuracy_file") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (!GetValidPath(prog_name, arg, false, app_args.expected_accuracy_file)) { + return false; + } + } else if (arg == "-e" || arg == "--execution_provider") { + if (!cmd_args.HasNext()) { + std::cerr << "[ERROR]: Must provide an argument after the " << arg << " option" << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + arg = cmd_args.GetNext(); + if (arg == "qnn") { + if (!ParseQnnArgs(app_args, cmd_args)) { + return false; + } + } else { + std::cerr << "[ERROR]: Unsupported execution provider: " << arg << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + app_args.execution_provider = arg; + } else if (arg == "-s" || arg == "--save_expected_outputs") { + app_args.save_expected_outputs_to_disk = true; + } else if (arg == "-l" || arg == "--load_expected_outputs") { + app_args.load_expected_outputs_from_disk = true; + } else { + std::cerr << "[ERROR]: unknown command-line argument `" << arg << "`" << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + } + + // + // Final argument validation: + // + + if (app_args.test_dir.empty()) { + std::cerr << "[ERROR]: Must provide a test directory using the -t/--test_dir option." << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (app_args.execution_provider.empty()) { + std::cerr << "[ERROR]: Must provide an execution provider using the -e/--execution_provider option." << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + if (app_args.load_expected_outputs_from_disk && app_args.save_expected_outputs_to_disk) { + std::cerr << "[ERROR]: Cannot enable both -s/--save_expected_outputs and -l/--load_expected_outputs" << std::endl + << std::endl; + PrintUsage(std::cerr, prog_name); + return false; + } + + return true; +} \ No newline at end of file diff --git a/onnxruntime/test/acc_test/cmd_args.h b/onnxruntime/test/acc_test/cmd_args.h new file mode 100644 index 0000000000000..4b1718d52d3c9 --- /dev/null +++ b/onnxruntime/test/acc_test/cmd_args.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include +#include + +struct AppArgs { + std::filesystem::path test_dir; + std::string output_file; + std::filesystem::path expected_accuracy_file; + std::string execution_provider; + bool uses_qdq_model = false; + bool supports_multithread_inference = true; + bool save_expected_outputs_to_disk = false; + bool load_expected_outputs_from_disk = false; + Ort::SessionOptions session_options; +}; + +bool ParseCmdLineArgs(AppArgs& app_args, int argc, char** argv); diff --git a/onnxruntime/test/acc_test/data_loader.cc b/onnxruntime/test/acc_test/data_loader.cc new file mode 100644 index 0000000000000..3426f4dc36896 --- /dev/null +++ b/onnxruntime/test/acc_test/data_loader.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "data_loader.h" +#include +#include +#include +#include + +namespace acctest { + +bool LoadIODataFromDisk(const std::vector& dataset_paths, + const std::vector& io_infos, + const char* data_file_prefix, + std::vector>& dataset_data) { + size_t total_data_size = 0; + for (const auto& io_info : io_infos) { + total_data_size += io_info.total_data_size; + } + + dataset_data.clear(); + dataset_data.reserve(dataset_paths.size()); + + for (const auto& dataset_path : dataset_paths) { + dataset_data.emplace_back(std::make_unique(total_data_size)); + + size_t num_files_loaded = 0; + + for (const auto& data_file_entry : std::filesystem::directory_iterator{dataset_path}) { + const std::filesystem::path& data_file_path = data_file_entry.path(); + + if (!std::filesystem::is_regular_file(data_file_path)) { + continue; + } + + std::string data_filename_wo_ext = data_file_path.stem().string(); + if (data_filename_wo_ext.rfind(data_file_prefix, 0) != 0) { + continue; + } + + const int64_t io_index = GetFileIndexSuffix(data_filename_wo_ext, data_file_prefix); + if (io_index < 0) { + std::cerr << "[ERROR]: The file " << data_file_path << " does not have a properly formatted name" + << " (e.g., " << data_file_prefix << "0.raw)" << std::endl; + return false; + } + + if (io_index >= static_cast(io_infos.size())) { + std::cerr << "[ERROR]: The input (or output) file index for file " << data_file_path + << " exceeds the number of inputs (or outputs) in the model (" + << io_infos.size() << ")" << std::endl; + return false; + } + + size_t offset = 0; + for (int64_t i = 0; i < io_index; i++) { + offset += io_infos[i].total_data_size; + } + assert(offset < total_data_size); + + Span span_to_fill(dataset_data.back().get() + offset, io_infos[io_index].total_data_size); + if (!FillBytesFromBinaryFile(span_to_fill, data_file_path.string())) { + std::cerr << "[ERROR]: Unable to read raw data from file " << data_file_path << std::endl; + return false; + } + + num_files_loaded += 1; + } + + if (num_files_loaded != io_infos.size()) { + std::cerr << "[ERROR]: " << dataset_path << " does not have the expected number of " + << data_file_prefix << ".raw files. Loaded " << num_files_loaded << "files, but expected " + << io_infos.size() << "files." << std::endl; + return false; + } + } + + return true; +} +} // namespace acctest diff --git a/onnxruntime/test/acc_test/data_loader.h b/onnxruntime/test/acc_test/data_loader.h new file mode 100644 index 0000000000000..3f6a3ac7ddf81 --- /dev/null +++ b/onnxruntime/test/acc_test/data_loader.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include + +#include "basic_utils.h" +#include "model_io_utils.h" + +namespace acctest { + +bool LoadIODataFromDisk(const std::vector& dataset_paths, + const std::vector& io_infos, + const char* data_file_prefix, + std::vector>& dataset_data); +} // namespace acctest diff --git a/onnxruntime/test/acc_test/main.cc b/onnxruntime/test/acc_test/main.cc new file mode 100644 index 0000000000000..99138c3ef3d19 --- /dev/null +++ b/onnxruntime/test/acc_test/main.cc @@ -0,0 +1,375 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include +#include +#include +#include +#include +#include +#include // std::abort +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cmd_args.h" +#include "model_io_utils.h" +#include "data_loader.h" +#include "acc_task.h" +#include "task_thread_pool.h" + +static std::vector GetSortedDatasetPaths(const std::filesystem::path& model_dir) { + std::vector dataset_paths; + const char* dataset_prefix = "test_data_set_"; + + for (const auto& entry : std::filesystem::directory_iterator{model_dir}) { + std::filesystem::path entry_path = entry.path(); + std::string entry_filename = entry_path.filename().string(); + + if (std::filesystem::is_directory(entry_path) && entry_filename.rfind(dataset_prefix, 0) == 0) { + dataset_paths.push_back(std::move(entry_path)); + } + } + + auto cmp_indexed_paths = [dataset_prefix](const std::filesystem::path& a, + const std::filesystem::path& b) -> bool { + const int64_t a_index = GetFileIndexSuffix(a.filename().string(), dataset_prefix); + const int64_t b_index = GetFileIndexSuffix(b.filename().string(), dataset_prefix); + return a_index < b_index; + }; + + std::sort(dataset_paths.begin(), dataset_paths.end(), cmp_indexed_paths); + + return dataset_paths; +} + +static bool GetExpectedOutputsFromModel(Ort::Env& env, + TaskThreadPool& pool, + const AppArgs& args, + const std::filesystem::path& model_path, + const std::vector& dataset_paths, + std::vector>& all_inputs, + std::vector>& all_outputs) { + Ort::SessionOptions session_options; + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + + Ort::Session f32_cpu_sess(env, model_path.c_str(), session_options); + ModelIOInfo model_io_info; + + if (!ModelIOInfo::Init(model_io_info, f32_cpu_sess.GetConst())) { + std::cerr << "[ERROR]: Failed to query model I/O information." << std::endl; + return false; + } + + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.inputs, "input_", all_inputs)) { + std::cerr << "[ERROR]: Failed to load test inputs for model directory " << model_path.parent_path() << std::endl; + return false; + } + + const size_t num_datasets = dataset_paths.size(); + std::vector tasks; + tasks.reserve(num_datasets); + + const size_t total_input_data_size = model_io_info.GetTotalInputSize(); + const size_t total_output_data_size = model_io_info.GetTotalOutputSize(); + + all_outputs.reserve(num_datasets); + + for (size_t i = 0; i < num_datasets; i++) { + all_outputs.emplace_back(std::make_unique(total_output_data_size)); + + Task task = Task::CreateInferenceTask(f32_cpu_sess, model_io_info, + Span(all_inputs[i].get(), total_input_data_size), + Span(all_outputs.back().get(), total_output_data_size)); + tasks.push_back(std::move(task)); + } + + pool.CompleteTasks(tasks); + + if (args.save_expected_outputs_to_disk) { + // Write outputs to disk: output_0.raw, output_1.raw, ... + for (size_t dataset_index = 0; dataset_index < num_datasets; dataset_index++) { + const std::filesystem::path& dataset_dir = dataset_paths[dataset_index]; + Span dataset_output(all_outputs[dataset_index].get(), total_output_data_size); + const std::vector& output_infos = model_io_info.outputs; + const size_t num_outputs = output_infos.size(); + + for (size_t buf_offset = 0, i = 0; i < num_outputs; buf_offset += output_infos[i].total_data_size, i++) { + std::ostringstream oss; + oss << "output_" << i << ".raw"; + + std::filesystem::path output_filepath = dataset_dir / oss.str(); + std::ofstream ofs(output_filepath, std::ios::binary); + + assert(buf_offset < dataset_output.size()); + ofs.write(&dataset_output[buf_offset], output_infos[i].total_data_size); + } + } + } + return true; +} + +static bool RunTestModel(Ort::Env& env, + TaskThreadPool& pool, + const std::filesystem::path& model_path, + const std::vector& dataset_paths, + const Ort::SessionOptions& session_options, + std::vector>& all_inputs, + std::vector>& all_outputs, + std::vector>& test_accuracy_results) { + Ort::Session session(env, model_path.c_str(), session_options); + ModelIOInfo model_io_info; + + if (!ModelIOInfo::Init(model_io_info, session.GetConst())) { + std::cerr << "[ERROR]: Failed to query model I/O information " + << "for model " << model_path << std::endl; + return false; + } + + const size_t num_datasets = dataset_paths.size(); + + if (all_inputs.empty()) { + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.inputs, "input_", all_inputs)) { + std::cerr << "[ERROR]: Failed to load test inputs for model directory " + << model_path.parent_path() << std::endl; + return false; + } + } + + if (all_outputs.empty()) { + if (!acctest::LoadIODataFromDisk(dataset_paths, model_io_info.outputs, "output_", all_outputs)) { + std::cerr << "[ERROR]: Failed to load test outputs for model directory " + << model_path.parent_path() << std::endl; + return false; + } + } + + assert(all_inputs.size() == num_datasets); + assert(all_outputs.size() == num_datasets); + + std::vector tasks; + tasks.reserve(num_datasets); + + test_accuracy_results.resize(num_datasets, std::vector(model_io_info.outputs.size())); + + const size_t total_input_data_size = model_io_info.GetTotalInputSize(); + const size_t total_output_data_size = model_io_info.GetTotalOutputSize(); + + for (size_t i = 0; i < num_datasets; i++) { + Task task = Task::CreateAccuracyCheckTask(session, model_io_info, + Span(all_inputs[i].get(), total_input_data_size), + Span(all_outputs[i].get(), total_output_data_size), + Span(test_accuracy_results[i])); + tasks.push_back(std::move(task)); + } + + pool.CompleteTasks(tasks); + return true; +} + +static void PrintAccuracyResults(const std::vector>& test_accuracy_results, + const std::vector& dataset_paths, + const std::filesystem::directory_entry& model_dir, + const std::string& output_file, + std::unordered_map& test_name_to_acc_result_index) { + assert(test_accuracy_results.size() == dataset_paths.size()); + std::ostringstream oss; + for (size_t i = 0; i < test_accuracy_results.size(); i++) { + const std::filesystem::path& test_path = dataset_paths[i]; + const std::vector& metrics = test_accuracy_results[i]; + std::string key = model_dir.path().filename().string() + "/" + test_path.filename().string(); + test_name_to_acc_result_index[key] = i; + + oss << key << ","; + for (size_t j = 0; j < metrics.size(); j++) { + oss << std::setprecision(std::numeric_limits::max_digits10) << metrics[j].snr; + if (j < metrics.size() - 1) { + oss << ","; + } + } + oss << std::endl; + } + + if (output_file.empty()) { + std::cout << std::endl; + std::cout << "Accuracy Results:" << std::endl; + std::cout << "=================" << std::endl; + std::cout << oss.str() << std::endl; + } else { + std::ofstream out_fs(output_file); + out_fs << oss.str(); + out_fs.close(); + } +} + +static bool CompareAccuracyWithExpectedValues(const std::filesystem::path& expected_accuracy_file, + const std::vector>& test_accuracy_results, + const std::unordered_map& test_name_to_acc_result_index, + size_t& total_tests, + size_t& total_failed_tests) { + std::cout << std::endl; + std::cout << "[INFO]: Comparing accuracy with " << expected_accuracy_file.filename().string() << std::endl; + std::cout << "===============================================" << std::endl; + std::ifstream in_fs(expected_accuracy_file); + constexpr size_t N = 512; + std::array tmp_buf = {}; + + while (in_fs.getline(&tmp_buf[0], tmp_buf.size())) { + std::istringstream iss(tmp_buf.data()); + if (!iss.getline(&tmp_buf[0], tmp_buf.size(), ',')) { + std::cerr << "[ERROR]: Failed to parse expected accuracy file " << expected_accuracy_file << std::endl; + return false; + } + + std::string key(tmp_buf.data()); + auto it = test_name_to_acc_result_index.find(key); + if (it == test_name_to_acc_result_index.end()) { + std::cerr << "[ERROR]: " << key << " was not a test that was run."; + return false; + } + + std::vector expected_values; + while (iss.getline(&tmp_buf[0], tmp_buf.size(), ',')) { + expected_values.push_back(std::stod(tmp_buf.data())); + } + + const std::vector& actual_output_metrics = test_accuracy_results[it->second]; + if (actual_output_metrics.size() != expected_values.size()) { + std::cerr << "[ERROR]: test " << key << " does not have the expected number of outputs."; + return false; + } + + std::ostringstream oss; + bool passed = true; + for (size_t i = 0; i < expected_values.size(); i++) { + const auto& metrics = actual_output_metrics[i]; + + if (!(expected_values[i] - metrics.snr <= EPSILON_DBL)) { + passed = false; + oss << "\tOutput " << i << " SNR decreased: expected " + << std::setprecision(std::numeric_limits::max_digits10) << expected_values[i] << ", actual " + << metrics.snr << std::endl; + } + } + + std::cout << "[INFO]: Checking if " << key << " degraded ... "; + if (passed) { + std::cout << "PASSED" << std::endl; + } else { + std::cout << "FAILED" << std::endl; + std::cout << oss.str() << std::endl; + total_failed_tests += 1; + } + total_tests += 1; + } + + return true; +} + +int main(int argc, char** argv) { + try { + AppArgs args; + if (!ParseCmdLineArgs(args, argc, argv)) { + return 1; + } + + Ort::Env env; + + constexpr size_t num_pool_threads = 3; + TaskThreadPool pool(num_pool_threads); + TaskThreadPool dummy_pool(0); // For EPs that only support single-threaded inference (e.g., QNN). + size_t total_tests = 0; + size_t total_failed_tests = 0; + + for (const std::filesystem::directory_entry& model_dir : std::filesystem::directory_iterator{args.test_dir}) { + const std::filesystem::path& model_dir_path = model_dir.path(); + const std::vector dataset_paths = GetSortedDatasetPaths(model_dir_path); + + if (dataset_paths.empty()) { + continue; // Nothing to test. + } + + std::filesystem::path base_model_path = model_dir_path / "model.onnx"; + std::filesystem::path ep_model_path; + + // Some EPs will need to use a QDQ model instead of the the original model. + if (args.uses_qdq_model) { + std::filesystem::path qdq_model_path = model_dir_path / "model.qdq.onnx"; + + if (!std::filesystem::is_regular_file(qdq_model_path)) { + std::cerr << "[ERROR]: Execution provider '" << args.execution_provider + << "' requires a QDQ model." << std::endl; + return 1; + } + ep_model_path = std::move(qdq_model_path); + } else { + ep_model_path = base_model_path; + } + + std::vector> all_inputs; + std::vector> all_outputs; + + // Load expected outputs from base model running on CPU EP (unless user wants to use outputs from disk). + if (!args.load_expected_outputs_from_disk) { + if (!std::filesystem::is_regular_file(base_model_path)) { + std::cerr << "[ERROR]: Cannot find ONNX model " << base_model_path + << " from which to get expected outputs." << std::endl; + return 1; + } + + if (!GetExpectedOutputsFromModel(env, pool, args, base_model_path, dataset_paths, all_inputs, all_outputs)) { + return 1; + } + } + + // Run accuracy measurements with the EP under test. + std::vector> test_accuracy_results; + TaskThreadPool& ep_pool = args.supports_multithread_inference ? pool : dummy_pool; + if (!RunTestModel(env, ep_pool, ep_model_path, dataset_paths, args.session_options, + all_inputs, all_outputs, test_accuracy_results)) { + return 1; + } + + // Print the accuracy results to file or stdout. + std::unordered_map test_name_to_acc_result_index; + PrintAccuracyResults(test_accuracy_results, + dataset_paths, + model_dir, + args.output_file, + test_name_to_acc_result_index); + + if (!args.expected_accuracy_file.empty()) { + if (!CompareAccuracyWithExpectedValues(args.expected_accuracy_file, test_accuracy_results, + test_name_to_acc_result_index, total_tests, total_failed_tests)) { + return 1; + } + } + } + + if (!args.expected_accuracy_file.empty()) { + const size_t total_tests_passed = total_tests - total_failed_tests; + std::cout << std::endl + << "[INFO]: " << total_tests_passed << "/" << total_tests << " tests passed." << std::endl + << "[INFO]: " << total_failed_tests << "/" << total_tests << " tests failed." << std::endl; + return 1; + } + } catch (const std::exception& e) { + std::cerr << "[ORT_QNN_APP EXCEPTION]: " << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/onnxruntime/test/acc_test/model_io_utils.cc b/onnxruntime/test/acc_test/model_io_utils.cc new file mode 100644 index 0000000000000..de89006f78eb7 --- /dev/null +++ b/onnxruntime/test/acc_test/model_io_utils.cc @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "model_io_utils.h" +#include + +bool GetTensorElemDataSize(ONNXTensorElementDataType data_type, size_t& size) { + switch (data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + size = sizeof(float); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + size = sizeof(uint8_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + size = sizeof(int8_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + size = sizeof(uint16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + size = sizeof(int16_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + size = sizeof(int32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + size = sizeof(int64_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + size = sizeof(bool); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + size = sizeof(double); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + size = sizeof(uint32_t); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + size = sizeof(uint64_t); + break; + default: + std::cerr << "[ERROR]: Unsupported tensor element data type: " << data_type << std::endl; + return false; + } + + return true; +} + +AccMetrics ComputeAccuracyMetric(Ort::ConstValue ort_output, Span raw_expected_output, + const IOInfo& output_info) { + AccMetrics metrics = {}; + switch (output_info.data_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + Span expected_output = ReinterpretBytesAsSpan(raw_expected_output); + Span actual_output(ort_output.GetTensorData(), expected_output.size()); + GetAccuracy(expected_output, actual_output, metrics); + break; + } + default: + // Note: shouldn't get here because we've already validated expected output data types when loading model. + std::cerr << "[ERROR]: Unsupported tensor element data type: " << output_info.data_type << std::endl; + std::abort(); + } + + return metrics; +} + +bool ModelIOInfo::Init(ModelIOInfo& model_info, Ort::ConstSession session) { + Ort::AllocatorWithDefaultOptions allocator; + + // Get model input info (name, shape, type) + { + size_t num_inputs = session.GetInputCount(); + model_info.inputs.reserve(num_inputs); + + for (size_t i = 0; i < num_inputs; i++) { + Ort::TypeInfo type_info = session.GetInputTypeInfo(i); + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) { + std::cerr << "[ERROR]: Only support models with tensor inputs" << std::endl; + return false; + } + + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + IOInfo input_info; + if (!IOInfo::Init(input_info, session.GetInputNameAllocated(i, allocator).get(), + tensor_info.GetElementType(), tensor_info.GetShape())) { + std::cerr << "[ERROR]: Unsupported tensor element type (" << tensor_info.GetElementType() + << ") for input at index " << i << std::endl; + return false; + } + + model_info.inputs.push_back(std::move(input_info)); + } + } + + // Get model output info (name, shape, type) + { + size_t num_outputs = session.GetOutputCount(); + model_info.outputs.reserve(num_outputs); + + for (size_t i = 0; i < num_outputs; i++) { + Ort::TypeInfo type_info = session.GetOutputTypeInfo(i); + if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) { + std::cerr << "[ERROR]: Only support models with tensor outputs" << std::endl; + return false; + } + + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + IOInfo output_info; + if (!IOInfo::Init(output_info, session.GetOutputNameAllocated(i, allocator).get(), + tensor_info.GetElementType(), tensor_info.GetShape())) { + std::cerr << "[ERROR]: Unsupported tensor element type (" << tensor_info.GetElementType() + << ") for output at index " << i << std::endl; + return false; + } + + model_info.outputs.push_back(std::move(output_info)); + } + } + + return true; +} + +size_t ModelIOInfo::GetTotalInputSize() const { + size_t total_size = 0; + + for (const auto& input_info : inputs) { + total_size += input_info.total_data_size; + } + + return total_size; +} + +size_t ModelIOInfo::GetTotalOutputSize() const { + size_t total_size = 0; + + for (const auto& output_info : outputs) { + total_size += output_info.total_data_size; + } + + return total_size; +} + diff --git a/onnxruntime/test/acc_test/model_io_utils.h b/onnxruntime/test/acc_test/model_io_utils.h new file mode 100644 index 0000000000000..7d23e7b23cbe4 --- /dev/null +++ b/onnxruntime/test/acc_test/model_io_utils.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include + +#include "basic_utils.h" + +bool GetTensorElemDataSize(ONNXTensorElementDataType data_type, size_t& size); + +struct IOInfo { + IOInfo() = default; + IOInfo(IOInfo&& other) = default; + IOInfo(const IOInfo& other) = default; + + IOInfo& operator=(const IOInfo& other) = default; + IOInfo& operator=(IOInfo&& other) = default; + + static bool Init(IOInfo& io_info, const char* name, + ONNXTensorElementDataType data_type, std::vector shape) { + size_t elem_size = 0; + if (!GetTensorElemDataSize(data_type, elem_size)) { + return false; + } + + const size_t total_data_size = elem_size * GetShapeSize(Span(shape)); + + io_info.name = name; + io_info.shape = std::move(shape); + io_info.data_type = data_type; + io_info.total_data_size = total_data_size; + + return true; + } + + friend bool operator==(const IOInfo& l, const IOInfo& r) { + if (l.name != r.name || l.data_type != r.data_type || l.shape.size() != r.shape.size()) { + return false; + } + + for (size_t i = 0; i < l.shape.size(); i++) { + if (l.shape[i] != r.shape[i]) { + return false; + } + } + + return true; + } + + friend bool operator!=(const IOInfo& l, const IOInfo& r) { + return !(l == r); + } + + std::string name; + std::vector shape; + ONNXTensorElementDataType data_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + size_t total_data_size = 0; +}; + +struct ModelIOInfo { + ModelIOInfo() = default; + ModelIOInfo(ModelIOInfo&& other) = default; + ModelIOInfo(const ModelIOInfo& other) = default; + + ModelIOInfo& operator=(const ModelIOInfo& other) = default; + ModelIOInfo& operator=(ModelIOInfo&& other) = default; + + friend bool operator==(const ModelIOInfo& l, const ModelIOInfo& r) { + return l.inputs == r.inputs && l.outputs == r.outputs; + } + + friend bool operator!=(const ModelIOInfo& l, const ModelIOInfo& r) { + return !(l == r); + } + + static bool Init(ModelIOInfo& model_info, Ort::ConstSession session); + + size_t GetTotalInputSize() const; + size_t GetTotalOutputSize() const; + + std::vector inputs; + std::vector outputs; + +}; + +AccMetrics ComputeAccuracyMetric(Ort::ConstValue ort_output, Span raw_expected_output, + const IOInfo& output_info); diff --git a/onnxruntime/test/acc_test/task_thread_pool.cc b/onnxruntime/test/acc_test/task_thread_pool.cc new file mode 100644 index 0000000000000..3bd036f38f79e --- /dev/null +++ b/onnxruntime/test/acc_test/task_thread_pool.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "task_thread_pool.h" +#include +#include "acc_task.h" + +TaskThreadPool::TaskThreadPool(int num_threads) { + threads_.reserve(num_threads); + for (size_t i = 0; i < num_threads; i++) { + threads_.emplace_back(&TaskThreadPool::ThreadEntry, this); + } +} + +TaskThreadPool::~TaskThreadPool() { + { + // Acquire lock, set shutdown to true, and wake up all threads. + std::unique_lock lock(lock_); + shutdown_ = true; + signal_.notify_all(); + } + + // Wait for all threads to exit. + for (auto& thread : threads_) { + thread.join(); + } +} + +void TaskThreadPool::CompleteTasks(Span tasks) { + // Assert that it is only possible to call CompleteTasks() when either + // this is the first set of tasks or we've completely processed the previous tasks. + assert(tasks_completed_ == tasks_.size()); + + { + // Acquire lock, set new tasks, and wake up all threads. + std::unique_lock lock(lock_); + tasks_ = tasks; + tasks_completed_ = 0; + next_task_index_ = 0; + signal_.notify_all(); + } + + // The main thread (calling thread) can also help out until all tasks have been completed. + while (tasks_completed_ < tasks_.size()) { + while (RunNextTask()) { + // Keep helping out the pool threads. + } + } +} + +void TaskThreadPool::ThreadEntry() { + while (true) { + // Keep running tasks until they have *all* been claimed by some thread. + while (RunNextTask()) { + } + + { + // Get lock and sleep if all tasks have been taken by some thread. If shutdown_ flag is set, exit. + std::unique_lock lock(lock_); + while (!shutdown_ && (next_task_index_ >= tasks_.size())) { + signal_.wait(lock); // wait() may be unblocked spuriously (according to docs), so need to call it in a loop. + } + + if (shutdown_) { + return; + } + } + } +} + +bool TaskThreadPool::RunNextTask() { + if (tasks_.empty()) { + return false; + } + + const size_t task_index = std::atomic_fetch_add(&next_task_index_, 1); + if (task_index >= tasks_.size()) { + return false; + } + + tasks_[task_index].Run(); + + std::atomic_fetch_add(&tasks_completed_, 1); + return true; +} diff --git a/onnxruntime/test/acc_test/task_thread_pool.h b/onnxruntime/test/acc_test/task_thread_pool.h new file mode 100644 index 0000000000000..84413db15ebd4 --- /dev/null +++ b/onnxruntime/test/acc_test/task_thread_pool.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include +#include +#include + +#include "basic_utils.h" +#include "acc_task.h" + +class TaskThreadPool { + public: + TaskThreadPool(int num_threads); + ~TaskThreadPool(); + + void CompleteTasks(Span tasks); + + private: + void ThreadEntry(); + bool RunNextTask(); + + std::mutex lock_; + std::condition_variable signal_; + bool shutdown_ = false; + Span tasks_; + std::atomic next_task_index_ = 0; + std::atomic tasks_completed_ = 0; + std::vector threads_; +}; From 2fddb1c8d4e808d61ce1ea2480a20b95e7dfb7df Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 17 Dec 2023 22:04:45 -0800 Subject: [PATCH 2/8] headers cleanup --- onnxruntime/test/acc_test/acc_task.cc | 9 +++++++++ onnxruntime/test/acc_test/acc_task.h | 2 ++ onnxruntime/test/acc_test/basic_utils.cc | 1 - onnxruntime/test/acc_test/cmd_args.cc | 5 ++++- onnxruntime/test/acc_test/cmd_args.h | 3 +++ onnxruntime/test/acc_test/data_loader.cc | 2 ++ onnxruntime/test/acc_test/main.cc | 12 ++---------- onnxruntime/test/acc_test/task_thread_pool.h | 2 ++ 8 files changed, 24 insertions(+), 12 deletions(-) diff --git a/onnxruntime/test/acc_test/acc_task.cc b/onnxruntime/test/acc_test/acc_task.cc index 430d4703786e1..2939b67feafaf 100644 --- a/onnxruntime/test/acc_test/acc_task.cc +++ b/onnxruntime/test/acc_test/acc_task.cc @@ -1,4 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "acc_task.h" +#include +#include +#include +#include +#include static std::vector RunInference(Ort::Session& session, const ModelIOInfo& model_io_info, Span input_buffer) { @@ -73,5 +80,7 @@ void Task::Run() { return; } + // Should not reach this line unless we add a new (unhandled) std::variant type. + std::cerr << "[ERROR]: Unhandled std::variant type for Task::variant_ member." << std::endl; std::abort(); } diff --git a/onnxruntime/test/acc_test/acc_task.h b/onnxruntime/test/acc_test/acc_task.h index 980a90ac74fd2..51f472528c958 100644 --- a/onnxruntime/test/acc_test/acc_task.h +++ b/onnxruntime/test/acc_test/acc_task.h @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once #include #include diff --git a/onnxruntime/test/acc_test/basic_utils.cc b/onnxruntime/test/acc_test/basic_utils.cc index b874f4f02b84b..fcaf816b1b565 100644 --- a/onnxruntime/test/acc_test/basic_utils.cc +++ b/onnxruntime/test/acc_test/basic_utils.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "basic_utils.h" #include #include diff --git a/onnxruntime/test/acc_test/cmd_args.cc b/onnxruntime/test/acc_test/cmd_args.cc index b1b8286389249..b12459e02704f 100644 --- a/onnxruntime/test/acc_test/cmd_args.cc +++ b/onnxruntime/test/acc_test/cmd_args.cc @@ -1,12 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "cmd_args.h" #include #include #include #include +#include #include -#include #include #include +#include struct CmdArgs { CmdArgs(int argc, char** argv) noexcept : argc_(argc), argv_(argv), index_(0) {} diff --git a/onnxruntime/test/acc_test/cmd_args.h b/onnxruntime/test/acc_test/cmd_args.h index 4b1718d52d3c9..b4ea7356b78c8 100644 --- a/onnxruntime/test/acc_test/cmd_args.h +++ b/onnxruntime/test/acc_test/cmd_args.h @@ -1,5 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #pragma once #include +#include #include #include diff --git a/onnxruntime/test/acc_test/data_loader.cc b/onnxruntime/test/acc_test/data_loader.cc index 3426f4dc36896..232bdc96898bc 100644 --- a/onnxruntime/test/acc_test/data_loader.cc +++ b/onnxruntime/test/acc_test/data_loader.cc @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace acctest { diff --git a/onnxruntime/test/acc_test/main.cc b/onnxruntime/test/acc_test/main.cc index 99138c3ef3d19..a92bd94a4eae5 100644 --- a/onnxruntime/test/acc_test/main.cc +++ b/onnxruntime/test/acc_test/main.cc @@ -1,28 +1,20 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include #include #include #include -#include #include -#include -#include -#include // std::abort +#include #include #include #include -#include #include +#include #include -#include #include #include -#include -#include -#include #include #include "cmd_args.h" diff --git a/onnxruntime/test/acc_test/task_thread_pool.h b/onnxruntime/test/acc_test/task_thread_pool.h index 84413db15ebd4..318f6c76fd713 100644 --- a/onnxruntime/test/acc_test/task_thread_pool.h +++ b/onnxruntime/test/acc_test/task_thread_pool.h @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once +#include #include #include +#include #include #include "basic_utils.h" From 75affa7663b867e0ec31cc9d5f698ece71e55d51 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 17 Dec 2023 22:12:57 -0800 Subject: [PATCH 3/8] Fixed signed comparison warning --- onnxruntime/test/acc_test/task_thread_pool.cc | 2 +- onnxruntime/test/acc_test/task_thread_pool.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/acc_test/task_thread_pool.cc b/onnxruntime/test/acc_test/task_thread_pool.cc index 3bd036f38f79e..8bc84b5c01cf3 100644 --- a/onnxruntime/test/acc_test/task_thread_pool.cc +++ b/onnxruntime/test/acc_test/task_thread_pool.cc @@ -4,7 +4,7 @@ #include #include "acc_task.h" -TaskThreadPool::TaskThreadPool(int num_threads) { +TaskThreadPool::TaskThreadPool(size_t num_threads) { threads_.reserve(num_threads); for (size_t i = 0; i < num_threads; i++) { threads_.emplace_back(&TaskThreadPool::ThreadEntry, this); diff --git a/onnxruntime/test/acc_test/task_thread_pool.h b/onnxruntime/test/acc_test/task_thread_pool.h index 318f6c76fd713..bd2f644cca9d1 100644 --- a/onnxruntime/test/acc_test/task_thread_pool.h +++ b/onnxruntime/test/acc_test/task_thread_pool.h @@ -12,7 +12,7 @@ class TaskThreadPool { public: - TaskThreadPool(int num_threads); + TaskThreadPool(size_t num_threads); ~TaskThreadPool(); void CompleteTasks(Span tasks); From dfa6b214e06684394e997a5f89828c1db5471500 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Sun, 17 Dec 2023 23:08:16 -0800 Subject: [PATCH 4/8] Use shared onnxruntime.dll --- cmake/onnxruntime_unittests.cmake | 54 +++++++++++++++++-------------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 757c1ed4c27dc..9162281e7059f 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1172,35 +1172,41 @@ endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) # Accuracy test runner - set(onnxruntime_acc_test_src_dir ${TEST_SRC_DIR}/acc_test) - set(onnxruntime_acc_test_src_patterns - "${onnxruntime_acc_test_src_dir}/*.cc" - "${onnxruntime_acc_test_src_dir}/*.h") + if (onnxruntime_BUILD_SHARED_LIB) + set(onnxruntime_acc_test_src_dir ${TEST_SRC_DIR}/acc_test) + set(onnxruntime_acc_test_src_patterns + "${onnxruntime_acc_test_src_dir}/*.cc" + "${onnxruntime_acc_test_src_dir}/*.h") - file(GLOB onnxruntime_acc_test_src CONFIGURE_DEPENDS - ${onnxruntime_acc_test_src_patterns} - ) - onnxruntime_add_executable(onnxruntime_acc_test ${onnxruntime_acc_test_src}) - target_include_directories(onnxruntime_acc_test PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) - if (WIN32) - target_compile_options(onnxruntime_acc_test PRIVATE ${disabled_warnings}) - endif() - if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") - set_target_properties(onnxruntime_acc_test PROPERTIES - XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" - ) - endif() + file(GLOB onnxruntime_acc_test_src CONFIGURE_DEPENDS + ${onnxruntime_acc_test_src_patterns} + ) + onnxruntime_add_executable(onnxruntime_acc_test ${onnxruntime_acc_test_src}) + target_include_directories(onnxruntime_acc_test PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) + if (WIN32) + target_compile_options(onnxruntime_acc_test PRIVATE ${disabled_warnings}) + endif() + if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") + set_target_properties(onnxruntime_acc_test PROPERTIES + XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" + ) + endif() - if (onnxruntime_BUILD_SHARED_LIB) set(onnxruntime_acc_test_libs onnxruntime) - target_link_libraries(onnxruntime_acc_test PRIVATE ${onnxruntime_acc_test_libs}) - endif() + if(NOT WIN32) + list(APPEND onnxruntime_acc_test_libs ${CMAKE_DL_LIBS}) + endif() + if (onnxruntime_LINK_LIBATOMIC) + list(APPEND onnxruntime_acc_test_libs atomic) + endif() + target_link_libraries(onnxruntime_acc_test PRIVATE ${onnxruntime_acc_test_libs} Threads::Threads) - set_target_properties(onnxruntime_acc_test PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(onnxruntime_acc_test PROPERTIES FOLDER "ONNXRuntimeTest") - if (onnxruntime_USE_TVM) - if (WIN32) - target_link_options(onnxruntime_acc_test PRIVATE "/STACK:4000000") + if (onnxruntime_USE_TVM) + if (WIN32) + target_link_options(onnxruntime_acc_test PRIVATE "/STACK:4000000") + endif() endif() endif() From b749cfd6e03103c577a55dd29f8e8a343092eaa2 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 18 Dec 2023 09:41:41 -0800 Subject: [PATCH 5/8] Fix linux warning/error: cast literal to unsigned for atomic_fetch_add --- onnxruntime/test/acc_test/task_thread_pool.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/acc_test/task_thread_pool.cc b/onnxruntime/test/acc_test/task_thread_pool.cc index 8bc84b5c01cf3..8252165fefa4a 100644 --- a/onnxruntime/test/acc_test/task_thread_pool.cc +++ b/onnxruntime/test/acc_test/task_thread_pool.cc @@ -72,13 +72,13 @@ bool TaskThreadPool::RunNextTask() { return false; } - const size_t task_index = std::atomic_fetch_add(&next_task_index_, 1); + const size_t task_index = std::atomic_fetch_add(&next_task_index_, static_cast(1)); if (task_index >= tasks_.size()) { return false; } tasks_[task_index].Run(); - std::atomic_fetch_add(&tasks_completed_, 1); + std::atomic_fetch_add(&tasks_completed_, static_cast(1)); return true; } From e28674bd83659b3ed5936a7c9c9cd5e6393b5e00 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 18 Dec 2023 09:49:26 -0800 Subject: [PATCH 6/8] Run lintrunner --- onnxruntime/test/acc_test/basic_utils.cc | 1 - onnxruntime/test/acc_test/model_io_utils.cc | 1 - onnxruntime/test/acc_test/model_io_utils.h | 1 - 3 files changed, 3 deletions(-) diff --git a/onnxruntime/test/acc_test/basic_utils.cc b/onnxruntime/test/acc_test/basic_utils.cc index fcaf816b1b565..dbf48af552af8 100644 --- a/onnxruntime/test/acc_test/basic_utils.cc +++ b/onnxruntime/test/acc_test/basic_utils.cc @@ -53,4 +53,3 @@ int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefi return index; } - diff --git a/onnxruntime/test/acc_test/model_io_utils.cc b/onnxruntime/test/acc_test/model_io_utils.cc index de89006f78eb7..9ea9882d8f00a 100644 --- a/onnxruntime/test/acc_test/model_io_utils.cc +++ b/onnxruntime/test/acc_test/model_io_utils.cc @@ -200,4 +200,3 @@ size_t ModelIOInfo::GetTotalOutputSize() const { return total_size; } - diff --git a/onnxruntime/test/acc_test/model_io_utils.h b/onnxruntime/test/acc_test/model_io_utils.h index 7d23e7b23cbe4..fc901a6689d8a 100644 --- a/onnxruntime/test/acc_test/model_io_utils.h +++ b/onnxruntime/test/acc_test/model_io_utils.h @@ -81,7 +81,6 @@ struct ModelIOInfo { std::vector inputs; std::vector outputs; - }; AccMetrics ComputeAccuracyMetric(Ort::ConstValue ort_output, Span raw_expected_output, From e8c11e136684927224ba94f83fd3b3f463d5f024 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 18 Dec 2023 10:30:04 -0800 Subject: [PATCH 7/8] Add std::filesystem library to linux builds --- cmake/onnxruntime_unittests.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 9162281e7059f..5f8c7bd084d02 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1194,7 +1194,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) set(onnxruntime_acc_test_libs onnxruntime) if(NOT WIN32) - list(APPEND onnxruntime_acc_test_libs ${CMAKE_DL_LIBS}) + list(APPEND onnxruntime_acc_test_libs stdc++fs ${CMAKE_DL_LIBS}) endif() if (onnxruntime_LINK_LIBATOMIC) list(APPEND onnxruntime_acc_test_libs atomic) From 00fa9aefac4831f0184fcee77d5b2b4b988dfa7f Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 18 Dec 2023 11:02:55 -0800 Subject: [PATCH 8/8] Remove Task() = default. Fix conversion that may cause data loss. --- onnxruntime/test/acc_test/acc_task.h | 1 - onnxruntime/test/acc_test/basic_utils.cc | 11 +++++------ onnxruntime/test/acc_test/basic_utils.h | 2 +- onnxruntime/test/acc_test/data_loader.cc | 9 +++++---- onnxruntime/test/acc_test/main.cc | 4 ++-- onnxruntime/test/acc_test/model_io_utils.h | 2 +- 6 files changed, 14 insertions(+), 15 deletions(-) diff --git a/onnxruntime/test/acc_test/acc_task.h b/onnxruntime/test/acc_test/acc_task.h index 51f472528c958..9fe62da562d1a 100644 --- a/onnxruntime/test/acc_test/acc_task.h +++ b/onnxruntime/test/acc_test/acc_task.h @@ -19,7 +19,6 @@ class Task { }; public: - Task() = default; Task(Task&& other) = default; Task(const Task& other) = default; Task(Ort::Session& session, const ModelIOInfo& model_io_info, diff --git a/onnxruntime/test/acc_test/basic_utils.cc b/onnxruntime/test/acc_test/basic_utils.cc index dbf48af552af8..d72d496af2bfc 100644 --- a/onnxruntime/test/acc_test/basic_utils.cc +++ b/onnxruntime/test/acc_test/basic_utils.cc @@ -11,12 +11,11 @@ bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepat return false; } - size_t file_byte_size = 0; input_ifs.seekg(0, input_ifs.end); - file_byte_size = input_ifs.tellg(); + auto file_byte_size = input_ifs.tellg(); input_ifs.seekg(0, input_ifs.beg); - if (file_byte_size != array.size()) { + if (static_cast(file_byte_size) != array.size()) { return false; } @@ -24,8 +23,8 @@ bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepat return static_cast(input_ifs); } -int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix) { - int64_t index = -1; +int32_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix) { + int32_t index = -1; const char* str = filename_wo_ext.c_str(); // Move past the prefix. @@ -41,7 +40,7 @@ int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefi // Parse the input index from file name. index = 0; while (*str) { - int64_t c = *str; + int32_t c = *str; if (!(c >= '0' && c <= '9')) { return -1; // Not a number. } diff --git a/onnxruntime/test/acc_test/basic_utils.h b/onnxruntime/test/acc_test/basic_utils.h index 885992d6f13e2..44c0f0e758b60 100644 --- a/onnxruntime/test/acc_test/basic_utils.h +++ b/onnxruntime/test/acc_test/basic_utils.h @@ -56,7 +56,7 @@ constexpr int64_t GetShapeSize(Span shape) { return size; } -int64_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix); +int32_t GetFileIndexSuffix(const std::string& filename_wo_ext, const char* prefix); bool FillBytesFromBinaryFile(Span array, const std::string& binary_filepath); constexpr double EPSILON_DBL = 2e-16; diff --git a/onnxruntime/test/acc_test/data_loader.cc b/onnxruntime/test/acc_test/data_loader.cc index 232bdc96898bc..222cb83705534 100644 --- a/onnxruntime/test/acc_test/data_loader.cc +++ b/onnxruntime/test/acc_test/data_loader.cc @@ -39,14 +39,15 @@ bool LoadIODataFromDisk(const std::vector& dataset_paths, continue; } - const int64_t io_index = GetFileIndexSuffix(data_filename_wo_ext, data_file_prefix); - if (io_index < 0) { + const int32_t io_index_s32 = GetFileIndexSuffix(data_filename_wo_ext, data_file_prefix); + if (io_index_s32 < 0) { std::cerr << "[ERROR]: The file " << data_file_path << " does not have a properly formatted name" << " (e.g., " << data_file_prefix << "0.raw)" << std::endl; return false; } - if (io_index >= static_cast(io_infos.size())) { + const size_t io_index = static_cast(io_index_s32); + if (io_index >= io_infos.size()) { std::cerr << "[ERROR]: The input (or output) file index for file " << data_file_path << " exceeds the number of inputs (or outputs) in the model (" << io_infos.size() << ")" << std::endl; @@ -54,7 +55,7 @@ bool LoadIODataFromDisk(const std::vector& dataset_paths, } size_t offset = 0; - for (int64_t i = 0; i < io_index; i++) { + for (size_t i = 0; i < io_index; i++) { offset += io_infos[i].total_data_size; } assert(offset < total_data_size); diff --git a/onnxruntime/test/acc_test/main.cc b/onnxruntime/test/acc_test/main.cc index a92bd94a4eae5..486367d38d61a 100644 --- a/onnxruntime/test/acc_test/main.cc +++ b/onnxruntime/test/acc_test/main.cc @@ -38,8 +38,8 @@ static std::vector GetSortedDatasetPaths(const std::files auto cmp_indexed_paths = [dataset_prefix](const std::filesystem::path& a, const std::filesystem::path& b) -> bool { - const int64_t a_index = GetFileIndexSuffix(a.filename().string(), dataset_prefix); - const int64_t b_index = GetFileIndexSuffix(b.filename().string(), dataset_prefix); + const int32_t a_index = GetFileIndexSuffix(a.filename().string(), dataset_prefix); + const int32_t b_index = GetFileIndexSuffix(b.filename().string(), dataset_prefix); return a_index < b_index; }; diff --git a/onnxruntime/test/acc_test/model_io_utils.h b/onnxruntime/test/acc_test/model_io_utils.h index fc901a6689d8a..27bf42cebe7bf 100644 --- a/onnxruntime/test/acc_test/model_io_utils.h +++ b/onnxruntime/test/acc_test/model_io_utils.h @@ -24,7 +24,7 @@ struct IOInfo { return false; } - const size_t total_data_size = elem_size * GetShapeSize(Span(shape)); + const size_t total_data_size = elem_size * static_cast(GetShapeSize(Span(shape))); io_info.name = name; io_info.shape = std::move(shape);