diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f504ca1255059..cc1f5fc546944 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -763,6 +763,10 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi */ bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const; + /** Gets all the initializer tensors in this Graph. */ const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return name_to_initial_tensor_; } @@ -1449,15 +1453,15 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi const OrtFormatLoadOptions& load_options, const logging::Logger& logger, std::unique_ptr& graph); - static Status LoadFromGraphApiModel(const OrtGraph& api_graph, - const Model& owning_model, - const std::unordered_map& domain_to_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - bool strict_shape_type_inference, - const logging::Logger& logger, - std::unique_ptr& graph); + static Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph); - Status UpdateUsingGraphApiModel(const OrtModel& api_model); + Status UpdateUsingModelBuilderApiModel(const OrtModel& api_model); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const RuntimeOptimizationRecordContainer& RuntimeOptimizations() const { @@ -1699,7 +1703,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return nodes_[node_index].get(); } - Status LoadFromGraphApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); + Status LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updating_existing_graph = false); const Model& owning_model_; @@ -1715,6 +1719,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi InitializedTensorSet name_to_initial_tensor_; + // Initializers that are external to the Graph. e.g. created using Model Builder API from existing memory. + // As we need to convert to TensorProto for the optimizers to work and keep the deleter information we store them + // in the Graph instance and retrieve during session state finalization. + std::unordered_map ortvalue_initializers_; + std::unordered_set, std::hash, std::equal_to> sparse_tensor_names_; @@ -1736,6 +1745,7 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi // in some case, a fused sub-graph will happens multiple times in one model, we use a map // to store reusable-schema in lookup. InlinedHashMap> reusable_fused_schema_map_; + #endif // !defined(ORT_MINIMAL_BUILD) // Graph nodes. diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 9385e2f092e58..6a664d8be9c05 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -193,6 +193,12 @@ class GraphViewer { IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const { return graph_->GetSchemaRegistry(); } #endif + /** Populate `value` if an externally allocated OrtValue exists for an initializer with the given name. + */ + bool GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + return graph_->GetOrtValueInitializer(name, value); + } + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f99ea2bb719ec..2b72b780b4ca2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5160,22 +5160,23 @@ struct OrtModelBuilderApi { * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue * with a tensor that contains a pointer to the existing data. * User must keep pointer valid for lifetime of the inference session. + * Set `data_is_external` to true. * * Allocated memory: * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. - * ORT will own the memory. - * - * + * Set `data_is_external` to false. * * \param[in] graph The OrtGraph instance to update. * \param[in] name The value name for the initializer. * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.21. */ - ORT_API2_STATUS(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor); + ORT_API2_STATUS(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); /** \brief Add an OrtNode to an OrtGraph * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 7b6bc9ccefbc8..f099e415faefe 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1785,6 +1785,19 @@ struct Value : detail::ValueImpl { const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param deleter OrtAllocator that will be used to free the buffer when no longer required. + * \param p_data Pointer to the data buffer. + * \param p_data_byte_count The number of bytes in the data buffer. + * \param shape Pointer to the tensor shape dimensions. + * \param shape_len The number of tensor shape dimensions. + * \param type The data type. + */ + static Value CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); + /** \brief Creates an OrtValue with a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue. * This overload will allocate the buffer for the tensor according to the supplied shape and data type. * The allocated buffer will be owned by the returned OrtValue and will be freed when the OrtValue is released. @@ -1810,7 +1823,8 @@ struct Value : detail::ValueImpl { * \param shape_len The number of tensor shape dimensions. * \param type The data type. */ - static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); + static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type); /** \brief Creates an OrtValue with a Map Onnx type representation. * The API would ref-count the supplied OrtValues and they will be released @@ -2542,9 +2556,6 @@ struct ValueInfoImpl : Ort::detail::Base { std::string Name() const; ConstTypeInfo TypeInfo() const; - - template - bool operator==(const ValueInfoImpl& o) const; }; } // namespace detail @@ -2570,9 +2581,6 @@ template struct NodeImpl : Ort::detail::Base { using B = Ort::detail::Base; using B::B; - - template - bool operator==(const NodeImpl& o) const; }; } // namespace detail @@ -2619,11 +2627,8 @@ struct GraphImpl : Ort::detail::Base { void SetInputs(std::vector& inputs); void SetOutputs(std::vector& outputs); - void AddInitializer(const std::string& name, Value& initializer); // Graph takes ownership of Value - void AddNode(Node& node); // Graph takes ownership of Node - - template - bool operator==(const GraphImpl& o) const; + void AddInitializer(const std::string& name, Value& initializer, bool data_is_external); // Graph takes ownership of Value + void AddNode(Node& node); // Graph takes ownership of Node }; } // namespace detail @@ -2648,9 +2653,6 @@ struct ModelImpl : Ort::detail::Base { using B::B; void AddGraph(Graph& graph); - - template - bool operator==(const ModelImpl& o) const; }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index cab4dfb373318..09c8a71a464ae 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1707,23 +1707,35 @@ void ValueImpl::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_inf } // namespace detail template -inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) { +inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, + const int64_t* shape, size_t shape_len) { return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, +inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out)); return Value{out}; } +inline Value Value::CreateTensor(OrtAllocator* deleter, void* p_data, size_t p_data_byte_count, + const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { + OrtValue* out; + ThrowOnError(GetApi().CreateTensorWithDataAndDeleterAsOrtValue(deleter, p_data, p_data_byte_count, + shape, shape_len, type, &out)); + return Value{out}; +} + template inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) { return CreateTensor(allocator, shape, shape_len, TypeToTensorType::type); } -inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) { +inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out)); return Value{out}; @@ -1741,7 +1753,8 @@ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& values_shape, ONNXTensorElementDataType type) { OrtValue* out; ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len, - values_shape.shape, values_shape.shape_len, type, &out)); + values_shape.shape, values_shape.shape_len, type, + &out)); return Value{out}; } @@ -2425,9 +2438,9 @@ inline void GraphImpl::SetOutputs(std::vector& outputs) { std::for_each(outputs.begin(), outputs.end(), [](ValueInfo& vi) { vi.release(); }); } -inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer) { +inline void GraphImpl::AddInitializer(const std::string& name, Value& initializer, bool data_is_external) { // Graph takes ownership of `initializer` - ThrowOnError(GetModelBuilderApi().AddInitializerToGraph(p_, name.c_str(), initializer.release())); + ThrowOnError(GetModelBuilderApi().AddInitializerToGraph(p_, name.c_str(), initializer.release(), data_is_external)); } inline void GraphImpl::AddNode(Node& node) { diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index 2c74805c57dce..3e77ed30fd620 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -200,13 +200,12 @@ static common::Status DeserializeTensorProto(const Env& env, const std::basic_st } } -common::Status AllocateTensor( - const onnxruntime::MemBuffer* m, - std::unique_ptr& p_tensor, - const onnxruntime::DataTypeImpl* const& type, - onnxruntime::TensorShape& tensor_shape, - bool use_device_allocator_for_initializers, - const onnxruntime::AllocatorPtr& alloc) { +common::Status AllocateTensor(const onnxruntime::MemBuffer* m, + std::unique_ptr& p_tensor, + const onnxruntime::DataTypeImpl* const& type, + onnxruntime::TensorShape& tensor_shape, + bool use_device_allocator_for_initializers, + const onnxruntime::AllocatorPtr& alloc) { if (m != nullptr) { p_tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); if (m->GetLen() < p_tensor->SizeInBytes()) { @@ -350,6 +349,7 @@ common::Status SaveInitializedTensors( } ORT_RETURN_IF_ERROR(planner.Trace(entry.first, entry.second)); } + // 2. allocate weight buffer on different locations // planned_initializers_memory_size_in_byte is not actual physical size. // It's the virtual size computed by planner. @@ -382,6 +382,9 @@ common::Status SaveInitializedTensors( if (user_supplied_initializer_ids.find(entry.first) != user_supplied_initializer_ids.end()) { ort_value = *(session_options.initializers_to_share_map.at(name)); LOGS(logger, INFO) << "Using user supplied initializer with name (" << name << ")."; + + } else if (graph.GetOrtValueInitializer(name, ort_value)) { + // populated OrtValue from the Graph instance } else { const ONNX_NAMESPACE::TensorProto& tensor_proto = *(entry.second); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 58449a2820507..87fec65c65908 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3487,6 +3487,11 @@ void Graph::RemoveInitializedTensor(const std::string& tensor_name) { #if !defined(DISABLE_SPARSE_TENSORS) sparse_tensor_names_.erase(tensor_name); #endif + + if (auto it = ortvalue_initializers_.find(tensor_name); it != ortvalue_initializers_.end()) { + ortvalue_initializers_.erase(it); + } + SetGraphResolveNeeded(); } else { #if !defined(DISABLE_SPARSE_TENSORS) @@ -3618,8 +3623,18 @@ Status Graph::InjectExternalInitializersFromFilesInMemory( return Status::OK(); } -#endif // DISABLE_EXTERNAL_INITIALIZERS +bool Graph::GetOrtValueInitializer(const std::string& name, OrtValue& value) const { + auto it = ortvalue_initializers_.find(name); + if (it == ortvalue_initializers_.end()) { + return false; + } + + value = it->second; + return true; +} + +#endif // DISABLE_EXTERNAL_INITIALIZERS #endif // !defined(ORT_MINIMAL_BUILD) bool Graph::GetInitializedTensor(const std::string& tensor_name, const TensorProto*& value) const { @@ -3646,6 +3661,8 @@ void Graph::CleanAllInitializedTensors() noexcept { for (int i = 0; i < num_cleared; i++) { delete graph_proto_->mutable_initializer()->ReleaseCleared(); } + + ortvalue_initializers_.clear(); } const ONNX_NAMESPACE::TensorProto* Graph::GetConstantInitializer(const std::string& initializer_name, @@ -5580,6 +5597,7 @@ common::Status Graph::LoadFromOrtFormat(const onnxruntime::fbs::Graph& fbs_graph return Status::OK(); } +#if !defined(ORT_MINIMAL_BUILD) namespace { ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { // the model builder API checks that the OrtValueInfo has a complete and valid OrtTypeInfo instance and that the @@ -5612,10 +5630,9 @@ ValueInfoProto OrtValueInfoToOnnx(const OrtValueInfo& vi) { return value_info_proto; } - } // namespace -Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { +Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, bool updating_existing_graph) { ArgNameToTypeMap name_to_type_map; // NOTE: need to create NodeArgs as we go along @@ -5646,57 +5663,63 @@ Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, bool updating_exi } }; - // process graph inputs first as we want the type/shape from them to be preferred if a graph input - // has a matching initializer - add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); - - // add initializers - for (const auto& name_and_ortvalue : api_graph.initializers) { - // convert from OrtValue to TensorProto - const std::string& name = name_and_ortvalue.first; - const OrtValue& v = *name_and_ortvalue.second; - - ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); - const Tensor& t = v.Get(); - TensorProto& tensor_proto = *graph_proto_->add_initializer(); - - tensor_proto.set_name(name); - tensor_proto.set_data_type(t.GetElementType()); - for (auto dim : t.Shape().GetDims()) { - tensor_proto.add_dims(dim); - } - - // we're assuming that CreateTensorWithDataAsOrtValue or CreateTensorAsOrtValue was used to create the OrtValue. - // based on that we're inferring whether the Tensor in the OrtValue owns the buffer. - // TODO: Is this robust? Do we need something more explicit? - const bool is_internal_data = t.OwnsBuffer(); + auto add_initializers = [this](const std::unordered_map>& initializers, + bool is_external) { + for (auto& name_and_ortvalue : initializers) { + // convert from OrtValue to TensorProto + const std::string& name = name_and_ortvalue.first; + OrtValue& v = *name_and_ortvalue.second; + + ORT_ENFORCE(v.IsTensor(), "Initializers must be Tensors"); + const Tensor& t = v.Get(); + TensorProto& tensor_proto = *graph_proto_->add_initializer(); + + tensor_proto.set_name(name); + tensor_proto.set_data_type(t.GetElementType()); + for (auto dim : t.Shape().GetDims()) { + tensor_proto.add_dims(dim); + } - if (is_internal_data) { - tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); - } else { - // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo - tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + if (is_external) { + // pre-existing memory that we don't own. avoid a copy by storing the pointer in the ExternalDataInfo + tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + + const void* data_offset = t.DataRaw(); // address of memory not offset into file + auto offset = narrow(reinterpret_cast(data_offset)); + + ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("location"); + // magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory + entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("offset"); + entry->set_value(std::to_string(offset)); + entry = tensor_proto.mutable_external_data()->Add(); + entry->set_key("length"); + entry->set_value(std::to_string(t.SizeInBytes())); + + // copy OrtValue to keep it alive and to store the deleter if provided. + ortvalue_initializers_.emplace(name, v); + v = OrtValue{}; // reset as we have taken a copy + } else { + tensor_proto.set_raw_data(t.DataRaw(), t.SizeInBytes()); + } - const void* data_offset = t.DataRaw(); // address of memory not offset into file - auto offset = narrow(reinterpret_cast(data_offset)); + TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); - ONNX_NAMESPACE::StringStringEntryProto* entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("location"); - // magic tag for existing memory that causes 'offset' to be treated as a pointer to the memory - entry->set_value(ToUTF8String(onnxruntime::utils::kTensorProtoMemoryAddressTag)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("offset"); - entry->set_value(std::to_string(offset)); - entry = tensor_proto.mutable_external_data()->Add(); - entry->set_key("length"); - entry->set_value(std::to_string(t.SizeInBytes())); + name_to_initial_tensor_.emplace(name, &tensor_proto); } + }; - TypeProto type_proto{TypeProtoFromTensorProto(tensor_proto)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(name, &type_proto)); + // process graph inputs first as we want the type/shape from them to be preferred if a graph input + // has a matching initializer + add_graph_inputs_outputs(api_graph.inputs, /*input*/ true); - name_to_initial_tensor_.emplace(name, &tensor_proto); - } + // add initializers + ortvalue_initializers_.reserve(api_graph.external_initializers.size()); + add_initializers(api_graph.external_initializers, /*is_external*/ true); + add_initializers(api_graph.initializers, /*is_external*/ false); // add graph outputs add_graph_inputs_outputs(api_graph.outputs, /*input*/ false); @@ -5753,12 +5776,13 @@ Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, bool updating_exi } // static -Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, const Model& owning_model, - const std::unordered_map& domain_to_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - bool strict_shape_type_inference, - const logging::Logger& logger, - std::unique_ptr& graph) { +Status Graph::LoadFromModelBuilderApiModel(const OrtGraph& api_graph, + const Model& owning_model, + const std::unordered_map& domain_to_version, + IOnnxRuntimeOpSchemaCollectionPtr schema_registry, + bool strict_shape_type_inference, + const logging::Logger& logger, + std::unique_ptr& graph) { graph = std::make_unique(owning_model, domain_to_version, schema_registry, @@ -5766,10 +5790,10 @@ Status Graph::LoadFromGraphApiModel(const OrtGraph& api_graph, const Model& owni logger, strict_shape_type_inference); - return graph->LoadFromGraphApiModel(api_graph); + return graph->LoadFromModelBuilderApiModel(api_graph); } -Status Graph::UpdateUsingGraphApiModel(const OrtModel& api_model) { +Status Graph::UpdateUsingModelBuilderApiModel(const OrtModel& api_model) { for (auto& entry : api_model.domain_to_version) { if (auto it = domain_to_version_.find(entry.first); it != domain_to_version_.end()) { if (it->second != entry.second) { @@ -5782,8 +5806,9 @@ Status Graph::UpdateUsingGraphApiModel(const OrtModel& api_model) { } } - // this will replace all inputs/outputs and add nodes. - return LoadFromGraphApiModel(*api_model.graph, /*updating_existing_graph*/ true); + // this will replace inputs/outputs and add nodes. + return LoadFromModelBuilderApiModel(*api_model.graph, /*updating_existing_graph*/ true); } +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index b3ccef0dc51e9..58dcf7d6ab257 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -929,11 +929,11 @@ common::Status Model::LoadFromOrtFormat(const fbs::Model& fbs_model, } // static -common::Status Model::LoadFromGraphApiModel(const OrtModel& graph_api_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries, - const ModelOptions& options, - const logging::Logger& logger, - std::unique_ptr& model) { +common::Status Model::LoadFromModelBuilderApiModel(const OrtModel& graph_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model) { model = std::make_unique(); model->model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); // The optimizer Initializer class requires a path if external data is used, however in the Graph API usage the @@ -947,13 +947,13 @@ common::Status Model::LoadFromGraphApiModel(const OrtModel& graph_api_model, } } - ORT_RETURN_IF_ERROR(Graph::LoadFromGraphApiModel(*graph_api_model.graph, - *model, - graph_api_model.domain_to_version, - schema_registry, - options.strict_shape_type_inference, - logger, - model->graph_)); + ORT_RETURN_IF_ERROR(Graph::LoadFromModelBuilderApiModel(*graph_api_model.graph, + *model, + graph_api_model.domain_to_version, + schema_registry, + options.strict_shape_type_inference, + logger, + model->graph_)); return Status::OK(); } diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 5acbf634fc6ee..edb9cd1f2f918 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -305,11 +305,11 @@ class Model { const logging::Logger& logger, const ModelOptions& options = {}); - static common::Status LoadFromGraphApiModel(const OrtModel& graph_api_model, - const IOnnxRuntimeOpSchemaRegistryList* local_registries, - const ModelOptions& options, - const logging::Logger& logger, - std::unique_ptr& model); + static common::Status LoadFromModelBuilderApiModel(const OrtModel& graph_api_model, + const IOnnxRuntimeOpSchemaRegistryList* local_registries, + const ModelOptions& options, + const logging::Logger& logger, + std::unique_ptr& model); common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; diff --git a/onnxruntime/core/graph/model_builder_api_types.h b/onnxruntime/core/graph/model_builder_api_types.h index f1a634a591513..acc29beca0d8d 100644 --- a/onnxruntime/core/graph/model_builder_api_types.h +++ b/onnxruntime/core/graph/model_builder_api_types.h @@ -38,6 +38,7 @@ struct OrtGraph { std::vector> inputs; std::vector> outputs; std::unordered_map> initializers; + std::unordered_map> external_initializers; std::vector> nodes; }; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 4c21aa75e5c64..01b3a43dca951 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1214,10 +1214,10 @@ common::Status InferenceSession::Load(const OrtModel& graph_api_model) { // need to go from unique_ptr to shared_ptr when moving into model_ std::unique_ptr tmp_model; - ORT_RETURN_IF_ERROR(Model::LoadFromGraphApiModel(graph_api_model, - HasLocalSchema() ? &custom_schema_registries_ : nullptr, - ModelOptions(true, strict_shape_type_inference), - *session_logger_, tmp_model)); + ORT_RETURN_IF_ERROR(Model::LoadFromModelBuilderApiModel(graph_api_model, + HasLocalSchema() ? &custom_schema_registries_ : nullptr, + ModelOptions(true, strict_shape_type_inference), + *session_logger_, tmp_model)); model_ = std::move(tmp_model); @@ -1241,7 +1241,7 @@ common::Status InferenceSession::ApplyUpdates(const OrtModel& graph_api_model) { return status; } - return model_->MainGraph().UpdateUsingGraphApiModel(graph_api_model); + return model_->MainGraph().UpdateUsingModelBuilderApiModel(graph_api_model); } common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { diff --git a/onnxruntime/core/session/model_builder_api.h b/onnxruntime/core/session/model_builder_api.h index 7da22909858a5..7c03c7f05e887 100644 --- a/onnxruntime/core/session/model_builder_api.h +++ b/onnxruntime/core/session/model_builder_api.h @@ -21,7 +21,8 @@ ORT_API_STATUS_IMPL(SetGraphInputs, _In_ OrtGraph* graph, _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); ORT_API_STATUS_IMPL(SetGraphOutputs, _In_ OrtGraph* graph, _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); -ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor); +ORT_API_STATUS_IMPL(AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor, + bool data_is_external); ORT_API_STATUS_IMPL(AddNodeToGraph, _In_ OrtGraph* graph, _Inout_ OrtNode* node); ORT_API(void, ReleaseGraph, _Frees_ptr_opt_ OrtGraph* graph); diff --git a/onnxruntime/core/session/model_builder_c_api.cc b/onnxruntime/core/session/model_builder_c_api.cc index 083f7d25b881f..7fd5b897367cb 100644 --- a/onnxruntime/core/session/model_builder_c_api.cc +++ b/onnxruntime/core/session/model_builder_c_api.cc @@ -112,7 +112,8 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::CreateGraph, _Outptr_ OrtGraph** graph) // do some reserves to reduce reallocation. if we had a hint about sizes upfront that would be optimal g->inputs.reserve(8); g->outputs.reserve(8); - g->initializers.reserve(64); + g->initializers.reserve(32); + g->external_initializers.reserve(32); g->nodes.reserve(64); *graph = g.release(); @@ -152,9 +153,19 @@ ORT_API_STATUS_IMPL(OrtModelBuilderAPI::SetGraphOutputs, _In_ OrtGraph* graph, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, _Inout_ OrtValue* tensor) { +ORT_API_STATUS_IMPL(OrtModelBuilderAPI::AddInitializerToGraph, _In_ OrtGraph* graph, _In_ const char* name, + _Inout_ OrtValue* tensor, bool data_is_external) { API_IMPL_BEGIN - graph->initializers[name] = std::unique_ptr(tensor); // take ownership + if (data_is_external) { +#if !defined(DISABLE_EXTERNAL_INITIALIZERS) + graph->external_initializers[name] = std::unique_ptr(tensor); // take ownership +#else + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External initializers are not supported in this build"); +#endif + } else { + graph->initializers[name] = std::unique_ptr(tensor); // take ownership + } + return nullptr; API_IMPL_END } diff --git a/onnxruntime/test/shared_lib/test_model_builder_api.cc b/onnxruntime/test/shared_lib/test_model_builder_api.cc index 26e0165775e20..b6f8eb04e303f 100644 --- a/onnxruntime/test/shared_lib/test_model_builder_api.cc +++ b/onnxruntime/test/shared_lib/test_model_builder_api.cc @@ -32,7 +32,7 @@ Ort::Session CreateSession(Ort::Env& env, : default_session_options; // Set this to save the model if you want to debug. - // session_options.SetOptimizedModelFilePath(ORT_TSTR("graph_api_model.onnx")); + // session_options.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); Ort::Session session(env, graph_api_model, session_options); @@ -75,17 +75,51 @@ OrtNode* CreateNode(const OrtModelBuilderApi& api, &node)); return node; } - } // namespace +struct TestAllocator : public OrtAllocator { + TestAllocator() { + version = ORT_API_VERSION; + Info = [](const struct OrtAllocator* this_ptr) -> const struct OrtMemoryInfo* { + auto* test_allocator = static_cast(this_ptr); + return test_allocator->memory_info; + }; + + Free = [](struct OrtAllocator* allocator, void* p) -> void { + auto* test_allocator = static_cast(allocator); + // find the matching pointer and remove it + auto it = std::find_if(test_allocator->weights.begin(), test_allocator->weights.end(), + [p](const std::unique_ptr>& v) { return v->data() == p; }); + if (it == test_allocator->weights.end()) { + throw std::exception("Free called with unknown pointer"); + } + + test_allocator->weights.erase(it); + }; + + Alloc = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::exception("This should not be used"); + }; + + Reserve = [](struct OrtAllocator* /*this*/, size_t /*size*/) -> void* { + throw std::exception("This should not be used"); + }; + } + + // initializers that are used directly by the model. as there's no copy they must remain valid. + // we store them in the test allocator so we can validate that Free is called + std::vector>> weights; + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, + OrtMemType::OrtMemTypeDefault); +}; + // Test the ModelBuilderAPI C api // Uses the ORT C++ api for the rest for simplicity TEST(ModelBuilderAPITest, Basic_CApi) { const auto& api = Ort::GetApi(); const auto& graph_api = Ort::GetModelBuilderApi(); - // initializers that are used directly by the model. as there's no copy they must remain valid - std::vector>> weights; + TestAllocator deleter; // return void so we can use ASSERT_* in the lambda const auto build_model = [&](bool use_constant_node, OrtModel*& model) -> void { @@ -153,7 +187,7 @@ TEST(ModelBuilderAPITest, Basic_CApi) { std::vector node_attributes{alpha_attr}; OrtNode* node = CreateNode(graph_api, "Gemm", "Gemm1", node_input_names, node_output_names, node_attributes); - api.ReleaseOpAttr(alpha_attr); // CreateNode copies an OrtOpAttr instances + api.ReleaseOpAttr(alpha_attr); // CreateNode copies all OrtOpAttr instances Ort::ThrowOnError(graph_api.AddNodeToGraph(graph, node)); node = nullptr; // graph now owns node @@ -166,18 +200,21 @@ TEST(ModelBuilderAPITest, Basic_CApi) { // create an initializer for the Y input. add to `weights` so the memory remains valid OrtValue* y_tensor = nullptr; std::vector y_dims = {2, 3}; - weights.emplace_back(std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, - 4.0f, 5.0f, 6.0f})); - auto& y_values = *weights.back(); + deleter.weights.emplace_back( + std::make_unique>(std::initializer_list{1.0f, 2.0f, 3.0f, + 4.0f, 5.0f, 6.0f})); + auto& y_values = *deleter.weights.back(); auto info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession - Ort::ThrowOnError(api.CreateTensorWithDataAsOrtValue(info, - y_values.data(), y_values.size() * sizeof(y_values[0]), - y_dims.data(), y_dims.size(), - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, - &y_tensor)); - Ort::ThrowOnError(graph_api.AddInitializerToGraph(graph, "Y", y_tensor)); + Ort::ThrowOnError( + api.CreateTensorWithDataAndDeleterAsOrtValue(&deleter, + y_values.data(), y_values.size() * sizeof(y_values[0]), + y_dims.data(), y_dims.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + &y_tensor)); + + Ort::ThrowOnError(graph_api.AddInitializerToGraph(graph, "Y", y_tensor, /*data is external*/ true)); y_tensor = nullptr; // graph now owns std::vector domain_names = {onnxruntime::kOnnxDomain}; @@ -208,6 +245,10 @@ TEST(ModelBuilderAPITest, Basic_CApi) { {18.0f, 24.0f, 30.0f, 38.0f, 52.0f, 66.0f, 58.0f, 80.0f, 102.0f}); + + api.ReleaseSession(session.release()); + + ASSERT_EQ(deleter.weights.size(), 0) << "All weights should have been freed"; } TEST(ModelBuilderAPITest, Basic_CxxApi) { @@ -225,7 +266,7 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { std::vector graph_inputs; std::vector graph_outputs; - // model input + // model input. it's {3, 2} but use a symbolic dim to test that works. std::vector input_dims({-1, 2}); std::vector input_symbolic_dims({"multiple_of_3", ""}); TensorTypeAndShapeInfo input_tensor_info(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, @@ -268,7 +309,7 @@ TEST(ModelBuilderAPITest, Basic_CxxApi) { // if you use this API the initializer data MUST remain valid for the lifetime of the InferenceSession auto y_tensor = Value::CreateTensor(info, y_values.data(), y_values.size(), y_dims.data(), y_dims.size()); - graph.AddInitializer("Y", y_tensor); + graph.AddInitializer("Y", y_tensor, /*data is external*/ true); std::vector opsets{{onnxruntime::kOnnxDomain, 18}}; ModelBuilderAPI::Model model(opsets); @@ -298,6 +339,10 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { // SessionOptions so; + + // Set this to save the model if you want to debug. + so.SetOptimizedModelFilePath(ORT_TSTR("model_builder_output.onnx")); + Session session = Session::CreateModelBuilderSession(*ort_env, TSTR("testdata/mnist.onnx"), so); ASSERT_EQ(session.GetOpset(""), 8); // ONNX domain is empty string @@ -322,7 +367,8 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { int64_t to = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; attributes.push_back(OpAttr("to", &to, 1, OrtOpAttrType::ORT_OP_ATTR_INT)); - ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]}, attributes); + ModelBuilderAPI::Node node("Cast", onnxruntime::kOnnxDomain, new_input_name, {"Int64Input"}, {input_names[0]}, + attributes); // we're replacing the only input, so we don't need to call session.GetInputTypeInfo(x) to copy other inputs // in order to preserve them @@ -352,16 +398,33 @@ TEST(ModelBuilderAPITest, BasicModelEdit_CxxApi) { std::iota(input.values.begin(), input.values.end(), 1); std::vector expected_dims = {1, 10}; - TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, - {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, 750.92145f, - 231.5060f, -1694.4152f, 681.5623f, 378.1689f}); + std::vector expected_output = {-48.5088f, -1040.2948f, -347.0959f, 101.7392f, 421.3352f, + 750.92145f, 231.5060f, -1694.4152f, 681.5623f, 378.1689f}; + + TestInference(session, inputs, session.GetOutputNames()[0].c_str(), expected_dims, expected_output); + + // double check with original model + { + SessionOptions expected_so; + Session expected_session = Session(*ort_env, TSTR("testdata/mnist.onnx"), expected_so); + std::vector> expected_inputs(1); + auto& expected_input = expected_inputs[0]; + expected_input.name = input_names[0].c_str(); + expected_input.dims = orig_input.GetTensorTypeAndShapeInfo().GetShape(); + expected_input.values.reserve(size_t(num_values)); + std::transform(input.values.begin(), input.values.end(), std::back_inserter(expected_input.values), + [&](int64_t value) { return float(value); }); + + TestInference(expected_session, expected_inputs, session.GetOutputNames()[0].c_str(), + expected_dims, expected_output); + } } /* Tests required +- Constant node is converted to initializer - Attempt to create invalid model -- Create symbolic dims for model input or output - Edit and change outputs - Invalid edit - Edit where we change a subset of inputs or outputs.