Skip to content

Commit

Permalink
Make session configuration options available to kernels via OpKernelI…
Browse files Browse the repository at this point in the history
…nfo (#18897)

### Description
<!-- Describe your changes. -->
Pass through the ConfigOptions from the session via OpKernelInfo so that
kernel behavior can be configured.

Initial usage would be to optionally enable a fast path for ARM64 bloat16 GEMM - see #17031
Other usages could be things like selected the exact implementations of the activation functions for RNN operators instead of the default approximations (e.g. use [sigmoid_exact instead of sigmoid](https://github.com/microsoft/onnxruntime/blob/2d6e2e243d1a1ab0486f4f191b61ac979c5b978e/onnxruntime/core/providers/cpu/rnn/rnn_helpers.h#L379-L382))

OpKernelInfo is already passing through things from the session state, and adding a new member of ConfigOptions
is the simpler update. It's also a more natural fit given it's providing state/info to the kernel.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
skottmckay authored and mszhanyi committed Jan 15, 2024
1 parent c80cb34 commit 746a04d
Show file tree
Hide file tree
Showing 26 changed files with 246 additions and 162 deletions.
6 changes: 5 additions & 1 deletion include/onnxruntime/core/framework/op_kernel_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& mlvalue_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const AllocatorMap& allocators = {});
const AllocatorMap& allocators,
const ConfigOptions& config_options);

OpKernelInfo(const OpKernelInfo& other);

Expand All @@ -50,6 +51,8 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {

const AllocatorMap& GetAllocators() const { return allocators_; }

const ConfigOptions& GetConfigOptions() const { return config_options_; }

private:
ORT_DISALLOW_MOVE(OpKernelInfo);
ORT_DISALLOW_ASSIGNMENT(OpKernelInfo);
Expand All @@ -64,6 +67,7 @@ class OpKernelInfo : public OpNodeProtoHelper<ProtoHelperNodeContext> {
const DataTransferManager& data_transfer_mgr_;
ProtoHelperNodeContext proto_helper_context_;
const AllocatorMap& allocators_;
const ConfigOptions& config_options_;
};

} // namespace onnxruntime
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/kernel_registry_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ Status KernelRegistryManager::CreateKernel(const Node& node,
session_state.GetConstantInitializedTensors(),
session_state.GetOrtValueNameIdxMap(),
session_state.GetDataTransferMgr(),
session_state.GetAllocators());
session_state.GetAllocators(),
session_state.GetSessionOptions().config_options);

return kernel_create_info.kernel_create_func(session_state.GetMutableFuncMgr(), kernel_info, out);
}
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/core/framework/op_kernel_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
const std::unordered_map<int, OrtValue>& constant_initialized_tensors,
const OrtValueNameIdxMap& ort_value_name_idx_map,
const DataTransferManager& data_transfer_mgr,
const AllocatorMap& allocators)
const AllocatorMap& allocators,
const ConfigOptions& config_options)
: OpNodeProtoHelper(&proto_helper_context_),
node_(node),
kernel_def_(kernel_def),
Expand All @@ -24,15 +25,22 @@ OpKernelInfo::OpKernelInfo(const onnxruntime::Node& node,
ort_value_name_idx_map_(ort_value_name_idx_map),
data_transfer_mgr_(data_transfer_mgr),
proto_helper_context_(node),
allocators_(allocators) {}
allocators_(allocators),
config_options_(config_options) {
}

OpKernelInfo::OpKernelInfo(const OpKernelInfo& other)
: OpKernelInfo(other.node_, other.kernel_def_, *other.execution_provider_, other.constant_initialized_tensors_,
other.ort_value_name_idx_map_, other.data_transfer_mgr_, other.allocators_) {}
other.ort_value_name_idx_map_, other.data_transfer_mgr_,
other.allocators_, other.config_options_) {
}

AllocatorPtr OpKernelInfo::GetAllocator(OrtMemType mem_type) const {
auto it = allocators_.find(execution_provider_->GetOrtDeviceByMemType(mem_type));
if (it != allocators_.end()) return it->second;
if (it != allocators_.end()) {
return it->second;
}

return nullptr;
}

Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/optimizer/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ namespace onnxruntime {

ConstantFolding::ConstantFolding(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers,
const InlinedHashSet<std::string>& excluded_initializers) noexcept
: GraphTransformer("ConstantFolding", compatible_execution_providers),
skip_dequantize_linear_(skip_dequantize_linear),
config_options_(config_options),
excluded_initializers_(excluded_initializers),
execution_provider_(execution_provider) {
}
Expand Down Expand Up @@ -250,12 +252,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
// override the EP assigned to the node so that it will use the CPU kernel for Compute.
node->SetExecutionProviderType(kCpuExecutionProvider);

kernel = info.CreateKernel(node);
kernel = info.CreateKernel(node, config_options_);

// undo the EP change to the value that was assigned at graph partitioning time
node->SetExecutionProviderType(ep_type);
} else {
kernel = info.CreateKernel(node);
kernel = info.CreateKernel(node, config_options_);
}

// We currently constant fold using the CPU EP only.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/optimizer/constant_folding.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ class ConstantFolding : public GraphTransformer {
*/
ConstantFolding(const IExecutionProvider& execution_provider,
bool skip_dequantize_linear,
const ConfigOptions& config_options,
const InlinedHashSet<std::string_view>& compatible_execution_providers = {},
const InlinedHashSet<std::string>& excluded_initializers = {}) noexcept;

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;

bool skip_dequantize_linear_;
const ConfigOptions& config_options_;
const InlinedHashSet<std::string> excluded_initializers_;
const IExecutionProvider& execution_provider_;
};
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
transformers.emplace_back(std::make_unique<ConstantSharing>(no_limit_empty_ep_list, excluded_initializers));

transformers.emplace_back(std::make_unique<CommonSubexpressionElimination>());
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq));
transformers.emplace_back(std::make_unique<ConstantFolding>(cpu_execution_provider, !disable_quant_qdq,
session_options.config_options));
transformers.emplace_back(std::make_unique<MatMulAddFusion>());
transformers.emplace_back(std::make_unique<ReshapeFusion>());
transformers.emplace_back(std::make_unique<FreeDimensionOverrideTransformer>(
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/core/optimizer/optimizer_execution_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,26 +128,34 @@ static Status TryCreateKernel(const Node& node,
const OrtValueNameIdxMap& ort_value_name_idx_map,
FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
const ConfigOptions& config_options,
/*out*/ std::unique_ptr<OpKernel>& op_kernel) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelCreateInfo* kernel_create_info = nullptr;
ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
&kernel_create_info));

static const AllocatorMap dummy_allocators;

OpKernelInfo kernel_info(node,
*kernel_create_info->kernel_def,
execution_provider,
constant_initialized_tensors,
ort_value_name_idx_map,
data_transfer_mgr);
data_transfer_mgr,
dummy_allocators,
config_options);

return kernel_create_info->kernel_create_func(funcs_mgr, kernel_info, op_kernel);
}

std::unique_ptr<const OpKernel> OptimizerExecutionFrame::Info::CreateKernel(const Node* node) const {
std::unique_ptr<const OpKernel>
OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOptions& config_options) const {
std::unique_ptr<OpKernel> op_kernel;
std::shared_ptr<KernelRegistry> kernel_registry = execution_provider_.GetKernelRegistry();
FuncManager func;
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
ort_value_name_idx_map_, func, data_transfer_mgr_,
ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
op_kernel);

// Kernel found in the CPU kernel registry
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/optimizer/optimizer_execution_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
const Path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);

Info(const std::vector<const Node*>& nodes,
const std::unordered_map<std::string, OrtValue>& initialized_tensor_set,
const Path& model_path,
const IExecutionProvider& execution_provider,
const std::function<bool(const std::string&)>& is_sparse_initializer_func);

~Info() = default;

const AllocatorPtr& GetAllocator() const {
Expand All @@ -52,7 +54,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
return -1;
}

std::unique_ptr<const OpKernel> CreateKernel(const Node* node) const;
std::unique_ptr<const OpKernel> CreateKernel(const Node* node, const ConfigOptions& config_options) const;

// Check if an kernel create info can be found in the registry.
Status TryFindKernel(const Node* node, const KernelCreateInfo** out) const;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ struct Logger;
struct Capture;
} // namespace logging
struct ComputeCapability;
struct ConfigOptions;
struct DataTransferManager;
struct IndexedSubGraph;
struct IndexedSubGraph_MetaDef;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <optional>

// Public wrappers around internal ort interfaces (currently)
#include "core/providers/shared_library/provider_host_api.h"

Expand Down Expand Up @@ -426,6 +428,9 @@ struct ProviderHost {

virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;

// ConfigOptions
virtual std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0;

// ComputeCapability
virtual std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) = 0;
virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0;
Expand Down Expand Up @@ -808,6 +813,7 @@ struct ProviderHost {
virtual uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) = 0;
virtual uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) = 0;
virtual const Node& OpKernelInfo__node(const OpKernelInfo* p) = 0;
virtual const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) = 0;

// SessionState
virtual const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) = 0;
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ struct DataTypeUtils final {

} // namespace Utils

struct ConfigOptions final {
std::optional<std::string> GetConfigEntry(const std::string& config_key) const {
return g_host->ConfigOptions__GetConfigEntry(this, config_key);
}

PROVIDER_DISALLOW_ALL(ConfigOptions)
};

struct ComputeCapability final {
static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
Expand Down Expand Up @@ -901,6 +909,8 @@ struct OpKernelInfo final {

const Node& node() const noexcept { return g_host->OpKernelInfo__node(this); }

const ConfigOptions& GetConfigOptions() const { return g_host->OpKernelInfo__GetConfigOptions(this); }

OpKernelInfo() = delete;
OpKernelInfo(const OpKernelInfo&) = delete;
void operator=(const OpKernelInfo&) = delete;
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "core/common/inlined_containers.h"
#include "core/framework/allocator_utils.h"
#include "core/framework/config_options.h"
#include "core/framework/compute_capability.h"
#include "core/framework/data_types.h"
#include "core/framework/data_transfer_manager.h"
Expand Down Expand Up @@ -529,6 +530,11 @@ struct ProviderHostImpl : ProviderHost {

const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; }

// ConfigOptions (wrapped)
std::optional<std::string> ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override {
return p->GetConfigEntry(config_key);
}

// ComputeCapability (wrapped)
std::unique_ptr<ComputeCapability> ComputeCapability__construct(std::unique_ptr<IndexedSubGraph> t_sub_graph) override { return std::make_unique<ComputeCapability>(std::move(t_sub_graph)); }
void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; }
Expand Down Expand Up @@ -934,6 +940,7 @@ struct ProviderHostImpl : ProviderHost {
uint32_t OpKernelInfo__GetInputCount(const OpKernelInfo* p) override { return p->GetInputCount(); }
uint32_t OpKernelInfo__GetOutputCount(const OpKernelInfo* p) override { return p->GetOutputCount(); }
const Node& OpKernelInfo__node(const OpKernelInfo* p) override { return p->node(); }
const ConfigOptions& OpKernelInfo__GetConfigOptions(const OpKernelInfo* p) override { return p->GetConfigOptions(); }

// SessionState (wrapped)
const DataTransferManager& SessionState__GetDataTransferMgr(const SessionState* p) override { return p->GetDataTransferMgr(); }
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/session/standalone_op_invoker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,10 @@ onnxruntime::Status CreateOp(_In_ const OrtKernelInfo* info,
static const OrtValueNameIdxMap kEmptyNameMap;

OpKernelInfo tmp_kernel_info(*node_ptr.get(), *kernel_def, *ep, kEmptyValueMap, kEmptyNameMap,
kernel_info->GetDataTransferManager(), kernel_info->GetAllocators());
kernel_info->GetDataTransferManager(),
kernel_info->GetAllocators(),
kernel_info->GetConfigOptions());

std::unique_ptr<onnxruntime::OpKernel> op_kernel;

auto& node_repo = NodeRepo::GetInstance();
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/framework/allocation_planner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class PlannerTest : public ::testing::Test {
ASSERT_NE(ep, nullptr);
auto info = std::make_unique<OpKernelInfo>(
*p_node, kernel_def, *ep, state_->GetInitializedTensors(), state_->GetOrtValueNameIdxMap(),
state_->GetDataTransferMgr());
state_->GetDataTransferMgr(), state_->GetAllocators(), state_->GetSessionOptions().config_options);

op_kernel_infos_.push_back(std::move(info));
const auto kernel_type_str_resolver = OpSchemaKernelTypeStrResolver{};
Expand Down
Loading

0 comments on commit 746a04d

Please sign in to comment.