Skip to content

Commit

Permalink
VitisAI EP Context Model (microsoft#20926)
Browse files Browse the repository at this point in the history
# Why so many commits
- Runtime debugging - which is necessary
- Three different approaches to EP context model - as a result testing back and forth
- Windows compatibility issues - this development has been done on Linux for convenience

# "Open" (?) questions
- Full offloading to a specific EP
- Dumping EP context models by EPs vs [by
ONNXRT](https://github.com/microsoft/onnxruntime/blob/e2abba18ea9370329ce6894a4eb3e98ad8f11cb6/onnxruntime/core/framework/graph_partitioner.cc#L725)
- [Node name to pick
nodes](https://github.com/microsoft/onnxruntime/blob/e2abba18ea9370329ce6894a4eb3e98ad8f11cb6/onnxruntime/core/framework/graph_partitioner.cc#L654)

# VitisAI EP made three variant implementations that have respective pros and cons (and of course we can combine them)
## Serialize and cache the list of compute capabilities and the original
ONNX model itself
## In `ComputeCapability()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key
## In `Compile()`, serialize and cache the backend compilation cache and the related necessary cache info such as cache dir and cache key

# EP context model creation
- Precondition
Session option configuration `kOrtSessionOptionEpContextEnable` (aka "ep.context_enable") is enabled.
- Approach 1
  - Steps
1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext").
2. EP implements/overrides `IExecutionProvider::GetEpContextNodes()` method.
    3. ONNXRT core creates an EP context model and saves/dumps it.
       - `CreateEpContextModel()` in the file "graph_partitioner.cc"
- In `get_ep_context_node()`, `Node::Name()` is used to check whether a node is an EP context node. This limits that EP model creation can only happen in `IExecutionProvider::Compile()`.
- The workaround is (1) not implementing `IExecutionProvider::GetEpContextNodes()` and (2) dumping the EP context model by EP itself.
4. Optionally, EP can also dump the EP context model it created by
iteself.
  - Examples
    - `QNNExecutionProvider`
    - `VitisAIExecutionProvider`
- Approach 2
  - Steps
1. EP creates an ONNX model whose main graph has EP context nodes (i.e., node type is "EPContext").
2. EP does NOT implement `IExecutionProvider::GetEpContextNodes()` at all.
    3. EP dumps the EP context model it created.
  - Examples
    - `TensorrtExecutionProvider`
       - UPDATES
- TRT EP is switching to leveraging
`IExecutionProvider::GetEpContextNodes()`
    - `OpenVINOExecutionProvider` (?)

# What to cache in EP context nodes
- Non Compilation based EPs
  - Examples
    - `VitisAIExecutionProvider`
  - Characteristics
- Heavy lifting work happens in `IExecutionProvider::GetCapability()`.
  - Preconditions
- `IExecutionProvider::GetCapability()` is only called once by ONNXRT.
  - Cache content
    - Serialization of a list of `ComputeCapability`
      - Not EP-specific
      - Serialized using `onnx::FunctionProto`
    - EP-specific cache
- Compilation based EPs
  - Examples
    - `QNNExecutionProvider`
    - `TensorrtExecutionProvider`
    - `MIGraphXExecutionProvider`
    - `OpenVINOExecutionProvider`
  - Cache content
    - EP-specific cache

# Requirements
- Offline / AOT compilation of ONNX models with EP context cache
- Compile somewhere, run everywhere
- Pseudo code with brief explanation
  ```
  GenerateCache(original_onnx_file, cache_onnx_file) model_buffer = load(original_onnx_file) --> Load the original ONNX model file
    model_buffer = decrypt(model_buffer)
session_options = { kOrtSessionOptionEpContextEnable: true,
kOrtSessionOptionEpContextFilePath: temp_file } --> Set necessary configs
Ort::CreateSessionFromArray(model_buffer, session_options) --> The new ONNX model with EP context is created and dumped into the user specified file "temp_file"
    temp_buffer = encrypt(temp_file)
write(temp_buffer, cache_onnx_file) --> Write the encypted context of "temp_file" into the "cache_onnx_file" file


  InitializeInferenceSession(cache_onnx_file)
model_buffer = load(cache_onnx_file) --> Load the ONNX model with EP context from the file generated in the previous step
    model_buffer = decrypt(model_buffer)
    session_options = { }
Ort::CreateSessionFromArray(model_buffer, session_options) --> Create and initalize an session with the EP context model
  ```
- Python code with comments
  - EP context model creation
    ```python
    import onnxruntime as onnxrt


    # Session options for creating an ONNX model with EP context cache.
    sess_opts = onnxrt.SessionOptions()

    # Verbose.
    sess_opts.log_severity_level = 0

    # This is REQUIRED.
    sess_opts.add_session_config_entry("ep.context_enable", "1")
    # This is OPTIONAL.
# Either an absolute path (preferred for now) or a relative path (WIP)
is okay.
# sess_opts.add_session_config_entry("ep.context_file_path",
"/some/path/to/original_model_ctx.onnx")
    # This is OPTIONAL.
    sess_opts.add_session_config_entry("ep.context_embed_mode", "1")

    orig_model_location = "/some/path/to/original_model.onnx"
sess = onnxrt.InferenceSession(orig_model_location, sess_opts,
providers=["VitisAIExecutionProvider"], provider_options=[])
    ```
  - Inference run with an EP context model
    ```python
    import onnxruntime as onnxrt


    # Session options for creating an ONNX model with EP context cache.
    sess_opts = onnxrt.SessionOptions()

    # Default EP context model path.
# ep_ctx_model_location = "/some/path/to/origina_model.onnx_ctx.onnx"
    # User configured EP context model path.
    ep_ctx_model_location = "/some/path/to/origina_model_ctx.onnx"
sess = onnxrt.InferenceSession(ep_ctx_model_location, sess_opts,
providers=["VitisAIExecutionProvider"], provider_options=[])

    model_inputs = {}
    run_opts = onnxrt.RunOptions()
    # Verbose.
    run_opts.log_severity_level = 1
    sess.run(None, model_inputs, run_opts)
    ```

---------

Co-authored-by: Glen Cao <[email protected]>
  • Loading branch information
glen-amd and Glen Cao authored Jul 13, 2024
1 parent 92a8407 commit 281ed8c
Show file tree
Hide file tree
Showing 13 changed files with 1,195 additions and 16 deletions.
2 changes: 2 additions & 0 deletions include/onnxruntime/core/graph/basic_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class TensorProto;
class SparseTensorProto;
class TypeProto;
class AttributeProto;
class FunctionProto;
class OperatorSetIdProto;
// define types that would come from the ONNX library if we were building against it.
#if defined(ORT_MINIMAL_BUILD)
using OperatorSetVersion = int;
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ struct NodeProto;
struct SparseTensorProto;
struct StringStringEntryProto;
struct StringStringEntryProtos; // RepeatedPtrField
struct OperatorSetIdProto;
struct TensorProto;
struct TensorProtos; // RepeatedPtrField
struct TensorShapeProto_Dimension;
Expand All @@ -120,6 +121,7 @@ struct TypeProto_Sequence;
struct TypeProto;
struct ValueInfoProto;
struct ValueInfoProtos; // RepeatedPtrField
struct FunctionProto;
struct InferenceContext;
class GraphInferencer;
using InferenceFunction = std::function<void(InferenceContext&)>;
Expand All @@ -146,6 +148,7 @@ struct ConfigOptions;
struct DataTransferManager;
struct IndexedSubGraph;
struct IndexedSubGraph_MetaDef;
enum class IndexedSubGraph_SourceOfSchema : uint8_t;
struct KernelCreateInfo;
struct KernelDef;
struct KernelDefBuilder;
Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ struct ProviderHost {
virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0;

// OperatorSetIdProto
virtual std::string* OperatorSetIdProto__mutable_domain(ONNX_NAMESPACE::OperatorSetIdProto* p) = 0;
virtual void OperatorSetIdProto__set_version(ONNX_NAMESPACE::OperatorSetIdProto* p, int64_t version) = 0;
virtual int64_t OperatorSetIdProto__version(const ONNX_NAMESPACE::OperatorSetIdProto* p) = 0;

#if !defined(DISABLE_OPTIONAL_TYPE)
// TypeProto_Optional
virtual const ONNX_NAMESPACE::TypeProto& TypeProto_Optional__elem_type(const ONNX_NAMESPACE::TypeProto_Optional* p) = 0;
Expand Down Expand Up @@ -420,13 +425,19 @@ struct ProviderHost {
virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0;

virtual const ONNX_NAMESPACE::OperatorSetIdProto& ModelProto__opset_import(const ONNX_NAMESPACE::ModelProto* p, int index) = 0;
virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__mutable_opset_import(ONNX_NAMESPACE::ModelProto* p, int index) = 0;
virtual int ModelProto__opset_import_size(const ONNX_NAMESPACE::ModelProto* p) = 0;
virtual ONNX_NAMESPACE::OperatorSetIdProto* ModelProto__add_opset_import(ONNX_NAMESPACE::ModelProto* p) = 0;

// NodeProto
virtual std::unique_ptr<ONNX_NAMESPACE::NodeProto> NodeProto__construct() = 0;
virtual void NodeProto__operator_delete(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual void NodeProto__operator_assign(ONNX_NAMESPACE::NodeProto* p, const ONNX_NAMESPACE::NodeProto& v) = 0;
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__add_attribute(ONNX_NAMESPACE::NodeProto* p) = 0;

// TensorProto
virtual std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() = 0;
Expand Down Expand Up @@ -495,6 +506,64 @@ struct ProviderHost {

virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0;

// FunctionProto
virtual std::unique_ptr<ONNX_NAMESPACE::FunctionProto> FunctionProto__construct() = 0;
virtual void FunctionProto__operator_delete(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual bool FunctionProto__SerializeToString(const ONNX_NAMESPACE::FunctionProto* p, std::string& string) = 0;
virtual bool FunctionProto__SerializeToOstream(const ONNX_NAMESPACE::FunctionProto* p, std::ostream& output) = 0;
virtual bool FunctionProto__ParseFromString(ONNX_NAMESPACE::FunctionProto* p, const std::string& data) = 0;
virtual std::string FunctionProto__SerializeAsString(const ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual bool FunctionProto__has_name(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__name(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_name(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& name) = 0;

virtual bool FunctionProto__has_doc_string(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__doc_string(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_doc_string(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& doc_string) = 0;

virtual bool FunctionProto__has_domain(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual const std::string& FunctionProto__domain(const ONNX_NAMESPACE::FunctionProto* p) const = 0;
virtual void FunctionProto__set_domain(ONNX_NAMESPACE::FunctionProto* p, const ::std::string& domain) = 0;

virtual const std::string& FunctionProto__input(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_input(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__input_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_input(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const std::string& FunctionProto__output(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_output(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__output_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_output(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const std::string& FunctionProto__attribute(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual std::string* FunctionProto__mutable_attribute(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__attribute_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual void FunctionProto__add_attribute(ONNX_NAMESPACE::FunctionProto* p, const std::string& value) = 0;

virtual const ONNX_NAMESPACE::AttributeProto& FunctionProto__attribute_proto(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__mutable_attribute_proto(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__attribute_proto_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::AttributeProto* FunctionProto__add_attribute_proto(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::NodeProto& FunctionProto__node(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__mutable_node(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__node_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* FunctionProto__add_node(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::ValueInfoProto& FunctionProto__value_info(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::ValueInfoProtos* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__mutable_value_info(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__value_info_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::ValueInfoProto* FunctionProto__add_value_info(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual const ONNX_NAMESPACE::StringStringEntryProto& FunctionProto__metadata_props(const ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProtos* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__mutable_metadata_props(ONNX_NAMESPACE::FunctionProto* p, int index) = 0;
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0;

// ConfigOptions
Expand Down Expand Up @@ -546,6 +615,9 @@ struct ProviderHost {
virtual void IndexedSubGraph__SetMetaDef(IndexedSubGraph* p, std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) = 0;
virtual const IndexedSubGraph_MetaDef* IndexedSubGraph__GetMetaDef(const IndexedSubGraph* p) = 0;

virtual void IndexedSubGraph__SetSchemaSource(IndexedSubGraph* p, IndexedSubGraph_SourceOfSchema schema_source) = 0;
virtual IndexedSubGraph_SourceOfSchema IndexedSubGraph__GetSchemaSource(const IndexedSubGraph* p) = 0;

// KernelDef
virtual void KernelDef__operator_delete(KernelDef* p) = 0;
virtual int KernelDef__ExecQueueId(const KernelDef* p) = 0;
Expand Down
87 changes: 87 additions & 0 deletions onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ struct StringStringEntryProtos final {

PROVIDER_DISALLOW_ALL(StringStringEntryProtos)
};

struct OperatorSetIdProto final {
std::string* mutable_domain() { return g_host->OperatorSetIdProto__mutable_domain(this); }
void set_version(int64_t version) { return g_host->OperatorSetIdProto__set_version(this, version); }
int64_t version() { return g_host->OperatorSetIdProto__version(this); }

PROVIDER_DISALLOW_ALL(OperatorSetIdProto)
};

struct AttributeProto final {
static std::unique_ptr<AttributeProto> Create() { return g_host->AttributeProto__construct(); }
void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); }
Expand Down Expand Up @@ -178,6 +187,11 @@ struct ModelProto final {

void set_ir_version(int64_t value) { return g_host->ModelProto__set_ir_version(this, value); }

const OperatorSetIdProto& opset_import(int index) const { return g_host->ModelProto__opset_import(this, index); }
OperatorSetIdProto* mutable_opset_import(int index) { return g_host->ModelProto__mutable_opset_import(this, index); }
int opset_import_size() const { return g_host->ModelProto__opset_import_size(this); }
OperatorSetIdProto* add_opset_import() { return g_host->ModelProto__add_opset_import(this); }

ModelProto() = delete;
ModelProto(const ModelProto&) = delete;
void operator=(const ModelProto&) = delete;
Expand All @@ -190,6 +204,7 @@ struct NodeProto final {
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }
AttributeProto* add_attribute() { return g_host->NodeProto__add_attribute(this); }

NodeProto() = delete;
NodeProto(const NodeProto&) = delete;
Expand Down Expand Up @@ -372,6 +387,69 @@ struct ValueInfoProtos final {

PROVIDER_DISALLOW_ALL(ValueInfoProtos)
};

struct FunctionProto final {
static std::unique_ptr<FunctionProto> Create() { return g_host->FunctionProto__construct(); }
static void operator delete(void* p) { g_host->FunctionProto__operator_delete(reinterpret_cast<FunctionProto*>(p)); }

bool SerializeToString(std::string& string) const { return g_host->FunctionProto__SerializeToString(this, string); }
bool SerializeToOstream(std::ostream& output) const { return g_host->FunctionProto__SerializeToOstream(this, output); }
bool ParseFromString(const std::string& data) { return g_host->FunctionProto__ParseFromString(this, data); }
std::string SerializeAsString() const { return g_host->FunctionProto__SerializeAsString(this); }

bool has_name() const { return g_host->FunctionProto__has_name(this); }
const std::string& name() const { return g_host->FunctionProto__name(this); }
void set_name(const std::string& name) { g_host->FunctionProto__set_name(this, name); }

bool has_doc_string() const { return g_host->FunctionProto__has_doc_string(this); }
const std::string& doc_string() const { return g_host->FunctionProto__doc_string(this); }
void set_doc_string(const std::string& doc_string) { g_host->FunctionProto__set_doc_string(this, doc_string); }

bool has_domain() const { return g_host->FunctionProto__has_domain(this); }
const std::string& domain() const { return g_host->FunctionProto__domain(this); }
void set_domain(const std::string& domain) { g_host->FunctionProto__set_domain(this, domain); }

const std::string& input(int index) const { return g_host->FunctionProto__input(this, index); }
std::string* mutable_input(int index) { return g_host->FunctionProto__mutable_input(this, index); }
int input_size() const { return g_host->FunctionProto__input_size(this); }
void add_input(const std::string& value) { g_host->FunctionProto__add_input(this, value); }

const std::string& output(int index) const { return g_host->FunctionProto__output(this, index); }
std::string* mutable_output(int index) { return g_host->FunctionProto__mutable_output(this, index); }
int output_size() const { return g_host->FunctionProto__output_size(this); }
void add_output(const std::string& value) { g_host->FunctionProto__add_output(this, value); }

const std::string& attribute(int index) const { return g_host->FunctionProto__attribute(this, index); }
std::string* mutable_attribute(int index) { return g_host->FunctionProto__mutable_attribute(this, index); }
int attribute_size() const { return g_host->FunctionProto__attribute_size(this); }
void add_attribute(const std::string& value) { g_host->FunctionProto__add_attribute(this, value); }

const AttributeProto& attribute_proto(int index) const { return g_host->FunctionProto__attribute_proto(this, index); }
AttributeProto* mutable_attribute_proto(int index) { return g_host->FunctionProto__mutable_attribute_proto(this, index); }
int attribute_proto_size() const { return g_host->FunctionProto__attribute_proto_size(this); }
AttributeProto* add_attribute_proto() { return g_host->FunctionProto__add_attribute_proto(this); }

const NodeProto& node(int index) const { return g_host->FunctionProto__node(this, index); }
NodeProto* mutable_node(int index) { return g_host->FunctionProto__mutable_node(this, index); }
int node_size() const { return g_host->FunctionProto__node_size(this); }
NodeProto* add_node() { return g_host->FunctionProto__add_node(this); }

const ValueInfoProto& value_info(int index) const { return g_host->FunctionProto__value_info(this, index); }
ValueInfoProtos* mutable_value_info() { return g_host->FunctionProto__mutable_value_info(this); }
ValueInfoProto* mutable_value_info(int index) { return g_host->FunctionProto__mutable_value_info(this, index); }
int value_info_size() const { return g_host->FunctionProto__value_info_size(this); }
ValueInfoProto* add_value_info() { return g_host->FunctionProto__add_value_info(this); }

const StringStringEntryProto& metadata_props(int index) const { return g_host->FunctionProto__metadata_props(this, index); }
StringStringEntryProtos* mutable_metadata_props() { return g_host->FunctionProto__mutable_metadata_props(this); }
StringStringEntryProto* mutable_metadata_props(int index) { return g_host->FunctionProto__mutable_metadata_props(this, index); }
int metadata_props_size() const { return g_host->FunctionProto__metadata_props_size(this); }
StringStringEntryProto* add_metadata_props() { return g_host->FunctionProto__add_metadata_props(this); }

FunctionProto() = delete;
FunctionProto(const FunctionProto&) = delete;
void operator=(const FunctionProto&) = delete;
};
} // namespace ONNX_NAMESPACE

namespace onnxruntime {
Expand Down Expand Up @@ -449,6 +527,12 @@ struct IndexedSubGraph_MetaDef final {
void operator=(const IndexedSubGraph_MetaDef&) = delete;
};

enum class IndexedSubGraph_SourceOfSchema : uint8_t {
CREATE,
REUSE_OR_CREATE,
EXISTING,
};

struct IndexedSubGraph final {
static std::unique_ptr<IndexedSubGraph> Create() { return g_host->IndexedSubGraph__construct(); }
static void operator delete(void* p) { g_host->IndexedSubGraph__operator_delete(reinterpret_cast<IndexedSubGraph*>(p)); }
Expand All @@ -458,6 +542,9 @@ struct IndexedSubGraph final {
void SetMetaDef(std::unique_ptr<IndexedSubGraph_MetaDef>&& meta_def_) { return g_host->IndexedSubGraph__SetMetaDef(this, std::move(*reinterpret_cast<std::unique_ptr<IndexedSubGraph_MetaDef>*>(&meta_def_))); }
const IndexedSubGraph_MetaDef* GetMetaDef() const { return reinterpret_cast<const IndexedSubGraph_MetaDef*>(g_host->IndexedSubGraph__GetMetaDef(this)); }

void SetSchemaSource(IndexedSubGraph_SourceOfSchema schema_source) { return g_host->IndexedSubGraph__SetSchemaSource(this, schema_source); }
IndexedSubGraph_SourceOfSchema GetSchemaSource() const { return g_host->IndexedSubGraph__GetSchemaSource(this); }

IndexedSubGraph() = delete;
IndexedSubGraph(const IndexedSubGraph&) = delete;
void operator=(const IndexedSubGraph&) = delete;
Expand Down
Loading

0 comments on commit 281ed8c

Please sign in to comment.