From fcdaf3bd0366c58434e08a284d3e2a7b3a539c14 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Tue, 23 Jul 2024 17:28:57 +0800 Subject: [PATCH] add registered custom op for perf test --- onnxruntime/test/perftest/command_args_parser.cc | 6 +++++- onnxruntime/test/perftest/ort_test_session.cc | 4 ++++ onnxruntime/test/perftest/test_configuration.h | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 84c3bc16346f3..7d06bbadbd645 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -144,6 +144,7 @@ namespace perftest { "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-l Provide file as binary in memory by using fopen before session creation.\n" + "\t-R [Register custom op]: allow user to register custom op by .so or .dll file.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -206,7 +207,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznl"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqznlR:"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -393,6 +394,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'l': test_config.model_info.load_via_path = true; break; + case 'R': + test_config.run_config.register_custom_op_path = optarg; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fc1bdb10d7453..6fb999b3efc07 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -636,6 +636,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); session_options.AddConfigEntry(kOrtSessionOptionsConfigForceSpinningStop, "1"); } + if (!performance_test_config.run_config.register_custom_op_path.empty()) { + session_options.RegisterCustomOpsLibrary(performance_test_config.run_config.register_custom_op_path.c_str()); + } + if (performance_test_config.run_config.execution_mode == ExecutionMode::ORT_PARALLEL && performance_test_config.run_config.inter_op_num_threads > 0) { fprintf(stdout, "Setting inter_op_num_threads to %d\n", performance_test_config.run_config.inter_op_num_threads); session_options.SetInterOpNumThreads(performance_test_config.run_config.inter_op_num_threads); diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 209fb55fe93d4..90759a4d2f65a 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -65,6 +65,7 @@ struct RunConfig { bool disable_spinning = false; bool disable_spinning_between_run = false; bool exit_after_session_creation = false; + std::basic_string register_custom_op_path; }; struct PerformanceTestConfig {