diff --git a/include/onnxruntime/core/framework/op_kernel_info.h b/include/onnxruntime/core/framework/op_kernel_info.h index b31c85e32f80c..a0bbfe50a700b 100644 --- a/include/onnxruntime/core/framework/op_kernel_info.h +++ b/include/onnxruntime/core/framework/op_kernel_info.h @@ -28,7 +28,8 @@ class OpKernelInfo : public OpNodeProtoHelper { const std::unordered_map& constant_initialized_tensors, const OrtValueNameIdxMap& mlvalue_name_idx_map, const DataTransferManager& data_transfer_mgr, - const AllocatorMap& allocators = {}); + const AllocatorMap& allocators, + const ConfigOptions& config_options); OpKernelInfo(const OpKernelInfo& other); @@ -50,6 +51,8 @@ class OpKernelInfo : public OpNodeProtoHelper { const AllocatorMap& GetAllocators() const { return allocators_; } + const ConfigOptions& GetConfigOptions() const { return config_options_; } + private: ORT_DISALLOW_MOVE(OpKernelInfo); ORT_DISALLOW_ASSIGNMENT(OpKernelInfo); @@ -64,6 +67,7 @@ class OpKernelInfo : public OpNodeProtoHelper { const DataTransferManager& data_transfer_mgr_; ProtoHelperNodeContext proto_helper_context_; const AllocatorMap& allocators_; + const ConfigOptions& config_options_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index b2ef853119588..f8ccdb8fb0238 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -24,7 +24,8 @@ Status KernelRegistryManager::CreateKernel(const Node& node, session_state.GetConstantInitializedTensors(), session_state.GetOrtValueNameIdxMap(), session_state.GetDataTransferMgr(), - session_state.GetAllocators()); + session_state.GetAllocators(), + session_state.GetSessionOptions().config_options); return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out); } diff --git a/onnxruntime/core/framework/op_kernel_info.cc b/onnxruntime/core/framework/op_kernel_info.cc index 841fdb585f0d8..28793dae36d20 100644 --- a/onnxruntime/core/framework/op_kernel_info.cc +++ b/onnxruntime/core/framework/op_kernel_info.cc @@ -15,7 +15,8 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, const std::unordered_map& constant_initialized_tensors, const OrtValueNameIdxMap& ort_value_name_idx_map, const DataTransferManager& data_transfer_mgr, - const AllocatorMap& allocators) + const AllocatorMap& allocators, + const ConfigOptions& config_options) : OpNodeProtoHelper(&proto_helper_context_), node_(node), kernel_def_(kernel_def), @@ -24,15 +25,22 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node, ort_value_name_idx_map_(ort_value_name_idx_map), data_transfer_mgr_(data_transfer_mgr), proto_helper_context_(node), - allocators_(allocators) {} + allocators_(allocators), + config_options_(config_options) { +} OpKernelInfo::OpKernelInfo(const OpKernelInfo& other) : OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_, - other.ort_value_name_idx_map_, other.data_transfer_mgr_, other.allocators_) {} + other.ort_value_name_idx_map_, other.data_transfer_mgr_, + other.allocators_, other.config_options_) { +} AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const { auto it = allocators_.find(execution_provider_->GetOrtDeviceByMemType(mem_type)); - if (it != allocators_.end()) return it->second; + if (it != allocators_.end()) { + return it->second; + } + return nullptr; } diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index e3a2f2d74c0d4..9df300d6f4f88 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -18,10 +18,12 @@ namespace onnxruntime { ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider, bool skip_dequantize_linear, + const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers, const InlinedHashSet& excluded_initializers) noexcept : GraphTransformer("ConstantFolding", compatible_execution_providers), skip_dequantize_linear_(skip_dequantize_linear), + config_options_(config_options), excluded_initializers_(excluded_initializers), execution_provider_(execution_provider) { } @@ -250,12 +252,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, // override the EP assigned to the node so that it will use the CPU kernel for Compute. node->SetExecutionProviderType(kCpuExecutionProvider); - kernel = info.CreateKernel(node); + kernel = info.CreateKernel(node, config_options_); // undo the EP change to the value that was assigned at graph partitioning time node->SetExecutionProviderType(ep_type); } else { - kernel = info.CreateKernel(node); + kernel = info.CreateKernel(node, config_options_); } // We currently constant fold using the CPU EP only. diff --git a/onnxruntime/core/optimizer/constant_folding.h b/onnxruntime/core/optimizer/constant_folding.h index 47934307e8930..14eb2a9c5f06b 100644 --- a/onnxruntime/core/optimizer/constant_folding.h +++ b/onnxruntime/core/optimizer/constant_folding.h @@ -24,6 +24,7 @@ class ConstantFolding : public GraphTransformer { */ ConstantFolding(const IExecutionProvider& execution_provider, bool skip_dequantize_linear, + const ConfigOptions& config_options, const InlinedHashSet& compatible_execution_providers = {}, const InlinedHashSet& excluded_initializers = {}) noexcept; @@ -31,6 +32,7 @@ class ConstantFolding : public GraphTransformer { Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; bool skip_dequantize_linear_; + const ConfigOptions& config_options_; const InlinedHashSet excluded_initializers_; const IExecutionProvider& execution_provider_; }; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 3d6251a694cfb..cd3c49be15aa4 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -223,7 +223,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(no_limit_empty_ep_list, excluded_initializers)); transformers.emplace_back(std::make_unique()); - transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq)); + transformers.emplace_back(std::make_unique(cpu_execution_provider, !disable_quant_qdq, + session_options.config_options)); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique()); transformers.emplace_back(std::make_unique( diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index 46041bca9dcc1..1eabc079f3a20 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -128,26 +128,34 @@ static Status TryCreateKernel(const Node& node, const OrtValueNameIdxMap& ort_value_name_idx_map, FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, + const ConfigOptions& config_options, /*out*/ std::unique_ptr& op_kernel) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelCreateInfo* kernel_create_info = nullptr; ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, &kernel_create_info)); + + static const AllocatorMap dummy_allocators; + OpKernelInfo kernel_info(node, *kernel_create_info->kernel_def, execution_provider, constant_initialized_tensors, ort_value_name_idx_map, - data_transfer_mgr); + data_transfer_mgr, + dummy_allocators, + config_options); + return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel); } -std::unique_ptr OptimizerExecutionFrame::Info::CreateKernel(const Node* node) const { +std::unique_ptr +OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOptions& config_options) const { std::unique_ptr op_kernel; std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, + ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, op_kernel); // Kernel found in the CPU kernel registry diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index 13cf9e652c404..3dbf6c1d97aa6 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -27,11 +27,13 @@ class OptimizerExecutionFrame final : public IExecutionFrame { const Path& model_path, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func); + Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const Path& model_path, const IExecutionProvider& execution_provider, const std::function& is_sparse_initializer_func); + ~Info() = default; const AllocatorPtr& GetAllocator() const { @@ -52,7 +54,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { return -1; } - std::unique_ptr CreateKernel(const Node* node) const; + std::unique_ptr CreateKernel(const Node* node, const ConfigOptions& config_options) const; // Check if an kernel create info can be found in the registry. Status TryFindKernel(const Node* node, const KernelCreateInfo** out) const; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 76533a0061702..53ba4874c643c 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -132,6 +132,7 @@ struct Logger; struct Capture; } // namespace logging struct ComputeCapability; +struct ConfigOptions; struct DataTransferManager; struct IndexedSubGraph; struct IndexedSubGraph_MetaDef; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 2883d92e90dba..21c14ce784a38 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + // Public wrappers around internal ort interfaces (currently) #include "core/providers/shared_library/provider_host_api.h" @@ -426,6 +428,9 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + // ConfigOptions + virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; + // ComputeCapability virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; @@ -808,6 +813,7 @@ struct ProviderHost { virtual uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) = 0; virtual uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) = 0; virtual const Node& OpKernelInfo__node(const OpKernelInfo* p) = 0; + virtual const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) = 0; // SessionState virtual const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 149a43222b445..eaf8ef459cf00 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -335,6 +335,14 @@ struct DataTypeUtils final { } // namespace Utils +struct ConfigOptions final { + std::optional GetConfigEntry(const std::string& config_key) const { + return g_host->ConfigOptions__GetConfigEntry(this, config_key); + } + + PROVIDER_DISALLOW_ALL(ConfigOptions) +}; + struct ComputeCapability final { static std::unique_ptr Create(std::unique_ptr t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); } static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast(p)); } @@ -901,6 +909,8 @@ struct OpKernelInfo final { const Node& node() const noexcept { return g_host->OpKernelInfo__node(this); } + const ConfigOptions& GetConfigOptions() const { return g_host->OpKernelInfo__GetConfigOptions(this); } + OpKernelInfo() = delete; OpKernelInfo(const OpKernelInfo&) = delete; void operator=(const OpKernelInfo&) = delete; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2df30ba2de391..b9fd79997a538 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -6,6 +6,7 @@ #include "core/common/inlined_containers.h" #include "core/framework/allocator_utils.h" +#include "core/framework/config_options.h" #include "core/framework/compute_capability.h" #include "core/framework/data_types.h" #include "core/framework/data_transfer_manager.h" @@ -529,6 +530,11 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + // ConfigOptions (wrapped) + std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override { + return p->GetConfigEntry(config_key); + } + // ComputeCapability (wrapped) std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } @@ -934,6 +940,7 @@ struct ProviderHostImpl : ProviderHost { uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) override { return p->GetInputCount(); } uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) override { return p->GetOutputCount(); } const Node& OpKernelInfo__node(const OpKernelInfo* p) override { return p->node(); } + const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) override { return p->GetConfigOptions(); } // SessionState (wrapped) const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); } diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index b3128571f16ff..9cbf01946e92b 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -421,7 +421,10 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info, static const OrtValueNameIdxMap kEmptyNameMap; OpKernelInfo tmp_kernel_info(*node_ptr.get(), *kernel_def, *ep, kEmptyValueMap, kEmptyNameMap, - kernel_info->GetDataTransferManager(), kernel_info->GetAllocators()); + kernel_info->GetDataTransferManager(), + kernel_info->GetAllocators(), + kernel_info->GetConfigOptions()); + std::unique_ptr op_kernel; auto& node_repo = NodeRepo::GetInstance(); diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index 2147a4253ef39..b174ee4138be3 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -254,7 +254,7 @@ class PlannerTest : public ::testing::Test { ASSERT_NE(ep, nullptr); auto info = std::make_unique( *p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(), - state_->GetDataTransferMgr()); + state_->GetDataTransferMgr(), state_->GetAllocators(), state_->GetSessionOptions().config_options); op_kernel_infos_.push_back(std::move(info)); const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 2522ee3b496f6..60effda9ec772 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -82,6 +82,11 @@ ProviderInfo_ROCM& GetProviderInfo_ROCM(); class FuseAdd : public OpKernel { public: explicit FuseAdd(const OpKernelInfo& info) : OpKernel(info) { + // logic for testing that a session options config value can be read here + auto test_throw_in_ctor = info.GetConfigOptions().GetConfigEntry("ThrowInKernelCtor"); + if (test_throw_in_ctor == "1") { + ORT_THROW("Test exception in ctor"); + }; } Status Compute(OpKernelContext* context) const override { @@ -96,6 +101,7 @@ class FuseAdd : public OpKernel { return Status::OK(); } }; + constexpr const char* kFuseTest = "FuseTest"; constexpr const char* kFuseExecutionProvider = "FuseExecutionProvider"; class ONNX_OPERATOR_KERNEL_CLASS_NAME(kFuseExecutionProvider, kFuseTest, 1, FuseAdd); @@ -1263,28 +1269,22 @@ TEST(InferenceSessionTests, TestOptionalInputs) { ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); } // required, optional and invalid input - status = RunOptionalInputTest(true, true, true, version, sess_env); - ASSERT_FALSE(status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name")); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(true, true, true, version, sess_env), + "Invalid input name"); // missing required - status = RunOptionalInputTest(false, true, false, version, sess_env); - ASSERT_FALSE(status.IsOK()); - if (version == 3) { - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Invalid input name")); - } else { - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Missing Input:")); - } + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(false, true, false, version, sess_env), + (version == 3 ? "Invalid input name" : "Missing Input:")); } } -TEST(ExecutionProviderTest, FunctionTest) { - onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); +static void CreateFuseOpModel(const std::string& model_file_name) { + onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); std::vector inputs; std::vector outputs; - // FLOAT tensor. ONNX_NAMESPACE::TypeProto float_tensor; float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); @@ -1307,18 +1307,19 @@ TEST(ExecutionProviderTest, FunctionTest) { outputs.push_back(&output_arg_2); graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(graph.Resolve()); + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_file_name)); +} + +TEST(ExecutionProviderTest, FunctionTest) { std::string model_file_name = "execution_provider_test_graph.onnx"; - status = onnxruntime::Model::Save(model, model_file_name); + CreateFuseOpModel(model_file_name); SessionOptions so; so.session_logid = "ExecutionProviderTest.FunctionTest"; - InferenceSession session_object{so, GetEnvironment()}; - status = session_object.Load(model_file_name); - ASSERT_TRUE(status.IsOK()); - status = session_object.Initialize(); - ASSERT_TRUE(status.IsOK()); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_file_name)); + ASSERT_STATUS_OK(session.Initialize()); RunOptions run_options; run_options.run_tag = so.session_logid; @@ -1329,11 +1330,14 @@ TEST(ExecutionProviderTest, FunctionTest) { std::vector dims_mul_x = {3, 2}; std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; OrtValue ml_value_x; - CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_x); + CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, + &ml_value_x); OrtValue ml_value_y; - CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_y); + CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, + &ml_value_y); OrtValue ml_value_z; - CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, &ml_value_z); + CreateMLValue(testCPUExecutionProvider->CreatePreferredAllocators()[0], dims_mul_x, values_mul_x, + &ml_value_z); NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); feeds.insert(std::make_pair("Y", ml_value_y)); @@ -1349,67 +1353,33 @@ TEST(ExecutionProviderTest, FunctionTest) { std::vector expected_values_mul_m = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; // Now run - status = session_object.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(session.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); - InferenceSession session_object_2{so, GetEnvironment()}; - ASSERT_STATUS_OK( - session_object_2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); - ASSERT_STATUS_OK(session_object_2.Load(model_file_name)); - ASSERT_STATUS_OK(session_object_2.Initialize()); - ASSERT_STATUS_OK(session_object_2.Run(run_options, feeds, output_names, &fetches)); + InferenceSession session2{so, GetEnvironment()}; + ASSERT_STATUS_OK(session2.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); + ASSERT_STATUS_OK(session2.Load(model_file_name)); + ASSERT_STATUS_OK(session2.Initialize()); + ASSERT_STATUS_OK(session2.Run(run_options, feeds, output_names, &fetches)); VerifyOutputs(fetches, expected_dims_mul_m, expected_values_mul_m); } TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) { - onnxruntime::Model model("graph_1", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); - auto& graph = model.MainGraph(); - std::vector inputs; - std::vector outputs; - - // FLOAT tensor. - ONNX_NAMESPACE::TypeProto float_tensor; - float_tensor.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); - float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(2); - - auto& input_arg_1 = graph.GetOrCreateNodeArg("X", &float_tensor); - auto& input_arg_2 = graph.GetOrCreateNodeArg("Y", &float_tensor); - inputs.push_back(&input_arg_1); - inputs.push_back(&input_arg_2); - auto& output_arg = graph.GetOrCreateNodeArg("node_1_out_1", &float_tensor); - outputs.push_back(&output_arg); - graph.AddNode("node_1", "Add", "node 1.", inputs, outputs); - - auto& input_arg_3 = graph.GetOrCreateNodeArg("Z", &float_tensor); - inputs.clear(); - inputs.push_back(&output_arg); - inputs.push_back(&input_arg_3); - auto& output_arg_2 = graph.GetOrCreateNodeArg("M", &float_tensor); - outputs.clear(); - outputs.push_back(&output_arg_2); - graph.AddNode("node_2", "Add", "node 2.", inputs, outputs); - - auto status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()); std::string model_file_name = "fused_node_shape_inference_test_graph.onnx"; - status = onnxruntime::Model::Save(model, model_file_name); + + CreateFuseOpModel(model_file_name); SessionOptions so; so.session_logid = "ExecutionProviderTest.ShapeInferenceForFusedFunctionTest"; InferenceSessionWrapper session{so, GetEnvironment()}; - ASSERT_STATUS_OK( - session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); - status = session.Load(model_file_name); - ASSERT_TRUE(status.IsOK()); - status = session.Initialize(); - ASSERT_TRUE(status.IsOK()); + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); + ASSERT_STATUS_OK(session.Load(model_file_name)); + ASSERT_STATUS_OK(session.Initialize()); Graph& fused_graph = session.GetMutableGraph(); - ASSERT_TRUE(fused_graph.NumberOfNodes() == 1); + ASSERT_EQ(fused_graph.NumberOfNodes(), 1); auto& fused_node = *fused_graph.Nodes().begin(); - ASSERT_TRUE(fused_node.NodeType() == Node::Type::Fused); + ASSERT_EQ(fused_node.NodeType(), Node::Type::Fused); ASSERT_TRUE(fused_node.Op()->has_type_and_shape_inference_function()); // Clear shape inference data from output node to verify that assigned inference function is called @@ -1419,7 +1389,25 @@ TEST(ExecutionProviderTest, ShapeInferenceForFusedFunctionTest) { ASSERT_STATUS_OK(fused_graph.Resolve()); ASSERT_TRUE(fused_node_output.Shape() != nullptr); - ASSERT_TRUE(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()) == utils::GetTensorShapeFromTensorShapeProto(float_tensor.tensor_type().shape())); + ASSERT_EQ(utils::GetTensorShapeFromTensorShapeProto(*fused_node_output.Shape()), TensorShape({3, 2})); +} + +TEST(ExecutionProviderTest, OpKernelInfoCanReadConfigOptions) { + std::string model_file_name = "OpKernelInfoCanReadConfigOptions.onnx"; + CreateFuseOpModel(model_file_name); + + SessionOptions so; + so.session_logid = "ExecutionProviderTest.OpKernelInfoCanReadConfigOptions"; + + // add a config key that if read causes the Fuse op kernel to throw in the ctor. this is just to test the value is passed + // through in the simplest way, as the kernel is constructed in InferenceSession::Intialize so we don't need to + // actually run the model. + ASSERT_STATUS_OK(so.config_options.AddConfigEntry("ThrowInKernelCtor", "1")); + + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::make_unique<::onnxruntime::FuseExecutionProvider>())); + ASSERT_STATUS_OK(session.Load(model_file_name)); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Initialize(), "Test exception in ctor"); } TEST(InferenceSessionTests, Test3LayerNestedSubgraph) { diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index e1ce1d4abf81d..8990c23e4af39 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -84,9 +84,10 @@ TEST_P(SessionStateAddGetKernelTest, AddGetKernelTest) { auto kernel_def = KernelDefBuilder().SetName("Variable").Provider(kCpuExecutionProvider).SinceVersion(1, 10).Build(); OpKernelInfo p_info(node, *kernel_def, *cpu_execution_provider, s.GetConstantInitializedTensors(), - s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr()); - unique_ptr p_kernel; - p_kernel.reset(new TestOpKernel(p_info)); + s.GetOrtValueNameIdxMap(), s.GetDataTransferMgr(), s.GetAllocators(), + s.GetSessionOptions().config_options); + + std::unique_ptr p_kernel = std::make_unique(p_info); size_t orig_num_outputs = p_kernel->Node().OutputDefs().size(); std::cout << "node_idx: " << node.Index() << std::endl; diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 24f34492954aa..4b676021dfe6c 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -1503,10 +1503,8 @@ TEST_F(GraphTest, ShapeInferenceErrorHandling) { graph.AddNode("node_1", "ShapeInferenceThrowsOp", "node 1", {&input_arg1}, {&output_arg1}); - auto status = graph.Resolve(); - EXPECT_FALSE(status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Node (node_1) Op (ShapeInferenceThrowsOp) " - "[ShapeInferenceError] try harder")); + EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(graph.Resolve(), + "Node (node_1) Op (ShapeInferenceThrowsOp) [ShapeInferenceError] try harder"); } TEST_F(GraphTest, AddTensorAttribute) { @@ -2024,10 +2022,9 @@ TEST_F(GraphTest, LoadModelMissingInput) { SetTypeAndShape(output->mutable_type()->mutable_tensor_type(), 1, {2, 2}); std::shared_ptr model; - Status st = Model::Load(std::move(m), model, nullptr, *logger_); - ASSERT_FALSE(st.IsOK()); - ASSERT_THAT(st.ErrorMessage(), testing::HasSubstr("Invalid model. Node input 'y' is not a graph input, " - "initializer, or output of a previous node.")); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(Model::Load(std::move(m), model, nullptr, *logger_), + "Invalid model. Node input 'y' is not a graph input, " + "initializer, or output of a previous node."); } // if an initializer is backing an optional graph input, it can't be removed even if unused in the graph. diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc index 77590f5c0a304..cf859facf4765 100644 --- a/onnxruntime/test/onnx/microbenchmark/activation.cc +++ b/onnxruntime/test/onnx/microbenchmark/activation.cc @@ -69,7 +69,18 @@ struct KernelAndDef { .SetDomain(domain) .TypeConstraint("T", DataTypeImpl::GetTensorType()) .Build(); - OpKernelInfo info(main_node, *out.def, *out.a, {}, {}, {}); + + // these usually come from the session state. OpKernelInfo stores references to them so we need a valid backing + // instance even though we don't use them in this test. + static const std::unordered_map constant_initialized_tensors; + static const OrtValueNameIdxMap mlvalue_name_idx_map; + static const DataTransferManager data_transfer_mgr; + static const AllocatorMap allocators; + static const ConfigOptions config_options; + OpKernelInfo info(main_node, *out.def, *out.a, + constant_initialized_tensors, mlvalue_name_idx_map, data_transfer_mgr, allocators, + config_options); + out.kernel = std::make_unique(info); return out; } diff --git a/onnxruntime/test/optimizer/cse_test.cc b/onnxruntime/test/optimizer/cse_test.cc index cccfc8d77fcea..bad96406df845 100644 --- a/onnxruntime/test/optimizer/cse_test.cc +++ b/onnxruntime/test/optimizer/cse_test.cc @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "test/framework/test_utils.h" -#include "test/test_environment.h" #include "core/graph/model.h" #include "core/optimizer/common_subexpression_elimination.h" #include "core/optimizer/graph_transformer_mgr.h" +#include "test/framework/test_utils.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" #ifdef ENABLE_TRAINING #include "orttraining/core/optimizer/graph_transformer_utils.h" @@ -272,20 +273,21 @@ TEST(CseTests, MergedValueAndGraphOutputAreOutputsOfSameNode) { TEST(CseTests, MergeConstants) { auto model_uri = ORT_TSTR("testdata/transform/cse/cse_merge_constants.onnx"); std::shared_ptr model; - ASSERT_TRUE(Model::Load(model_uri, model, nullptr, - DefaultLoggingManager().DefaultLogger()) - .IsOK()); + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, DefaultLoggingManager().DefaultLogger())); + Graph& graph = model->MainGraph(); GraphTransformerManager graph_transformation_mgr(1); // In current implementation, equal constants are not merged. So CSE must precede constant folding, otherwise we end up // with multiple copies of the same constant. std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); - ASSERT_TRUE( - graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1).IsOK()); - ASSERT_TRUE( - graph_transformation_mgr.Register(std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1).IsOK()); - ASSERT_TRUE( - graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, DefaultLoggingManager().DefaultLogger()).IsOK()); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); + const ConfigOptions empty_config_options; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, + DefaultLoggingManager().DefaultLogger())); ASSERT_EQ(graph.GetAllInitializedTensors().size(), 1U); auto op_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 5adcb3c150b8d..bf02c1741725f 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -575,12 +575,14 @@ TEST_F(GraphTransformationTests, ConstantFolding) { ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + ASSERT_EQ(op_to_count["Unsqueeze"], 2); + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); @@ -595,11 +597,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingNodesOnDifferentEP) { Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Unsqueeze"] == 2); - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; + ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); // assign all nodes to CUDA. the constant folding should override this to perform the constant folding on cpu for (auto& node : graph.Nodes()) { @@ -624,11 +628,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingUnsupportedFloat16) { Graph& graph = model->MainGraph(); std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Mul"] == 1); - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); // assign all nodes to CUDA. the constant folding should try folding the node on the CPU and fail, thus leaving the // EP as CUDA and not constant folding the node. @@ -707,11 +712,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingSubgraph) { std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count["Add"] == 2); // one in each subgraph - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); @@ -731,14 +737,15 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithShapeToInitializer) { ASSERT_TRUE(op_to_count["Unsqueeze"] == 3); InlinedHashSet compatible_eps; - InlinedHashSet excluded_initializers; - excluded_initializers.insert("matmul_weight"); + InlinedHashSet excluded_initializers = {"matmul_weight"}; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); ASSERT_STATUS_OK(graph_transformation_mgr.Register( std::make_unique(*e.get(), false /*skip_dequantize_linear*/, + empty_config_options, compatible_eps, excluded_initializers), TransformerLevel::Level1)); @@ -763,11 +770,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingWithScalarShapeToInitializer) { InlinedHashSet compatible_eps; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), - false /*skip_dequantize_linear*/, + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options, compatible_eps), TransformerLevel::Level1)); @@ -792,11 +799,11 @@ TEST_F(GraphTransformationTests, ConstantFoldingForOpsWithMissingOptionalInputs) InlinedHashSet compatible_eps; onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), - false /*skip_dequantize_linear*/, + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options, compatible_eps), TransformerLevel::Level1)); @@ -965,11 +972,12 @@ TEST_F(GraphTransformationTests, ConstantFolding_RemoveDanglingInputNodesToConst ASSERT_TRUE(op_to_count["Add"] == 1); // Input node to Shape ASSERT_TRUE(op_to_count["RandomUniform"] == 1); // Input node to Add - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); @@ -988,10 +996,13 @@ TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) { ASSERT_TRUE(op_to_count["Shape"] == 4); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - std::unique_ptr e = - std::make_unique(CPUExecutionProviderInfo()); + const ConfigOptions empty_config_options; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); @@ -1014,9 +1025,12 @@ TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) { ASSERT_EQ(op_to_count["Identity"], 1); onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + const ConfigOptions empty_config_options; std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( - std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + std::make_unique(*e.get(), false /*skip_dequantize_linear*/, empty_config_options), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); op_to_count = CountOpsInGraph(graph); diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 2ce1e3881d81d..79704f2cc79e3 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -27,7 +27,8 @@ namespace test { static const std::string MODEL_FOLDER = "testdata/transform/"; TEST(OptimizerTest, Basic) { - Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); auto& graph = model.MainGraph(); constexpr int tensor_dim = 10; @@ -65,8 +66,7 @@ TEST(OptimizerTest, Basic) { nodes.push_back(&node); } - std::unique_ptr cpu_execution_provider = - std::make_unique(CPUExecutionProviderInfo()); + auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); #if !defined(DISABLE_SPARSE_TENSORS) OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, graph.ModelPath(), @@ -85,8 +85,10 @@ TEST(OptimizerTest, Basic) { OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); + const ConfigOptions empty_config_options; + for (auto& node : graph.Nodes()) { - auto kernel = info.CreateKernel(&node); + auto kernel = info.CreateKernel(&node, empty_config_options); // kernel can only be a nullptr if a CPU kernel implementation has been removed, // if that is the case, OpKernelContext instance construction will throw in the next step diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 3d46893cdb82d..e5f3956438b7a 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -248,10 +248,9 @@ static common::Status CreateSubgraph(Graph& graph, RunOptions& options, const st auto status = graph.Resolve(); if (failure_message.empty()) { - EXPECT_EQ(status, Status::OK()); + EXPECT_STATUS_OK(status); } else { - EXPECT_TRUE(!status.IsOK()); - EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr(failure_message)); + EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(status, failure_message); } return status; diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 8955a83e66c01..aba74484a644b 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -153,9 +153,8 @@ TEST(InternalTestingEP, PreventSaveOfModelWithCompiledOps) { std::make_unique(supported_ops))); ASSERT_STATUS_OK(session->Load(ort_model_path)); - auto status = session->Initialize(); - ASSERT_FALSE(status.IsOK()) << "Initialize should have failed when trying to save model with compiled kernels"; - ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Unable to serialize model as it contains compiled nodes")); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session->Initialize(), + "Unable to serialize model as it contains compiled nodes"); } // the internal NHWC operators are only included as part of contrib ops currently. as the EP requests the NHWC @@ -195,11 +194,10 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { output_names.push_back("Z"); std::vector fetches; - auto status = session.Run(feeds, output_names, &fetches); // Error message should come from the Conv implementation with the statically registered kernel - ASSERT_THAT(status.ErrorMessage(), - ::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' " - "Status Message: TODO: add NHWC implementation here.")); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), + "Non-zero status code returned while running Conv node. Name:'Conv' " + "Status Message: TODO: add NHWC implementation here."); } TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { @@ -243,10 +241,9 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { output_names.push_back("softmaxout_1"); std::vector fetches; - auto status = session.Run(feeds, output_names, &fetches); - ASSERT_THAT(status.ErrorMessage(), - ::testing::HasSubstr("Non-zero status code returned while running Conv node. Name:'Conv' " - "Status Message: TODO: add NHWC implementation here.")); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), + "Non-zero status code returned while running Conv node. Name:'Conv' " + "Status Message: TODO: add NHWC implementation here."); } // This test can be deprecated now as the code logic has been changed so the model is not applicable diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 977a5bd9ea7b8..23ec48fa649dd 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -124,7 +124,8 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { outputs.emplace_back(output); } - auto kernel = info.CreateKernel(&node); + static const ConfigOptions empty_config_options; + auto kernel = info.CreateKernel(&node, empty_config_options); ASSERT_TRUE(kernel); std::vector fetch_mlvalue_idxs; diff --git a/onnxruntime/test/util/include/asserts.h b/onnxruntime/test/util/include/asserts.h index f6edb062f0706..02494951a06ba 100644 --- a/onnxruntime/test/util/include/asserts.h +++ b/onnxruntime/test/util/include/asserts.h @@ -6,6 +6,7 @@ #include "core/common/status.h" #include "core/session/onnxruntime_c_api.h" #include "gtest/gtest.h" +#include "gmock/gmock.h" // helpers to run a function and check the status, outputting any error if it fails. // note: wrapped in do{} while(false) so the _tmp_status variable has limited scope @@ -33,6 +34,20 @@ EXPECT_FALSE(_tmp_status.IsOK()); \ } while (false) +#define ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \ + do { \ + Status _tmp_status = (function); \ + ASSERT_FALSE(_tmp_status.IsOK()); \ + ASSERT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \ + } while (false) + +#define EXPECT_STATUS_NOT_OK_AND_HAS_SUBSTR(function, msg) \ + do { \ + Status _tmp_status = (function); \ + EXPECT_FALSE(_tmp_status.IsOK()); \ + EXPECT_THAT(_tmp_status.ErrorMessage(), ::testing::HasSubstr(msg)); \ + } while (false) + // Same helpers for public API OrtStatus. Get the 'api' instance using: // const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); #define ASSERT_ORTSTATUS_OK(api, function) \ diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 6193a1d10c095..894fe3b052fb2 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -157,8 +157,10 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); } InlinedHashSet excluded_initializers(weights_to_train.begin(), weights_to_train.end()); + static const ConfigOptions empty_config_options; transformers.emplace_back(std::make_unique( - execution_provider, false /*skip_dequantize_linear*/, compatible_eps, excluded_initializers)); + execution_provider, false /*skip_dequantize_linear*/, empty_config_options, compatible_eps, + excluded_initializers)); transformers.emplace_back(std::make_unique(compatible_eps)); // Put fine-grained optimizer (e.g. ShapeOptimizer) after ReshapeFusion to avoid it breaks the strong patterns // it defines. ReshapeFusion depends on subgraph pattern matching and do replacement accordingly, ShapeOptimizer