From 708ee8556e73c76993aaddf5558da6118e8aa510 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 10 Dec 2024 11:54:14 +1000 Subject: [PATCH] Reduce default logger usage (#23030) ### Description We have use cases where multiple sessions are created concurrently. Minimizing the usage of the default logger is important for these scenarios. Wire through the session logger to as many places as possible. The EP logger can also be used once the session is created (can't be used during EP construction/kernel registration but can be used in GetCapability and Compile). ### Motivation and Context Improve logging when there are concurrent sessions. --- .../core/framework/kernel_registry.h | 12 ++- .../core/optimizer/graph_transformer_utils.h | 2 + .../core/framework/allocation_planner.cc | 79 ++++++++++--------- .../core/framework/fallback_cpu_capability.cc | 11 +-- .../core/framework/fallback_cpu_capability.h | 6 +- .../core/framework/graph_partitioner.cc | 51 +++++++----- onnxruntime/core/framework/kernel_lookup.h | 9 ++- onnxruntime/core/framework/kernel_registry.cc | 12 ++- .../core/framework/kernel_registry_manager.cc | 14 ++-- .../core/framework/kernel_registry_manager.h | 5 +- onnxruntime/core/framework/session_state.cc | 4 +- .../core/optimizer/constant_folding.cc | 7 +- .../core/optimizer/graph_transformer_utils.cc | 8 +- .../core/optimizer/insert_cast_transformer.cc | 16 ++-- .../core/optimizer/nhwc_transformer.cc | 16 ++-- onnxruntime/core/optimizer/nhwc_transformer.h | 3 +- .../optimizer/optimizer_execution_frame.cc | 19 +++-- .../optimizer/optimizer_execution_frame.h | 7 +- .../qdq_transformer/avx2_weight_s8_to_u8.cc | 8 +- .../selectors_actions/shared/utils.cc | 9 ++- .../selectors_actions/shared/utils.h | 8 +- .../core/optimizer/transformer_memcpy.cc | 47 +++++++---- .../providers/cann/cann_execution_provider.cc | 6 +- .../providers/cuda/cuda_execution_provider.cc | 2 +- .../src/DmlGraphFusionTransformer.cpp | 3 +- .../src/DmlRuntimeGraphFusionTransformer.cpp | 3 +- .../src/ExecutionProvider.cpp | 7 +- .../src/ExecutionProvider.h | 3 +- .../providers/js/js_execution_provider.cc | 2 +- .../nnapi_builtin/builders/model_builder.cc | 14 +++- .../nnapi_builtin/builders/model_builder.h | 10 ++- .../nnapi_builtin/nnapi_execution_provider.cc | 17 ++-- .../qdq_transformations/qdq_stripping.cc | 2 +- .../core/providers/qnn/builder/qnn_model.cc | 2 +- .../providers/qnn/qnn_execution_provider.cc | 2 +- .../providers/rocm/rocm_execution_provider.cc | 2 +- .../providers/shared_library/provider_api.h | 7 +- .../provider_bridge_provider.cc | 5 +- .../shared_library/provider_interfaces.h | 5 +- .../core/providers/vsinpu/vsinpu_ep_graph.cc | 5 +- .../core/providers/vsinpu/vsinpu_ep_graph.h | 3 +- .../vsinpu/vsinpu_execution_provider.cc | 26 +++--- .../webgpu/webgpu_execution_provider.cc | 3 +- .../xnnpack/xnnpack_execution_provider.cc | 3 +- onnxruntime/core/session/inference_session.cc | 13 +-- onnxruntime/core/session/inference_session.h | 8 +- .../core/session/inference_session_utils.cc | 5 +- .../core/session/inference_session_utils.h | 7 +- .../core/session/provider_bridge_ort.cc | 9 ++- .../core/session/standalone_op_invoker.cc | 7 +- .../test/framework/allocation_planner_test.cc | 8 +- .../test/optimizer/graph_transform_test.cc | 12 ++- .../optimizer/graph_transform_utils_test.cc | 11 ++- onnxruntime/test/optimizer/optimizer_test.cc | 22 +++--- .../test/optimizer/qdq_transformer_test.cc | 7 +- onnxruntime/test/providers/base_tester.cc | 7 +- .../providers/kernel_compute_test_utils.cc | 8 +- .../test/providers/partitioning_utils_test.cc | 4 +- .../core/session/training_session.cc | 3 +- .../core/session/training_session.h | 3 +- .../test/gradient/gradient_op_test_utils.cc | 3 +- .../optimizer/graph_transformer_utils_test.cc | 10 ++- 62 files changed, 389 insertions(+), 243 deletions(-) diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h index 7b3d04ee66d9e..aaf533135429c 100644 --- a/include/onnxruntime/core/framework/kernel_registry.h +++ b/include/onnxruntime/core/framework/kernel_registry.h @@ -8,6 +8,9 @@ #include "core/framework/op_kernel.h" namespace onnxruntime { +namespace logging { +class Logger; +} using KernelCreateMap = std::multimap; using KernelDefHashes = std::vector>; @@ -33,6 +36,7 @@ class KernelRegistry { // Kernel matching uses the types from the node and the kernel_type_str_resolver. Status TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const; // map of type constraint name to required type @@ -42,6 +46,7 @@ class KernelRegistry { // Kernel matching uses the explicit type constraint name to required type map in type_constraints. Status TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; /** @@ -61,13 +66,15 @@ class KernelRegistry { std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; static bool HasImplementationOf(const KernelRegistry& r, const Node& node, ProviderType exec_provider, - const IKernelTypeStrResolver& kernel_type_str_resolver) { + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) { const KernelCreateInfo* info; - Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info); + Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info); return st.IsOK(); } @@ -83,6 +90,7 @@ class KernelRegistry { Status TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const; // Check whether the types of inputs/outputs of the given node match the extra diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h index 6cff153c336f0..31b0f22340510 100644 --- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h +++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h @@ -53,6 +53,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& execution_provider /*required by constant folding*/, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); @@ -84,6 +85,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable = {}, concurrency::ThreadPool* intra_op_thread_pool = nullptr, std::unordered_map>* p_buffered_tensors = nullptr); diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 5dca4cf6c165b..ecd3960107926 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -138,7 +138,8 @@ class PlannerImpl { const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps, const InlinedHashMap& outer_scope_node_arg_to_location_map, const OrtValueNameIdxMap& ort_value_name_idx_map, - const ISequentialPlannerContext& context, SequentialExecutionPlan& plan) + const ISequentialPlannerContext& context, SequentialExecutionPlan& plan, + const logging::Logger& logger) : context_(&context), plan_(plan), parent_node_(parent_node), @@ -148,14 +149,15 @@ class PlannerImpl { kernel_create_info_map_(kernel_create_info_map), subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps), outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map), - ort_value_name_idx_map_(ort_value_name_idx_map) {} + ort_value_name_idx_map_(ort_value_name_idx_map), + logger_(logger) { + } Status CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger); + const PathString& partition_config_file); private: gsl::not_null context_; @@ -183,6 +185,12 @@ class PlannerImpl { InlinedHashMap> dependence_graph_; InlinedHashMap value_node_map_; + // logger_ is not currently used in a minimal build +#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) + [[maybe_unused]] +#endif + const logging::Logger& logger_; + // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: struct OrtValueInfo { const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue @@ -213,6 +221,7 @@ class PlannerImpl { FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point) : ml_value(ort_value), deallocate_point(dealloc_point) {} }; + // freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when // they became free (more recently freed earlier in the list). std::list freelist_; @@ -225,7 +234,8 @@ class PlannerImpl { } int& UseCount(OrtValueIndex n) { - ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size()); + ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), + "invalid value index: ", n, " against size ", ort_value_info_.size()); return ort_value_info_[n].usecount; } int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); } @@ -335,9 +345,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -361,9 +371,9 @@ class PlannerImpl { // we cannot. const Node* producer_node = graph.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " - << producer_node->Name() << " which has external outputs. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node " + << producer_node->Name() << " which has external outputs. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif *reusable_input = Index(p_input_arg->Name()); @@ -397,8 +407,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -448,8 +458,8 @@ class PlannerImpl { return true; } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " - << producer_node->Name() << " as it has external outputs."; + LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node " + << producer_node->Name() << " as it has external outputs."; #endif } } @@ -1198,9 +1208,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1241,9 +1251,9 @@ class PlannerImpl { // Otherwise, we cannot reuse the buffer. const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name()); if (producer_node && HasExternalOutputs(*producer_node)) { - LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " - << producer_node->Name() << " which has external outputs is reused. " - << "Be cautious the reuse MUST be a read-only usage."; + LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node " + << producer_node->Name() << " which has external outputs is reused. " + << "Be cautious the reuse MUST be a read-only usage."; } #endif @@ -1290,8 +1300,8 @@ class PlannerImpl { } } else { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " - << producer_node->Name() << " as it has external outputs"; + LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node " + << producer_node->Name() << " as it has external outputs"; #endif } } @@ -1869,8 +1879,7 @@ class PlannerImpl { } #ifndef ORT_ENABLE_STREAM - void PartitionIntoStreams(const logging::Logger& /*logger*/, - const ExecutionProviders& /*execution_providers*/, + void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { if (graph_viewer_.NumberOfNodes() > 0) { stream_nodes_.push_back({}); @@ -1915,11 +1924,11 @@ class PlannerImpl { #else - void - PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers, - const PathString& partition_config_file) { - auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file); - auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder()); + void PartitionIntoStreams(const ExecutionProviders& execution_providers, + const PathString& partition_config_file) { + auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file); + auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, + context_->GetExecutionOrder()); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); for (size_t i = 0; i < stream_nodes_.size(); ++i) { @@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan( #ifdef ORT_ENABLE_STREAM const IStreamCommandHandleRegistry& stream_handle_registry, #endif - const PathString& partition_config_file, - const logging::Logger& logger) { + const PathString& partition_config_file) { // 1. partition graph into streams - PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file); + PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file); // 2. initialize the plan based on stream partition result int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; @@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan( PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_create_info_map, subgraphs_kernel_create_info_maps, outer_scope_node_arg_to_location_map, - ort_value_name_idx_map, context, *plan); + ort_value_name_idx_map, context, *plan, logger); return planner.CreatePlan( #ifdef ORT_ENABLE_STREAM stream_handle_registry, #endif - partition_config_file, - logger); + partition_config_file); } #ifdef ORT_ENABLE_STREAM diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc index ef68b88187e08..1eb7420b44d2c 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.cc +++ b/onnxruntime/core/framework/fallback_cpu_capability.cc @@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { + gsl::span tentative_nodes, + const logging::Logger& logger) { // automatic conversion from const std::vector& const auto& ordered_nodes = graph.GetNodesInTopologicalOrder(); InlinedVector node_id_to_order_map(graph.MaxNodeIndex()); @@ -83,7 +84,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name()); for (auto& consumer_node : consumer_nodes) { candidates.push(consumer_node->Index()); - LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); + LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name(); } } return Status::OK(); @@ -159,9 +160,9 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe if (place_in_cpu) { cpu_nodes.insert(cur); - LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() - << " because the CPU execution path is deemed faster than overhead involved with execution on other EPs " - << " capable of executing this node"; + LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name() + << " because the CPU execution path is deemed faster than overhead involved with execution " + "on other EPs capable of executing this node"; for (auto* output : node->OutputDefs()) { cpu_output_args.insert(output); } diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h index c5bcd22888b7c..bca75adbfd5a7 100644 --- a/onnxruntime/core/framework/fallback_cpu_capability.h +++ b/onnxruntime/core/framework/fallback_cpu_capability.h @@ -9,6 +9,9 @@ #include "core/graph/graph_viewer.h" namespace onnxruntime { +namespace logging { +class Logger; +} /** Returns a list of nodes that are preferred on CPU. @@ -19,6 +22,7 @@ namespace onnxruntime { */ std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); } // namespace onnxruntime diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 6174122cf3cb4..406fc1b15effc 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep, }; } // namespace -static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { +static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) { auto& current_ep = params.current_ep.get(); const auto& ep_type = current_ep.Type(); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) { - LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " + LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by " "the layout transformer."; return Status::OK(); } @@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; auto& graph = params.graph.get(); auto& capabilities = params.capabilities.get(); @@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, const KernelRegistryManager& kernel_registry_mgr, const IExecutionProvider& current_ep, + const logging::Logger& logger, std::vector>& capabilities) { const auto& ep_type = current_ep.Type(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, kernel_registries_for_ep, - kernel_registry_mgr.GetKernelTypeStrResolver()}; + kernel_registry_mgr.GetKernelTypeStrResolver(), + logger}; // TODO: Provide EP with a capability to look inside the functions. capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); @@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, GraphPartitioner::Mode mode, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, - const layout_transformation::DebugGraphFn& debug_graph_fn) { + const layout_transformation::DebugGraphFn& debug_graph_fn, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability if (graph.NumberOfNodes() == 0) { @@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn)); + transform_layout_fn, debug_graph_fn, logger)); } } @@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn)}; - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); if (n != nullptr) { // searching in kernel registries, if no kernel registered for the fused_node, use compile approach - if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) { + if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) { nodes_to_compile.push_back(n); capabilities_to_compile.push_back(std::move(capability)); } else { @@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, const KernelRegistryManager& kernel_registry_mgr, Graph& graph, + const logging::Logger& logger, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_mgr, *subgraph, + logger, not_inlined, inlined_count)); } @@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger, + capabilities)); for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { bool modified_graph = false; auto& graph = partition_params.graph.get(); @@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, - partition_params.debug_graph_fn)); + partition_params.debug_graph_fn, + logger)); } // expand any nodes that have an ONNX function definition but no matching ORT kernel. @@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params, KernelRegistryManager& kernel_registry_mgr, - IExecutionProvider& current_ep) { + IExecutionProvider& current_ep, + const logging::Logger& logger) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. // doing it here saves all providers checking for this in GetCapability auto& graph = partition_params.graph.get(); @@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param auto& subgraph = *entry.second; PartitionParams subgraph_partition_params = partition_params; subgraph_partition_params.graph = std::ref(subgraph); - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep, + logger)); } } @@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param }; // clang-format on - ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params)); + ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { return Status::OK(); } @@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param // Simplified partitioning where custom EPs may produce compiled nodes. static Status PartitionOrtFormatModel(const PartitionParams& partition_params, const ExecutionProviders& execution_providers, - KernelRegistryManager& kernel_registry_manager) { + KernelRegistryManager& kernel_registry_manager, + const logging::Logger& logger) { // process full graph with each EP for (const auto& ep : execution_providers) { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger)); } return Status::OK(); @@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, kernel_registry_manager, graph, + logger, not_inlined, inlined_count)); @@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, if (mode == Mode::kNormal || mode == Mode::kAssignOnly) { #if !defined(ORT_MINIMAL_BUILD) - ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger)); bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); @@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { - ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, - providers_, kernel_registry_mgr_)); + ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger)); } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h index 0dd17d2f4a624..fac43bad0fefb 100644 --- a/onnxruntime/core/framework/kernel_lookup.h +++ b/onnxruntime/core/framework/kernel_lookup.h @@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { public: KernelLookup(ProviderType provider_type, gsl::span> kernel_registries, - const IKernelTypeStrResolver& kernel_type_str_resolver) + const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger) : provider_type_{provider_type}, kernel_registries_{kernel_registries}, - kernel_type_str_resolver_{kernel_type_str_resolver} { + kernel_type_str_resolver_{kernel_type_str_resolver}, + logger_{logger} { ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified."); } const KernelCreateInfo* LookUpKernel(const Node& node) const override { const KernelCreateInfo* kernel_create_info{}; for (const auto& registry : kernel_registries_) { - const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, + const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return kernel_create_info; @@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup { ProviderType provider_type_; const gsl::span> kernel_registries_; const IKernelTypeStrResolver& kernel_type_str_resolver_; + const logging::Logger& logger_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc index cbbe0f86b8b7e..8602a3b4004ff 100644 --- a/onnxruntime/core/framework/kernel_registry.cc +++ b/onnxruntime/core/framework/kernel_registry.cc @@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver* kernel_type_str_resolver, const TypeConstraintMap* type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { const auto& node_provider = node.GetExecutionProviderType(); const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider); @@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } @@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node, Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const IKernelTypeStrResolver& kernel_type_str_resolver, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out); + return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out); } Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider, const TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { - return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out); + return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out); } static bool KernelDefCompatible(int version, const KernelDef& kernel_def, @@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::string_view domain, int version, const KernelRegistry::TypeConstraintMap& type_constraints, + const logging::Logger& logger, const KernelCreateInfo** out) const { auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider)); if (out) *out = nullptr; @@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider, std::ostream_iterator(oss, "\n")); oss << ")"; - VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str(); + VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str(); return Status(common::ONNXRUNTIME, common::FAIL, oss.str()); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index f8ccdb8fb0238..721353854a474 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptrTryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, } if (p != nullptr) { - status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info); + status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info); if (status.IsOK()) { return status; } @@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for ")); } -bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) { +bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, + const Node& node, + const std::string& provider_type, + const logging::Logger& logger) { const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type); return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) { - return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver()); + return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(), + logger); }); } diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h index 1da73208cb536..72f0ed3c6268a 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.h +++ b/onnxruntime/core/framework/kernel_registry_manager.h @@ -67,13 +67,14 @@ class KernelRegistryManager { // This function assumes the node is already assigned to an execution provider // Don't call this function before graph partition is done - Status SearchKernelRegistry(const Node& node, + Status SearchKernelRegistry(const Node& node, const logging::Logger& logger, /*out*/ const KernelCreateInfo** kernel_create_info) const; /** * Whether this node can be run on this provider */ - static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type); + static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type, + const logging::Logger& logger); Status CreateKernel(const Node& node, const IExecutionProvider& execution_provider, diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..0ac2271ba09f1 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne bool saving_ort_format) { for (auto& node : graph_.Nodes()) { const KernelCreateInfo* kci = nullptr; - auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); if (!status.IsOK() && saving_ort_format) { // if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled. // in that case we assigned the node to that EP but do not compile it into a fused node. @@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne // at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible. // if that's not possible for some reason we can fallback to the CPU EP implementation. node.SetExecutionProviderType(kCpuExecutionProvider); - status = kernel_registry_manager.SearchKernelRegistry(node, &kci); + status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci); } ORT_RETURN_IF_ERROR(status); diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc index 1466de51d0b99..e755b4bfa6364 100644 --- a/onnxruntime/core/optimizer/constant_folding.cc +++ b/onnxruntime/core/optimizer/constant_folding.cc @@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level, #if !defined(DISABLE_SPARSE_TENSORS) // Create execution frame for executing constant nodes. OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - is_sparse_initializer_check); + is_sparse_initializer_check, logger); #else // Create execution frame for executing constant nodes. - OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_, - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; }, + logger); #endif std::vector fetch_mlvalue_idxs; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 2f2524420dc44..ba2b87b5aa0ca 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -190,6 +190,7 @@ InlinedVector> GenerateTransformers( TransformerLevel level, const SessionOptions& session_options, const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/ + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -404,7 +405,8 @@ InlinedVector> GenerateTransformers( } auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } @@ -437,6 +439,7 @@ InlinedVector> GenerateTransformersForMinimalB const SessionOptions& session_options, const SatApplyContextVariant& apply_context, const IExecutionProvider& cpu_execution_provider, + const logging::Logger& logger, const InlinedHashSet& rules_and_transformers_to_disable, [[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool, std::unordered_map>* p_buffered_tensors) { @@ -490,7 +493,8 @@ InlinedVector> GenerateTransformersForMinimalB #ifndef DISABLE_CONTRIB_OPS AllocatorPtr cpu_allocator = std::make_shared(); auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); - auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); + auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry), + logger); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc index 67ebc22dab41d..b1665c7172549 100644 --- a/onnxruntime/core/optimizer/insert_cast_transformer.cc +++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc @@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) { // going to a node that will need a Cast. // // Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32. -static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, + const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { // we can check if it's an isolated fp16 node // if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't), // does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node), @@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: const KernelCreateInfo* kernel_create_info{}; const auto lookup_status = cpu_kernel_registry.TryFindKernel( kCpuExecutionProvider, node.OpType(), node.Domain(), - node.SinceVersion(), type_constraint_map, &kernel_create_info); + node.SinceVersion(), type_constraint_map, logger, &kernel_create_info); if (lookup_status.IsOK() && kernel_create_info != nullptr) { return true; } @@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime:: return false; } -static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) { +static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry, + const logging::Logger& logger) { for (auto& node : graph.Nodes()) { - if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) { + if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) { // unassign the node so that NeedInsertCast will return true for it, forcing it to fp32 node.SetExecutionProviderType(""); } @@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { return dst_bit_length <= src_bit_length; } - if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { + if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || + (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) { return true; } @@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer { Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { if (force_cpu_fp32_) - ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_)); + ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger)); GraphViewer graph_viewer(graph); auto& order = graph_viewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index ee79fa620374e..cd654991c92d5 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -44,7 +44,9 @@ NhwcConvLookup( return &(iter->second); } -NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept +NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, + std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept : GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) { if (!cpu_kernel_registry) { // This is a CPU op nodes optimizer, not useful if cpu EP is not available. @@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_, - qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info); + qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_, - qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info); + qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_, - nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info); + nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_, - nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info); + nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_, - nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( @@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel( kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_, - nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info); + nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info); if (status.IsOK() && kernel_create_info != nullptr) { kernel_create_info = nullptr; conv_table_.emplace( diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h index 000732060b889..c65f851fdab9d 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.h +++ b/onnxruntime/core/optimizer/nhwc_transformer.h @@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed. class NhwcTransformer : public GraphTransformer { private: public: - explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept; + explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry, + const logging::Logger& logger) noexcept; /** * @brief Usually called right after constructor, it shows whether diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index ed7d5feb2beb3..b2e8e491c361c 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& /* model_path */, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func) + const std::function& is_sparse_initializer_func, + const logging::Logger& logger) : execution_provider_(execution_provider), - is_sparse_initializer_func_(is_sparse_initializer_func) { + is_sparse_initializer_func_(is_sparse_initializer_func), + logger_(logger) { allocator_ptr_ = std::make_shared(); ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer"); @@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const { std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; - return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out); + return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out); } static Status TryCreateKernel(const Node& node, @@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node, FuncManager& funcs_mgr, const DataTransferManager& data_transfer_mgr, const ConfigOptions& config_options, + const logging::Logger& logger, /*out*/ std::unique_ptr& op_kernel) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelCreateInfo* kernel_create_info = nullptr; - ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, + ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger, &kernel_create_info)); static const AllocatorMap dummy_allocators; @@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry(); FuncManager func; auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_, - ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, + ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_, op_kernel); // Kernel found in the CPU kernel registry diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index b0f7f461661b5..24a23312feba9 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame { const InitializedTensorSet& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); Info(const std::vector& nodes, const std::unordered_map& initialized_tensor_set, const std::filesystem::path& model_path, const IExecutionProvider& execution_provider, - const std::function& is_sparse_initializer_func); + const std::function& is_sparse_initializer_func, + const logging::Logger& logger); ~Info() = default; @@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame { std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; + const logging::Logger& logger_; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info); }; diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc index 18e462c04dff3..5538aa54801cc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc @@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion( return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end(); } -static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { +static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) { constexpr size_t w_idx = 1; constexpr size_t w_zp_idx = 9; constexpr size_t r_idx = 2; @@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) || !graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) || r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " cannot locate recurrence tensor of const int8 type," << " int8 overflow might impact precision !"; return false; @@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) { if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) || !graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) || r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) { - LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator," + LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator," << " unable to locate recurrence tensor or its zero point value," << " int8 overflow might impact precision !"; return false; @@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int if (graph_utils::IsSupportedOptypeVersionAndDomain( op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) { // This one has two set of quantized arguments - modified |= TryConvertDynamicQuantizeLSTM(op_node, graph); + modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger); continue; // go on to next operator node } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index d2240b5d50194..81305f7effa16 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -291,7 +291,8 @@ SelectorManager::SelectorManager() { InitializeSelectorsMap(); } -std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const { +std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer, + const logging::Logger& logger) const { std::vector qdq_selections; for (auto index : graph_viewer.GetNodesInTopologicalOrder()) { const auto* node = graph_viewer.GetNode(index); @@ -313,7 +314,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second; if (!versions.empty()) { if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) { - LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType(); + LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType(); continue; } } @@ -329,7 +330,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap } std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer) { +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) { std::vector> node_unit_holder; std::unordered_map node_unit_map; @@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) { // Get QDQ NodeUnits first QDQ::SelectorManager selector_mgr; - const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer); + const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger); for (const auto& qdq_selection : qdq_selections) { auto qdq_unit = std::make_unique(graph_viewer, qdq_selection); diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h index f388206551172..ccc1844e3e985 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h @@ -15,7 +15,9 @@ #endif namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; class Node; @@ -65,7 +67,7 @@ class SelectorManager { // Methods that finds and returns a vector of QDQ::NodeGroup in a given graph // Can be used in QDQ support in different EPs - std::vector GetQDQSelections(const GraphViewer& graph_viewer) const; + std::vector GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const; private: Selectors qdq_selectors_; @@ -88,7 +90,7 @@ class SelectorManager { // We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer // library whereas it should be able to be used by an EP with no dependency on optimizers. std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer& graph_viewer); +GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 6a5a85ce0ff31..8c0136c495403 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -17,13 +17,22 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); + bool ModifyGraph(const KernelRegistryManager& schema_registries, + const logging::Logger& logger, + int& copy_node_counter); private: - void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); - void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); + void ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); + void BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger); void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); - bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); + bool ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger); private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl); @@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // find defs that require copy for (auto& node : graph_.Nodes()) { // as we process the defs, collect all the initializers consumed at the current graph level - ProcessDefs(node, kernel_registries, initializers_consumed); + ProcessDefs(node, kernel_registries, initializers_consumed, logger); } // for initializers shared by different providers, create dups - if (ProcessInitializers(kernel_registries, initializers_consumed)) + if (ProcessInitializers(kernel_registries, initializers_consumed, logger)) modified = true; for (auto arg : graph_.GetInputs()) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_input_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : non_provider_output_defs_) - BuildDefsMapping(arg, kernel_registries); + BuildDefsMapping(arg, kernel_registries, logger); for (auto arg : graph_.GetInputs()) // For inputs we need to create a copy node only when the input is connected to both provider @@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi return modified; } -void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, - InitializedTensorSet& initializers_consumed) { +void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, + const KernelRegistryManager& kernel_registries, + InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { auto node_provider_type = node.GetExecutionProviderType(); if ((node_provider_type == provider_) || (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || @@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg provider_nodes_.insert(&node); // note KernelCreateInfo might be nullptr for custom kernel const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci)); bool is_implicit_input = false; auto process_inputs = @@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg } // for non_provider defs, collect the nodes that expect it is provider tensor as input/output. -void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) { +void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, + const KernelRegistryManager& kernel_registries, + const logging::Logger& logger) { for (auto& it : graph_.Nodes()) { if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue; auto input_it = @@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co (node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) || (node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) { const KernelCreateInfo* kci = nullptr; - ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci)); + ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci)); if (arg_input_index != -1) { if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it); } @@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co // We duplicate any initializer that is used by both provider nodes and non-provider nodes // to ensure that provider nodes and non-provider nodes don't share initializers, as they // need to stay in different memory locations. -bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) { +bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, + const InitializedTensorSet& initializers_consumed, + const logging::Logger& logger) { std::map replacements; for (const auto& pair : initializers_consumed) { const auto& name = pair.first; @@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker auto dup_replacements = replacements; const KernelCreateInfo* kci = nullptr; - auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci); + auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); if (kci == nullptr) continue; if (kci->kernel_def == nullptr) continue; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index a799ed743ef52..f954baf3eabae 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node); if (cann_kernel_def == nullptr) { - LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() - << " node name: " << node.Name(); + LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType() + << " node name: " << node.Name(); continue; } candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger()); for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) continue; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8396e2629d2bf..d4013a7dc3d57 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For CUDA EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 35a2c451a49a5..9f95818501dac 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -62,7 +62,8 @@ namespace Dml const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; const auto kernel_lookup = onnxruntime::KernelLookup{provider_type, gsl::make_span(®istry, 1), - kernel_type_str_resolver}; + kernel_type_str_resolver, + logger}; std::vector> compiledPartitionInfos; std::vector additionalSplittingNodes; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp index 6318b0d5e2865..b9b90d6bc17bd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -54,7 +54,8 @@ namespace Dml const auto kernelLookup = onnxruntime::KernelLookup( providerType, gsl::make_span(®istry, 1), - kernelTypeStrResolver); + kernelTypeStrResolver, + logger); onnxruntime::GraphViewer graphViewer(graph); const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 228dfeb123175..826f48b5f7a68 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -95,7 +95,7 @@ namespace Dml const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const { #ifdef ENABLE_GRAPH_COMPILATION - return m_impl->GetCapability(graph, kernel_lookup); + return m_impl->GetCapability(graph, kernel_lookup, *GetLogger()); #else return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup); #endif @@ -876,7 +876,8 @@ namespace Dml std::vector> ExecutionProviderImpl::GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger) const { uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE. @@ -900,7 +901,7 @@ namespace Dml } // Get the list of nodes that should stay on the CPU - auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes); + auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger); for (size_t nodeIndex : toplogicalOrder) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 32a5b9add35a0..e7d859c5764de 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -88,7 +88,8 @@ namespace Dml std::vector> GetCapability( const onnxruntime::GraphViewer& graph, - const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup + const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup, + const onnxruntime::logging::Logger& logger ) const; uint32_t GetSupportedDeviceDataTypeMask() const; diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 4cb40ec8bf5fd..c1a8b373bed84 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -818,7 +818,7 @@ std::vector> JsExecutionProvider::GetCapabili candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index 12416ea0c121b..e4bee6f959a01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -32,8 +32,16 @@ namespace nnapi { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, gsl::span nnapi_target_devices, - TargetDeviceOption target_device_option) - : nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) { + TargetDeviceOption target_device_option, + const logging::Logger& logger) + : nnapi_(nnapi_handle), + graph_viewer_(graph_viewer), + nnapi_model_{std::make_unique(nnapi_handle)}, + shaper_{graph_viewer}, + nnapi_target_devices_(nnapi_target_devices), + target_device_option_(target_device_option), + nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)), + logger_(logger) { nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_; } @@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const { } void ModelBuilder::PreprocessNodeUnits() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); } // Help to get all quantized operators' input and the NodeUnit(s) using the input diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h index b2118150dd304..4db335afa98b0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h @@ -14,7 +14,9 @@ struct NnApi; namespace onnxruntime { - +namespace logging { +class Logger; +} class GraphViewer; enum class DataLayout; class NodeUnit; @@ -31,7 +33,8 @@ class ModelBuilder { using Shape = Shaper::Shape; ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle, - gsl::span nnapi_target_devices, TargetDeviceOption target_device_option); + gsl::span nnapi_target_devices, TargetDeviceOption target_device_option, + const logging::Logger& logger); common::Status Compile(std::unique_ptr& model); @@ -173,6 +176,9 @@ class ModelBuilder { // <1,1> <1,2> <1,3> InlinedVector> operations_recorder_; #endif + + const logging::Logger& logger_; + // Convert the ONNX model to ANeuralNetworksModel common::Status Prepare(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc index fca52396a190c..f92c9592742d5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc @@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {} std::vector> NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; // TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators) @@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL; #endif }(); - LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; + LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level; const nnapi::OpSupportCheckParams params{ android_feature_level, @@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view }; if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) { - LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" + LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level [" << params.android_feature_level << "] is lower than minimal supported NNAPI API feature level [" << ORT_NNAPI_MIN_API_LEVEL @@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view node_unit_supported_result[node_unit] = supported; } - LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported + LOGS(logger, VERBOSE) << "Node supported: [" << supported << "] Operator type: [" << node.OpType() << "] index: [" << node.Index() << "] name: [" << node.Name() @@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view // If the graph is partitioned in multiple subgraphs, and this may impact performance, // we want to give users a summary message at warning level. if (num_of_partitions > 1) { - LOGS_DEFAULT(WARNING) << summary_msg; + LOGS(logger, WARNING) << summary_msg; } else { - LOGS_DEFAULT(INFO) << summary_msg; + LOGS(logger, INFO) << summary_msg; } return result; @@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context, common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { using namespace android::nn::wrapper; + const auto& logger = *GetLogger(); + for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_); + nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger); builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW); builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16); diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index f1df1abf4c49a..decfe91c598be 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger); std::unordered_set seen_node_units; const auto& node_indices = src_graph.GetNodesInTopologicalOrder(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 88fa6429fc01e..75973c7031d62 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // valid throughout the lifetime of the ModelBuilder std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3bb069196e31c..060bbd4f79bf2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -718,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // remove is_qnn_ctx_model related code const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 75b8ac7e134f3..0a427b146dcaa 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // For ROCM EP, exclude the subgraph that is preferred to be placed in CPU // These are usually shape related computation subgraphs // Following logic can be extended for other EPs - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b84825236a453..45f81ed22b7f7 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -294,7 +294,8 @@ std::unique_ptr CreateGPUDataTransfer(); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger); std::string GetEnvironmentVar(const std::string& var_name); @@ -371,8 +372,8 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { namespace QDQ { inline std::pair>, std::unordered_map> -GetAllNodeUnits(const GraphViewer* graph_viewer) { - return g_host->QDQ__GetAllNodeUnits(graph_viewer); +GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) { + return g_host->QDQ__GetAllNodeUnits(graph_viewer, logger); } } // namespace QDQ diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d3b12f9728135..aa8c367d25d51 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -369,8 +369,9 @@ std::string GetEnvironmentVar(const std::string& var_name) { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) { - return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) { + return g_host->GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } namespace profiling { diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f9f2bb69a9d1a..7ab93d56cfe26 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -202,7 +202,8 @@ struct ProviderHost { virtual std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) = 0; + gsl::span tentative_nodes, + const logging::Logger& logger) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0; @@ -890,7 +891,7 @@ struct ProviderHost { virtual std::unique_ptr NodeUnit__OutputEdgesEnd(const NodeUnit* p) = 0; virtual std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) = 0; + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) = 0; // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc index bbf8255ac2940..db8a87d9eaf24 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.cc @@ -34,7 +34,8 @@ namespace onnxruntime { namespace vsi { namespace npu { -GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(graph_viewer) { +GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger) + : graph_viewer_(graph_viewer), logger_(logger) { Prepare(); context_ = tim::vx::Context::Create(); graph_ = context_->CreateGraph(); @@ -42,7 +43,7 @@ GraphEP::GraphEP(const onnxruntime::GraphViewer& graph_viewer) : graph_viewer_(g } bool GraphEP::Prepare() { - std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_); + std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_); for (const auto& node_unit : node_unit_holder_) { auto quant_op_type = util::GetQuantizedOpType(*node_unit); diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h index 49344770d060e..5bb332fad0177 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h +++ b/onnxruntime/core/providers/vsinpu/vsinpu_ep_graph.h @@ -51,7 +51,7 @@ struct NodeIOInfo { class GraphEP { public: - explicit GraphEP(const GraphViewer& graph_viewer); + explicit GraphEP(const GraphViewer& graph_viewer, const logging::Logger& logger); ~GraphEP() {} bool Prepare(); @@ -104,6 +104,7 @@ class GraphEP { // In the form of {input_name, [NodeUnit(s) using the input]} std::unordered_map> all_quantized_op_inputs_; const GraphViewer& graph_viewer_; + const logging::Logger& logger_; // Holder for the NodeUnits in the graph, this will guarantee the NodeUnits is // valid throughout the lifetime of the ModelBuilder diff --git a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc index 669c702544de8..7da7cc6cb63ba 100644 --- a/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc +++ b/onnxruntime/core/providers/vsinpu/vsinpu_execution_provider.cc @@ -62,6 +62,7 @@ VSINPUExecutionProvider::~VSINPUExecutionProvider() {} std::vector> VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> result; if (graph_viewer.IsSubgraph()) { @@ -82,7 +83,7 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit to be checked for multiple times @@ -174,7 +175,8 @@ VSINPUExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, - OrtKernelContext* context) { + OrtKernelContext* context, + const logging::Logger& logger) { Ort::KernelContext ctx(context); size_t num_in = ctx.GetInputCount(); const size_t num_inputs = graph_ep->GetGraphInputs().size(); @@ -192,7 +194,7 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, } if (!graph_ep->GetGraph()->Run()) { - LOGS_DEFAULT(ERROR) << "Failed to run graph."; + LOGS(logger, ERROR) << "Failed to run graph."; } for (size_t i = 0; i < ctx.GetOutputCount(); i++) { auto timvx_tensor = graph_ep->GetGraphOutputs()[i]->tensor; @@ -207,12 +209,14 @@ Status ComputeStateFunc(vsi::npu::GraphEP* graph_ep, Status VSINPUExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { + const auto& logger = *GetLogger(); + for (const auto& fused_node_graph : fused_nodes_and_graphs) { const GraphViewer& graph_viewer = fused_node_graph.filtered_graph; - std::shared_ptr graph_ep = std::make_shared(graph_viewer); + std::shared_ptr graph_ep = std::make_shared(graph_viewer, logger); for (auto tensor : graph_viewer.GetInputsIncludingInitializers()) { - LOGS_DEFAULT(VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" + LOGS(logger, VERBOSE) << "subgraph input init:" << vsi::npu::util::PrintNode(*tensor) << "#" << graph_viewer.IsInitializedTensor(tensor->Name()); auto input = std::make_shared(); input->name = tensor->Name(); @@ -220,7 +224,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu graph_ep->GetGraphInputs().push_back(input); } for (auto tensor : graph_viewer.GetOutputs()) { - LOGS_DEFAULT(VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); + LOGS(logger, VERBOSE) << "subgraph output:" << vsi::npu::util::PrintNode(*tensor); auto output = std::make_shared(); output->name = tensor->Name(); output->is_initializer = false; @@ -236,16 +240,16 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu if (node != &node_unit.GetNode()) { continue; } - LOGS_DEFAULT(VERBOSE) << "Adding node: [" << node->OpType() << "]"; + LOGS(logger, VERBOSE) << "Adding node: [" << node->OpType() << "]"; vsi::npu::SupportedBuiltinOps().at(node->OpType())->BuildOp(graph_ep.get(), graph_viewer, node_unit); } - LOGS_DEFAULT(INFO) << "Verifying graph"; + LOGS(logger, INFO) << "Verifying graph"; graph_ep->GetCompiled() = graph_ep->GetGraph()->Compile(); if (!graph_ep->GetCompiled()) { - LOGS_DEFAULT(ERROR) << "Failed to verify graph."; + LOGS(logger, ERROR) << "Failed to verify graph."; } else { - LOGS_DEFAULT(INFO) << "Graph has been verified successfully."; + LOGS(logger, INFO) << "Graph has been verified successfully."; } NodeComputeInfo compute_info; @@ -259,7 +263,7 @@ Status VSINPUExecutionProvider::Compile(const std::vector& fu [graph_ep, this](FunctionState /*state*/, const OrtApi* /* api */, OrtKernelContext* context) { std::lock_guard lock(this->GetMutex()); - Status res = ComputeStateFunc(graph_ep.get(), context); + Status res = ComputeStateFunc(graph_ep.get(), context, *GetLogger()); return res; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 90b6862758bd7..66209adf6f1a9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -798,7 +798,8 @@ std::vector> WebGpuExecutionProvider::GetCapa candidates.push_back(node.Index()); tenative_candidates.push_back(node.Index()); } - auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger()); std::vector> result; for (auto& node_index : candidates) { if (cpu_nodes.count(node_index) > 0) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 12e567e7080b3..ee4e7be0f1f49 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -258,6 +258,7 @@ static void AddComputeCapabilityForEachNodeInNodeUnit( std::vector> XnnpackExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { + const auto& logger = *GetLogger(); std::vector> capabilities; std::shared_ptr registry = GetKernelRegistry(); @@ -268,7 +269,7 @@ std::vector> XnnpackExecutionProvider::GetCap // Get all the NodeUnits in the GraphViewer so we can check if something is in a QDQ node group std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph, logger); // This holds the result of whether a NodeUnit is supported or not, // to prevent nodes in a NodeUnit being checked for multiple times diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 2ff9fa525fa3b..a60ee500a9898 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1644,7 +1644,7 @@ Status ApplyOrtFormatModelRuntimeOptimizations( level <= static_cast(session_options.graph_optimization_level); ++level) { const auto transformers = optimizer_utils::GenerateTransformersForMinimalBuild( - static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, + static_cast(level), session_options, SatRuntimeOptimizationLoadContext{}, cpu_ep, logger, optimizers_to_disable, intra_op_thread_pool, p_buffered_tensors); for (const auto& transformer : transformers) { @@ -1840,7 +1840,8 @@ common::Status InferenceSession::Initialize() { ORT_RETURN_IF_ERROR_SESSIONID_(AddPredefinedTransformers(graph_transformer_mgr_, session_options_.graph_optimization_level, minimal_build_optimization_handling, - record_runtime_optimization_produced_op_schema)); + record_runtime_optimization_produced_op_schema, + *session_logger_)); #ifdef USE_DML const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); @@ -2112,7 +2113,7 @@ common::Status InferenceSession::Initialize() { std::vector tuning_results; bool found_tuning_results = false; ORT_RETURN_IF_ERROR_SESSIONID_(inference_session_utils::ParseTuningResultsFromModelMetadata( - model_metadata_, tuning_results, found_tuning_results)); + model_metadata_, tuning_results, found_tuning_results, *session_logger_)); if (found_tuning_results) { ORT_RETURN_IF_ERROR_SESSIONID_(SetTuningResults(tuning_results, /*error_on_invalid*/ false, /*auto_enable*/ true)); } @@ -3233,7 +3234,8 @@ common::Status InferenceSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); @@ -3245,7 +3247,7 @@ common::Status InferenceSession::AddPredefinedTransformers( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; if (use_full_build_optimizations) { - return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, + return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); @@ -3257,6 +3259,7 @@ common::Status InferenceSession::AddPredefinedTransformers( record_runtime_optimization_produced_op_schema_fn}} : SatApplyContextVariant{SatDirectApplicationContext{}}; return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, + logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse(), session_state_->GetMutableBufferedTensors()); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 0675f64848fd0..e28ff75345785 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -690,8 +690,9 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, - /*out*/ InlinedVector& arenas_to_shrink) const; + [[nodiscard]] common::Status ValidateAndParseShrinkArenaString( + const std::string& ort_device_list, + /*out*/ InlinedVector& arenas_to_shrink) const; /* * Performs the shrinkage of arenas requested to be shrunk by the user @@ -708,7 +709,8 @@ class InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const; common::Status TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format); diff --git a/onnxruntime/core/session/inference_session_utils.cc b/onnxruntime/core/session/inference_session_utils.cc index 3436eebda3819..8b9de0c604441 100644 --- a/onnxruntime/core/session/inference_session_utils.cc +++ b/onnxruntime/core/session/inference_session_utils.cc @@ -236,7 +236,8 @@ Status JsonConfigParser::ParseRunOptionsFromModelProto(RunOptions& /*run_options Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, std::vector& results, - bool& key_found) { + bool& key_found, + const logging::Logger& logger) { results.clear(); key_found = false; auto it = metadata.custom_metadata_map.find(kTuningResultsKeys); @@ -245,7 +246,7 @@ Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& met } key_found = true; - LOGS_DEFAULT(INFO) << "Found tuning results in the model file to be used while loading the model"; + LOGS(logger, INFO) << "Found tuning results in the model file to be used while loading the model"; Status status; ORT_TRY { diff --git a/onnxruntime/core/session/inference_session_utils.h b/onnxruntime/core/session/inference_session_utils.h index a0bcdb9013bf0..f297d928f8a0d 100644 --- a/onnxruntime/core/session/inference_session_utils.h +++ b/onnxruntime/core/session/inference_session_utils.h @@ -19,7 +19,9 @@ using json = nlohmann::json; #endif namespace onnxruntime { - +namespace logging { +class Logger; +} namespace inference_session_utils { // need this value to be accessible in all builds in order to report error for attempted usage in a minimal build @@ -60,7 +62,8 @@ class JsonConfigParser { Status ParseTuningResultsFromModelMetadata(const onnxruntime::ModelMetadata& metadata, /*out*/ std::vector& results, - /*out*/ bool& key_found); + /*out*/ bool& key_found, + const logging::Logger& logger); #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d55fd34d5a8f2..c3832498af584 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -279,8 +279,9 @@ struct ProviderHostImpl : ProviderHost { std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, const IExecutionProvider::IKernelLookup& kernel_lookup, - gsl::span tentative_nodes) override { - return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes); + gsl::span tentative_nodes, + const logging::Logger& logger) override { + return onnxruntime::GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ bool* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } @@ -1057,8 +1058,8 @@ struct ProviderHostImpl : ProviderHost { } std::pair>, std::unordered_map> - QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer) override { - return QDQ::GetAllNodeUnits(*graph_viewer); + QDQ__GetAllNodeUnits(const GraphViewer* graph_viewer, const logging::Logger& logger) override { + return QDQ::GetAllNodeUnits(*graph_viewer, logger); } // Model (wrapped) diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 9cbf01946e92b..2706448d831cc 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -314,7 +314,8 @@ class StandAloneKernelContext : public OpKernelContext { AllocatorPtr allocator_; }; // StandAloneKernelContext -onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, OrtOpAttr** op_attr) { +onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, OrtOpAttrType type, + OrtOpAttr** op_attr) { auto attr = std::make_unique(); onnxruntime::Status status = onnxruntime::Status::OK(); attr->set_name(std::string{name}); @@ -410,7 +411,9 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info, node_ptr->SetSinceVersion(version); - auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, &kernel_create_info); + auto status = kernel_registry->TryFindKernel(*node_ptr, ep->Type(), type_constraint_map, + logging::LoggingManager::DefaultLogger(), // no other logger available + &kernel_create_info); ORT_RETURN_IF_ERROR(status); auto& kernel_def = kernel_create_info->kernel_def; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index a7f8a6424aa50..adab93908cdc4 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -252,6 +252,7 @@ class PlannerTest : public ::testing::Test { void BindKernel(onnxruntime::Node* p_node, ::onnxruntime::KernelDef& kernel_def, KernelRegistry* reg, std::unordered_map>& kernel_create_info_map) { + const auto& logger = DefaultLoggingManager().DefaultLogger(); const IExecutionProvider* ep = execution_providers_.Get(*p_node); ASSERT_NE(ep, nullptr); auto info = std::make_unique( @@ -261,7 +262,7 @@ class PlannerTest : public ::testing::Test { op_kernel_infos_.push_back(std::move(info)); const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{}; if (!KernelRegistry::HasImplementationOf(*reg, *p_node, onnxruntime::kCpuExecutionProvider, - kernel_type_str_resolver)) { + kernel_type_str_resolver, logger)) { ASSERT_STATUS_OK(reg->Register( KernelCreateInfo(std::make_unique(kernel_def), [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { @@ -271,7 +272,7 @@ class PlannerTest : public ::testing::Test { } const KernelCreateInfo* kci; - ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, &kci)); + ASSERT_STATUS_OK(reg->TryFindKernel(*p_node, "", kernel_type_str_resolver, logger, &kci)); kernel_create_info_map.insert({p_node->Index(), gsl::not_null(kci)}); } @@ -283,7 +284,8 @@ class PlannerTest : public ::testing::Test { } } - void CreatePlan(const std::vector& outer_scope_node_args = {}, bool invoke_createPlan_explicityly = true) { + void CreatePlan(const std::vector& outer_scope_node_args = {}, + bool invoke_createPlan_explicityly = true) { state_.reset(new SessionState(graph_, execution_providers_, tp_.get(), nullptr, dtm_, edlm_, DefaultLoggingManager().DefaultLogger(), profiler_, *sess_options_)); EXPECT_EQ(graph_.Resolve(), Status::OK()); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 67d60ea3a4ff6..2ff0b599beebf 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -831,7 +831,8 @@ static void VerifyConstantFoldingWithDequantizeLinear(const std::unordered_mapName() == "ConstantFolding") { ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(transformer), TransformerLevel::Level1)); @@ -4704,7 +4705,8 @@ TEST_F(GraphTransformationTests, BiasGeluSwitchedInputOrder) { // Compare results double per_sample_tolerance = 1e-3; double relative_per_sample_tolerance = 0.0; - auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], per_sample_tolerance, relative_per_sample_tolerance, false); + auto ret = CompareOrtValue(optimized_fetches[0], unoptimized_fetches[0], + per_sample_tolerance, relative_per_sample_tolerance, false); EXPECT_EQ(ret.first, COMPARE_RESULT::SUCCESS) << ret.second; } @@ -4713,7 +4715,8 @@ static void VerifyGeluApproximation(bool is_enabled, SessionOptions& session_opt std::make_unique(CPUExecutionProviderInfo()); bool has_gelu_approximation = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, session_options, *e.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "GeluApproximation") { has_gelu_approximation = true; @@ -4728,7 +4731,8 @@ TEST_F(GraphTransformationTests, DoubleQDQRemover_SessionOptionConfig) { auto verify_session_config = [&](bool is_enabled, SessionOptions& session_option) { std::unique_ptr cpu_ep = std::make_unique(CPUExecutionProviderInfo()); bool has_double_qdq_remover = false; - auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), {}); + auto transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, session_option, *cpu_ep.get(), + DefaultLoggingManager().DefaultLogger(), {}); for (auto& transformer : transformers) { if (transformer->Name() == "DoubleQDQPairsRemover") { has_double_qdq_remover = true; diff --git a/onnxruntime/test/optimizer/graph_transform_utils_test.cc b/onnxruntime/test/optimizer/graph_transform_utils_test.cc index 66b74641e41d3..caa64560426af 100644 --- a/onnxruntime/test/optimizer/graph_transform_utils_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_utils_test.cc @@ -36,9 +36,11 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { std::string l2_transformer = "ConvActivationFusion"; InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); + const auto& logger = DefaultLoggingManager().DefaultLogger(); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -61,8 +63,9 @@ TEST(GraphTransformerUtilsTests, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, + disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif } diff --git a/onnxruntime/test/optimizer/optimizer_test.cc b/onnxruntime/test/optimizer/optimizer_test.cc index 81c1a4ace1e33..b306f026b2dfd 100644 --- a/onnxruntime/test/optimizer/optimizer_test.cc +++ b/onnxruntime/test/optimizer/optimizer_test.cc @@ -27,6 +27,7 @@ namespace test { TEST(OptimizerTest, Basic) { Model model("OptimizerBasic", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {{kOnnxDomain, 12}}, {}, DefaultLoggingManager().DefaultLogger()); + const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); auto& graph = model.MainGraph(); constexpr int tensor_dim = 10; @@ -66,22 +67,21 @@ TEST(OptimizerTest, Basic) { auto cpu_execution_provider = std::make_unique(CPUExecutionProviderInfo()); #if !defined(DISABLE_SPARSE_TENSORS) - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [&graph](const std::string& name) -> bool { - return graph.IsSparseInitializer(name); - }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [&graph](const std::string& name) -> bool { + return graph.IsSparseInitializer(name); + }, + logger); #else - OptimizerExecutionFrame::Info info(nodes, initialized_tensor_set, - graph.ModelPath(), - *cpu_execution_provider.get(), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info( + nodes, initialized_tensor_set, graph.ModelPath(), *cpu_execution_provider.get(), + [](std::string const&) { return false; }, + logger); #endif //! defined(DISABLE_SPARSE_TENSORS) std::vector fetch_mlvalue_idxs{info.GetMLValueIndex("out")}; OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs); - const logging::Logger& logger = DefaultLoggingManager().DefaultLogger(); const ConfigOptions empty_config_options; diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index cfee4a83a4292..043b92d7ef121 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3928,6 +3928,7 @@ TEST(QDQTransformerTests, QDQPropagation_DQForward_SliceMultipleConsumers) { TEST(QDQTransformerTests, QDQ_Selector_Test) { const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/transform/qdq_conv.onnx"); + const auto& logger = DefaultLoggingManager().DefaultLogger(); SessionOptions so; // We want to keep the graph un-optimized to prevent QDQ transformer to kick in @@ -3962,7 +3963,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check if SelectorManager get a conv qdq group selection as expected { - const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(whole_graph_viewer, logger); ASSERT_FALSE(result.empty()); const auto& qdq_group = result.at(0); ASSERT_EQ(std::vector({0, 1, 2}), qdq_group.dq_nodes); @@ -3977,7 +3978,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(whole_graph_viewer, logger); // We should get a single QDQ Node unit in the result ASSERT_EQ(1, node_unit_holder.size()); @@ -4045,7 +4046,7 @@ TEST(QDQTransformerTests, QDQ_Selector_Test) { // Check SelectorManager will get empty result { - const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer); + const auto result = selector_mgr.GetQDQSelections(partial_graph_viewer, logger); ASSERT_TRUE(result.empty()); } } diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 9d83c789c5124..b0958e05dc373 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -420,6 +420,7 @@ bool SetEpsForAllNodes(Graph& graph, continue; bool found = false; + const auto& logger = DefaultLoggingManager().DefaultLogger(); for (const auto& ep : execution_providers) { auto provider_type = ep->Type(); @@ -438,7 +439,8 @@ bool SetEpsForAllNodes(Graph& graph, } // Check the EP has an impl for the node from builtin registry. - if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver)) { + if (KernelRegistry::HasImplementationOf(*ep->GetKernelRegistry(), node, ep->Type(), kernel_type_str_resolver, + logger)) { found = true; break; } @@ -451,6 +453,7 @@ bool SetEpsForAllNodes(Graph& graph, std::string_view(kMSInternalNHWCDomain), node.SinceVersion(), type_constraint_map, + logger, &kci); if (status.IsOK() && kci != nullptr) { found = true; @@ -463,7 +466,7 @@ bool SetEpsForAllNodes(Graph& graph, std::any_of(custom_registries->cbegin(), custom_registries->cend(), [&](auto reg) { return KernelRegistry::HasImplementationOf(*reg->GetKernelRegistry(), node, ep->Type(), - kernel_type_str_resolver); + kernel_type_str_resolver, logger); })) { found = true; break; diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 23ec48fa649dd..93e688570631e 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -42,8 +42,9 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #endif + const auto& logger = DefaultLoggingManager().DefaultLogger(); Model model("test", false, ModelMetaData(), ORT_TSTR(""), IOnnxRuntimeOpSchemaRegistryList(), - {{domain_, opset_version_}}, {}, DefaultLoggingManager().DefaultLogger()); + {{domain_, opset_version_}}, {}, logger); std::vector input_args; std::unordered_map initializer_map; @@ -89,8 +90,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { ASSERT_STATUS_OK(graph.Resolve()); node.SetExecutionProviderType(ep_type); - OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), - [](std::string const&) { return false; }); + OptimizerExecutionFrame::Info info({&node}, initializer_map, graph.ModelPath(), *execution_providers.Get(ep_type), [](std::string const&) { return false; }, logger); const KernelCreateInfo* kernel_create_info = nullptr; ASSERT_STATUS_OK(info.TryFindKernel(&node, &kernel_create_info)); ASSERT_TRUE(kernel_create_info); @@ -139,7 +139,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { #pragma warning(disable : 6387) #endif OptimizerExecutionFrame frame(info, fetch_mlvalue_idxs, outputs); - OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, DefaultLoggingManager().DefaultLogger()); + OpKernelContext op_kernel_context(&frame, kernel.get(), nullptr, nullptr, logger); #ifdef _WIN32 #pragma warning(pop) #endif diff --git a/onnxruntime/test/providers/partitioning_utils_test.cc b/onnxruntime/test/providers/partitioning_utils_test.cc index 5db69489afaef..f1fbb1cea7ea2 100644 --- a/onnxruntime/test/providers/partitioning_utils_test.cc +++ b/onnxruntime/test/providers/partitioning_utils_test.cc @@ -51,7 +51,7 @@ TEST(PartitioningUtilsTest, TestQDQHandling) { std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); auto result = utils::CreateSupportedPartitions(graph_viewer, is_node_supported, on_group_closed, gen_metadef_name, "TEST", kCpuExecutionProvider, &node_unit_map, @@ -82,7 +82,7 @@ static void CheckAllNodesProcessed(const std::function& std::vector> node_unit_holder; std::unordered_map node_unit_map; - std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); + std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger); const auto is_node_supported = [&](const Node& /*node*/) -> bool { return true; diff --git a/orttraining/orttraining/core/session/training_session.cc b/orttraining/orttraining/core/session/training_session.cc index 87a7cbc0375a4..f1545e96481fa 100644 --- a/orttraining/orttraining/core/session/training_session.cc +++ b/orttraining/orttraining/core/session/training_session.cc @@ -758,7 +758,8 @@ Status TrainingSession::AddPredefinedTransformers( GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/) const { + RecordRuntimeOptimizationProducedNodeOpSchemaFn /*record_runtime_optimization_produced_op_schema_fn*/, + const logging::Logger& /*logger*/) const { ORT_RETURN_IF_NOT( minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations, "Only applying full build optimizations is supported by TrainingSession."); diff --git a/orttraining/orttraining/core/session/training_session.h b/orttraining/orttraining/core/session/training_session.h index 765f88e1c992e..58492dc62400f 100644 --- a/orttraining/orttraining/core/session/training_session.h +++ b/orttraining/orttraining/core/session/training_session.h @@ -489,7 +489,8 @@ class TrainingSession : public InferenceSession { GraphTransformerManager& transformer_manager, TransformerLevel graph_optimization_level, MinimalBuildOptimizationHandling minimal_build_optimization_handling, - RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn) const override; + RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, + const logging::Logger& logger) const override; /** Perform auto-diff to add backward graph into the model. @param weights_to_train a set of weights to be training. diff --git a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc index 0944e46ff8eaf..58c173ed90277 100644 --- a/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc +++ b/orttraining/orttraining/test/gradient/gradient_op_test_utils.cc @@ -139,7 +139,8 @@ void GradientOpTester::Run(int output_index_to_use_as_loss, auto reg = execution_provider->GetKernelRegistry(); const KernelCreateInfo* kci; - auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, &kci); + auto st = reg->TryFindKernel(node, execution_provider->Type(), kernel_type_str_resolver, + DefaultLoggingManager().DefaultLogger(), &kci); if (!st.IsOK()) { // The goal here is unclear. It seems best to leave it to the Session // creation to figure out whether the model can be executed using some diff --git a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc index 548f39bb0150c..1b8699d1de497 100644 --- a/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transformer_utils_test.cc @@ -23,8 +23,10 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { InlinedHashSet disabled = {l1_rule1, l1_transformer, l2_transformer}; CPUExecutionProvider cpu_ep(CPUExecutionProviderInfo{}); - auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep); - auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, disabled); + const auto& logger = DefaultLoggingManager().DefaultLogger(); + auto all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger); + auto filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level1, {}, cpu_ep, logger, + disabled); // check ConstantFolding transformer was removed ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); @@ -47,8 +49,8 @@ TEST(GraphTransformerUtilsTestsForTraining, TestGenerateGraphTransformers) { #ifndef DISABLE_CONTRIB_OPS // check that ConvActivationFusion was removed - all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep); - filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, disabled); + all_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger); + filtered_transformers = optimizer_utils::GenerateTransformers(TransformerLevel::Level2, {}, cpu_ep, logger, disabled); ASSERT_TRUE(filtered_transformers.size() == all_transformers.size() - 1); #endif }