From 4aca8f33df95f2326ea3f6b8337b5e2ca53f0b96 Mon Sep 17 00:00:00 2001 From: mingyue <131847423+mingyueliuh@users.noreply.github.com> Date: Fri, 20 Dec 2024 08:47:13 +0800 Subject: [PATCH] [Bug Fix] Missing CustomOp SchemaRegister when generator EPContext ONNX model (#23091) ### Description Enhancements to EPContext Operations: 1. Introduced support for the bfloat16 data type in EPContext operations. 2. Bug Fix: Missing Custom OP Schema Registration when generator EPContext ONNX model --------- Co-authored-by: mingyue Co-authored-by: Hector Li --- docs/ContribOperators.md | 2 +- onnxruntime/core/framework/graph_partitioner.cc | 2 +- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 6ea3f93cdea12..2290030073e5c 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1625,7 +1625,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double)
+
T : tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
Constrain input and output types.
diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 406fc1b15effc..b97cf03e3bf59 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -681,7 +681,7 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers context_cache_path, "' exist already."); } - Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + Model ep_context_model(graph.Name(), false, graph.GetModel().MetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList{graph.GetSchemaRegistry()}, graph.DomainToVersionMap(), {}, logger); auto& ep_graph = ep_context_model.MainGraph(); ep_graph.SetDescription(graph.Description()); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index c7a0793c4748f..d78fe7111c9be 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3371,7 +3371,8 @@ void RegisterContribSchemas() { "tensor(uint64)", "tensor(float16)", "tensor(float)", - "tensor(double)"}, + "tensor(double)", + "tensor(bfloat16)"}, "Constrain input and output types."); static const char* BitmaskDropout_ver1_doc = R"DOC(