Skip to content

Commit

Permalink
We have use cases where multiple sessions are created concurrently. M…
Browse files Browse the repository at this point in the history
…inimizing the usage of the default logger is important for these scenarios.

Wire through the session logger to as many places as possible. The EP logger can also be used once the session is created (can't be used during EP construction/kernel registration but can be used in GetCapability and Compile).
  • Loading branch information
skottmckay committed Dec 5, 2024
1 parent d27fecd commit 1c62d21
Show file tree
Hide file tree
Showing 62 changed files with 388 additions and 241 deletions.
12 changes: 10 additions & 2 deletions include/onnxruntime/core/framework/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace logging {
class Logger;
}

using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
Expand All @@ -33,6 +36,7 @@ class KernelRegistry {
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

// map of type constraint name to required type
Expand All @@ -42,6 +46,7 @@ class KernelRegistry {
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

/**
Expand All @@ -61,13 +66,15 @@ class KernelRegistry {
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver) {
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger) {
const KernelCreateInfo* info;
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
return st.IsOK();
}

Expand All @@ -83,6 +90,7 @@ class KernelRegistry {
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

// Check whether the types of inputs/outputs of the given node match the extra
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
Expand Down Expand Up @@ -84,6 +85,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
Expand Down
79 changes: 43 additions & 36 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ class PlannerImpl {
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const InlinedHashMap<OrtValueName, OrtDevice>& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
const ISequentialPlannerContext& context, SequentialExecutionPlan& plan,
const logging::Logger& logger)
: context_(&context),
plan_(plan),
parent_node_(parent_node),
Expand All @@ -148,14 +149,15 @@ class PlannerImpl {
kernel_create_info_map_(kernel_create_info_map),
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
ort_value_name_idx_map_(ort_value_name_idx_map) {}
ort_value_name_idx_map_(ort_value_name_idx_map),
logger_(logger) {
}

Status CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
const PathString& partition_config_file,
const logging::Logger& logger);
const PathString& partition_config_file);

private:
gsl::not_null<const ISequentialPlannerContext*> context_;
Expand Down Expand Up @@ -183,6 +185,12 @@ class PlannerImpl {
InlinedHashMap<onnxruntime::NodeIndex, InlinedHashSet<onnxruntime::NodeIndex>> dependence_graph_;
InlinedHashMap<onnxruntime::OrtValueIndex, onnxruntime::NodeIndex> value_node_map_;

// logger_ is not currently used in a minimal build
#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
[[maybe_unused]]
#endif
const logging::Logger& logger_;

// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
struct OrtValueInfo {
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
Expand Down Expand Up @@ -213,6 +221,7 @@ class PlannerImpl {
FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point)
: ml_value(ort_value), deallocate_point(dealloc_point) {}
};

// freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when
// they became free (more recently freed earlier in the list).
std::list<FreeBufferInfo> freelist_;
Expand All @@ -225,7 +234,8 @@ class PlannerImpl {
}

int& UseCount(OrtValueIndex n) {
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size());
ORT_ENFORCE(n >= 0 && static_cast<size_t>(n) < ort_value_info_.size(),
"invalid value index: ", n, " against size ", ort_value_info_.size());
return ort_value_info_[n].usecount;
}
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }
Expand Down Expand Up @@ -335,9 +345,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
Expand All @@ -361,9 +371,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
<< producer_node->Name() << " which has external outputs. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
Expand Down Expand Up @@ -397,8 +407,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
#endif
}
}
Expand Down Expand Up @@ -448,8 +458,8 @@ class PlannerImpl {
return true;
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
<< producer_node->Name() << " as it has external outputs.";
LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
<< producer_node->Name() << " as it has external outputs.";
#endif
}
}
Expand Down Expand Up @@ -1198,9 +1208,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif

Expand Down Expand Up @@ -1241,9 +1251,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
<< producer_node->Name() << " which has external outputs is reused. "
<< "Be cautious the reuse MUST be a read-only usage.";
}
#endif

Expand Down Expand Up @@ -1290,8 +1300,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
<< producer_node->Name() << " as it has external outputs";
#endif
}
}
Expand Down Expand Up @@ -1869,8 +1879,7 @@ class PlannerImpl {
}

#ifndef ORT_ENABLE_STREAM
void PartitionIntoStreams(const logging::Logger& /*logger*/,
const ExecutionProviders& /*execution_providers*/,
void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/,
const PathString& /*partition_config_file*/) {
if (graph_viewer_.NumberOfNodes() > 0) {
stream_nodes_.push_back({});
Expand Down Expand Up @@ -1915,11 +1924,11 @@ class PlannerImpl {

#else

void
PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
void PartitionIntoStreams(const ExecutionProviders& execution_providers,
const PathString& partition_config_file) {
auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file);
auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_,
context_->GetExecutionOrder());
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
plan_.node_stream_map_.resize(SafeInt<size_t>(graph_viewer_.MaxNodeIndex()) + 1);
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
Expand Down Expand Up @@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
const PathString& partition_config_file,
const logging::Logger& logger) {
const PathString& partition_config_file) {
// 1. partition graph into streams
PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file);
PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file);

// 2. initialize the plan based on stream partition result
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
Expand Down Expand Up @@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan(
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
ort_value_name_idx_map, context, *plan);
ort_value_name_idx_map, context, *plan, logger);

return planner.CreatePlan(
#ifdef ORT_ENABLE_STREAM
stream_handle_registry,
#endif
partition_config_file,
logger);
partition_config_file);
}

#ifdef ORT_ENABLE_STREAM
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/core/framework/fallback_cpu_capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node

std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes) {
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger) {
// automatic conversion from const std::vector&
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
InlinedVector<size_t> node_id_to_order_map(graph.MaxNodeIndex());
Expand Down Expand Up @@ -83,7 +84,7 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {
candidates.push(consumer_node->Index());
LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
}
}
return Status::OK();
Expand Down Expand Up @@ -159,9 +160,9 @@ std::unordered_set<NodeIndex> GetCpuPreferredNodes(const onnxruntime::GraphViewe

if (place_in_cpu) {
cpu_nodes.insert(cur);
LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
<< " capable of executing this node";
LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
<< " because the CPU execution path is deemed faster than overhead involved with execution "
"on other EPs capable of executing this node";
for (auto* output : node->OutputDefs()) {
cpu_output_args.insert(output);
}
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/framework/fallback_cpu_capability.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
#include "core/graph/graph_viewer.h"

namespace onnxruntime {
namespace logging {
class Logger;
}

/**
Returns a list of nodes that are preferred on CPU.
Expand All @@ -19,6 +22,7 @@ namespace onnxruntime {
*/
std::unordered_set<NodeIndex> GetCpuPreferredNodes(const GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
gsl::span<const NodeIndex> tentative_nodes);
gsl::span<const NodeIndex> tentative_nodes,
const logging::Logger& logger);

} // namespace onnxruntime
Loading

0 comments on commit 1c62d21

Please sign in to comment.