diff --git a/CMakeLists.txt b/CMakeLists.txt index d484d73..fe2fbe8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,10 +23,10 @@ target_link_libraries(onnxruntime PRIVATE ${CMAKE_JS_LIB}) if (WIN32) target_link_libraries(onnxruntime PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/win-x64/onnxruntime.lib) elseif (APPLE) - target_link_libraries(onnxruntime PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/darwin-x64/libonnxruntime.0.2.1.dylib) + target_link_libraries(onnxruntime PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/darwin-x64/libonnxruntime.0.3.0.dylib) set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "@loader_path") else() - target_link_libraries(onnxruntime PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/linux-x64/libonnxruntime.so.0.2.1) + target_link_libraries(onnxruntime PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/linux-x64/libonnxruntime.so.0.3.0) set_target_properties(onnxruntime PROPERTIES INSTALL_RPATH "$ORIGIN/") endif() @@ -42,7 +42,7 @@ if (NOT APPLE) if (WIN32) target_link_libraries(onnxruntime_gpu PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/win_gpu-x64/onnxruntime.lib) else() - target_link_libraries(onnxruntime_gpu PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.2.1) + target_link_libraries(onnxruntime_gpu PRIVATE ${CMAKE_SOURCE_DIR}/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.3.0) set_target_properties(onnxruntime_gpu PROPERTIES INSTALL_RPATH "$ORIGIN/") endif() endif() diff --git a/README.md b/README.md index 34ede86..42ea1b4 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,11 @@ npm install onnxjs-node ## Supported Platforms OS |Arch |CPU/GPU |NAPI version |ONNXRuntime version ---------|-----|--------|-------------|-------------------- - Windows | x64 | CPU | v3 | v0.2.1 - Linux | x64 | CPU | v3 | v0.2.1 - macOS | x64 | CPU | v3 | v0.2.1 - Windows | x64 | GPU | v3 | v0.2.1 - Linux | x64 | GPU | v3 | v0.2.1 + Windows | x64 | CPU | v3 | v0.3.0 + Linux | x64 | CPU | v3 | v0.3.0 + macOS | x64 | CPU | v3 | v0.3.0 + Windows | x64 | GPU | v3 | v0.3.0 + Linux | x64 | GPU | v3 | v0.3.0 ## Usage There are 2 options to import `onnxjs-node`. @@ -54,6 +54,8 @@ session = new onnx.InferenceSession({backendHint: 'wasm'}); // use WebAssembly ## Documentation - [ONNX.js Home](https://github.com/Microsoft/onnxjs) - [ONNXRuntime](https://github.com/Microsoft/onnxruntime) +- [Nuget package: Microsoft.ML.OnnxRuntime](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime/) +- [Nuget package: Microsoft.ML.OnnxRuntime.Gpu](https://www.nuget.org/packages/Microsoft.ML.OnnxRuntime.Gpu/) # License Copyright (c) fs-eire. All rights reserved. diff --git a/onnxruntime/README.md b/onnxruntime/README.md deleted file mode 100644 index 92b2a2b..0000000 --- a/onnxruntime/README.md +++ /dev/null @@ -1,130 +0,0 @@ -# ONNX Runtime C# API -The ONNX runtime provides a C# .Net binding for running inference on ONNX models in any of the .Net standard platforms. The API is .Net standard 1.1 compliant for maximum portability. This document describes the API. - -## NuGet Package -The Microsoft.ML.OnnxRuntime Nuget package includes the precompiled binaries for ONNX runtime, and includes libraries for Windows and Linux platforms with X64 CPUs. The APIs conform to .Net Standard 1.1. - -## Getting Started -Here is simple tutorial for getting started with running inference on an existing ONNX model for a given input data. The model is typically trained using any of the well-known training frameworks and exported into the ONNX format. To start scoring using the model, open a session using the `InferenceSession` class, passing in the file path to the model as a parameter. - - var session = new InferenceSession("model.onnx"); - -Once a session is created, you can execute queries using the `Run` method of the `InferenceSession` object. Currently, only `Tensor` type of input and outputs are supported. The results of the `Run` method are represented as a collection of .Net `Tensor` objects (as defined in [System.Numerics.Tensor](https://www.nuget.org/packages/System.Numerics.Tensors)). - - Tensor t1, t2; // let's say data is fed into the Tensor objects - var inputs = new List() - { - NamedOnnxValue.CreateFromTensor("name1", t1), - NamedOnnxValue.CreateFromTensor("name2", t2) - }; - using (var results = session.Run(inputs)) - { - // manipulate the results - } - -You can load your input data into Tensor objects in several ways. A simple example is to create the Tensor from arrays. - - float[] sourceData; // assume your data is loaded into a flat float array - int[] dimensions; // and the dimensions of the input is stored here - Tensor t1 = new DenseTensor(sourceData, dimensions); - -Here is a [complete sample code](https://github.com/Microsoft/onnxruntime/tree/master/csharp/sample/Microsoft.ML.OnnxRuntime.InferenceSample) that runs inference on a pretrained model. - -## Running on GPU (Optional) -If using the GPU package, simply use the appropriate SessionOptions when creating an InferenceSession. - - int gpuDeviceId = 0; // The GPU device ID to execute on - var session = new InferenceSession("model.onnx", SessionOptions.MakeSessionOptionWithCudaProvider(gpuDeviceId)); - -## API Reference -### InferenceSession - class InferenceSession: IDisposable -The runtime representation of an ONNX model - -#### Constructor - InferenceSession(string modelPath); - InferenceSession(string modelPath, SesionOptions options); - -#### Properties - IReadOnlyDictionary InputMetadata; -Data types and shapes of the input nodes of the model. - IReadOnlyDictionary OutputMetadata; -Data types and shapes of the output nodes of the model. - -#### Methods - IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs); -Runs the model with the given input data to compute all the output nodes and returns the output node values. Both input and output are collection of NamedOnnxValue, which in turn is a name-value pair of string names and Tensor values. The outputs are IDisposable variant of NamedOnnxValue, since they wrap some unmanaged objects. - - IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection desiredOutputNodes); -Runs the model on given inputs for the given output nodes only. - -### System.Numerics.Tensor -The primary .Net object that is used for holding input-output of the model inference. Details on this newly introduced data type can be found in its [open-source implementation](https://github.com/dotnet/corefx/tree/master/src/System.Numerics.Tensors). The binaries are available as a [.Net NuGet package](https://www.nuget.org/packages/System.Numerics.Tensors). - -### NamedOnnxValue - class NamedOnnxValue; -Represents a name-value pair of string names and any type of value that ONNX runtime supports as input-output data. Currently, only Tensor objects are supported as input-output values. - -#### Constructor - No public constructor available. - -#### Properties - string Name; // read only - -#### Methods - static NamedOnnxValue CreateFromTensor(string name, Tensor); -Creates a NamedOnnxValue from a name and a Tensor object. - - Tensor AsTensor(); -Accesses the value as a Tensor. Returns null if the value is not a Tensor. - -### DisposableNamedOnnxValue - class DisposableNamedOnnxValue: NamedOnnxValue, IDisposable; -This is a disposable variant of NamedOnnxValue, used for holding output values which contains objects allocated in unmanaged memory. - -### IDisposableReadOnlyCollection - interface IDisposableReadOnlyCollection: IReadOnlyCollection, IDisposable -Collection interface to hold disposable values. Used for output of Run method. - -### SessionOptions - class SessionOptions: IDisposable; -A collection of properties to be set for configuring the OnnxRuntime session - -#### Constructor - SessionOptions(); -Constructs a SessionOptions will all options at default/unset values. - -#### Properties - static SessionOptions Default; //read-only -Accessor to the default static option object - -#### Methods - AppendExecutionProvider(ExecutionProvider provider); -Appends execution provider to the session. For any operator in the graph the first execution provider that implements the operator will be user. ExecutionProvider is defined as the following enum. - - enum ExecutionProvider - { - Cpu, - MklDnn - } - -### NodeMetadata -Container of metadata for a model graph node, used for communicating the shape and type of the input and output nodes. - -#### Properties - int[] Dimensions; -Read-only shape of the node, when the node is a Tensor. Undefined if the node is not a Tensor. - - System.Type ElementType; -Type of the elements of the node, when node is a Tensor. Undefined for non-Tensor nodes. - - bool IsTensor; -Whether the node is a Tensor - -### Exceptions - class OnnxRuntimeException: Exception; - -The type of Exception that is thrown in most of the error conditions related to Onnx Runtime. - - - diff --git a/onnxruntime/ThirdPartyNotices.txt b/onnxruntime/ThirdPartyNotices.txt index fd95e05..4ae8ac3 100644 --- a/onnxruntime/ThirdPartyNotices.txt +++ b/onnxruntime/ThirdPartyNotices.txt @@ -205,38 +205,6 @@ external contributions to this project including patches, pull requests, etc. _____ -NVlabs/cub - -Copyright (c) 2010-2011, Duane Merrill. All rights reserved. -Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - * Neither the name of the NVIDIA CORPORATION nor the - names of its contributors may be used to endorse or promote products - derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY -DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -_____ - onnx Open Neural Network Exchange @@ -1913,35 +1881,9 @@ DLPack: Open In Memory Tensor Structure See the License for the specific language governing permissions and limitations under the License. -_____ - -JSON for Modern C++ - -MIT License - -Copyright (c) 2013-2018 Niels Lohmann - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - ____ -Date +HowardHinnant/date The source code in this project is released using the MIT License. There is no global license for the project because each file is licensed individually with @@ -2815,55 +2757,32 @@ Apache License _____ -dotnet/corefxlab - -The MIT License (MIT) - -Copyright (c) .NET Foundation and Contributors - -All rights reserved. +google/re2 -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Copyright (c) 2009 The RE2 Authors. All rights reserved. -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -_____ - -dotnet/standard - -The MIT License (MIT) - -Copyright (c) .NET Foundation and Contributors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/onnxruntime/bin/darwin-x64/libonnxruntime.0.2.1.dylib b/onnxruntime/bin/darwin-x64/libonnxruntime.0.3.0.dylib old mode 100755 new mode 100644 similarity index 50% rename from onnxruntime/bin/darwin-x64/libonnxruntime.0.2.1.dylib rename to onnxruntime/bin/darwin-x64/libonnxruntime.0.3.0.dylib index 0a1ece5..9e065b4 Binary files a/onnxruntime/bin/darwin-x64/libonnxruntime.0.2.1.dylib and b/onnxruntime/bin/darwin-x64/libonnxruntime.0.3.0.dylib differ diff --git a/onnxruntime/bin/linux-x64/libmkldnn.so.0 b/onnxruntime/bin/linux-x64/libmkldnn.so.0 deleted file mode 100644 index 28e050c..0000000 Binary files a/onnxruntime/bin/linux-x64/libmkldnn.so.0 and /dev/null differ diff --git a/onnxruntime/bin/linux-x64/libonnxruntime.so.0.2.1 b/onnxruntime/bin/linux-x64/libonnxruntime.so.0.3.0 similarity index 55% rename from onnxruntime/bin/linux-x64/libonnxruntime.so.0.2.1 rename to onnxruntime/bin/linux-x64/libonnxruntime.so.0.3.0 index d28669a..836dbd3 100644 Binary files a/onnxruntime/bin/linux-x64/libonnxruntime.so.0.2.1 and b/onnxruntime/bin/linux-x64/libonnxruntime.so.0.3.0 differ diff --git a/onnxruntime/bin/linux_gpu-x64/libmkldnn.so.0 b/onnxruntime/bin/linux_gpu-x64/libmkldnn.so.0 deleted file mode 100644 index 28e050c..0000000 Binary files a/onnxruntime/bin/linux_gpu-x64/libmkldnn.so.0 and /dev/null differ diff --git a/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.2.1 b/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.3.0 similarity index 77% rename from onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.2.1 rename to onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.3.0 index 87c00c6..b573dd7 100644 Binary files a/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.2.1 and b/onnxruntime/bin/linux_gpu-x64/libonnxruntime.so.0.3.0 differ diff --git a/onnxruntime/bin/win-x64/mkldnn.dll b/onnxruntime/bin/win-x64/mkldnn.dll deleted file mode 100644 index ee74786..0000000 Binary files a/onnxruntime/bin/win-x64/mkldnn.dll and /dev/null differ diff --git a/onnxruntime/bin/win-x64/onnxruntime.dll b/onnxruntime/bin/win-x64/onnxruntime.dll index 96146ce..ca8e7dc 100644 Binary files a/onnxruntime/bin/win-x64/onnxruntime.dll and b/onnxruntime/bin/win-x64/onnxruntime.dll differ diff --git a/onnxruntime/bin/win-x64/onnxruntime.lib b/onnxruntime/bin/win-x64/onnxruntime.lib index c82fbe7..bba0e56 100644 Binary files a/onnxruntime/bin/win-x64/onnxruntime.lib and b/onnxruntime/bin/win-x64/onnxruntime.lib differ diff --git a/onnxruntime/bin/win_gpu-x64/mkldnn.dll b/onnxruntime/bin/win_gpu-x64/mkldnn.dll deleted file mode 100644 index 7a8aa12..0000000 Binary files a/onnxruntime/bin/win_gpu-x64/mkldnn.dll and /dev/null differ diff --git a/onnxruntime/bin/win_gpu-x64/onnxruntime.dll b/onnxruntime/bin/win_gpu-x64/onnxruntime.dll index 87c2111..a492b18 100644 Binary files a/onnxruntime/bin/win_gpu-x64/onnxruntime.dll and b/onnxruntime/bin/win_gpu-x64/onnxruntime.dll differ diff --git a/onnxruntime/bin/win_gpu-x64/onnxruntime.lib b/onnxruntime/bin/win_gpu-x64/onnxruntime.lib index a033a3b..f5801ea 100644 Binary files a/onnxruntime/bin/win_gpu-x64/onnxruntime.lib and b/onnxruntime/bin/win_gpu-x64/onnxruntime.lib differ diff --git a/onnxruntime/inc/core/common/code_location.h b/onnxruntime/inc/core/common/code_location.h deleted file mode 100644 index ff6506c..0000000 --- a/onnxruntime/inc/core/common/code_location.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { -/** - CodeLocation captures information on where in the source code a message came from. -*/ -struct CodeLocation { - /** - @param file_path Usually the value of __FILE__ - @param line Usually the value of __LINE__ - @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ - */ - CodeLocation(const char* file_path, const int line, const char* func) - : file_and_path{file_path}, line_num{line}, function{func} { - } - - /** - @param file_path Usually the value of __FILE__ - @param line Usually the value of __LINE__ - @param func Usually the value of __PRETTY_FUNCTION__ or __FUNCTION__ - @param stacktrace Stacktrace from source of message. - */ - CodeLocation(const char* file_path, const int line, const char* func, const std::vector& stacktrace) - : file_and_path{file_path}, line_num{line}, function{func}, stacktrace(stacktrace) { - } - - std::string FileNoPath() const { - // assuming we always have work to do, so not trying to avoid creating a new string if - // no path was removed. - return file_and_path.substr(file_and_path.find_last_of("/\\") + 1); - } - - enum Format { - kFilename, - kFilenameAndPath - }; - - std::string ToString(Format format = Format::kFilename) const { - std::ostringstream out; - out << (format == Format::kFilename ? FileNoPath() : file_and_path) << ":" << line_num << " " << function; - return out.str(); - } - - const std::string file_and_path; - const int line_num; - const std::string function; - const std::vector stacktrace; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/common.h b/onnxruntime/inc/core/common/common.h deleted file mode 100644 index f6356f4..0000000 --- a/onnxruntime/inc/core/common/common.h +++ /dev/null @@ -1,217 +0,0 @@ -/** - * Copyright (c) 2016-present, Facebook, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -// Portions Copyright (c) Microsoft Corporation - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "core/common/code_location.h" -#include "core/common/exceptions.h" -#include "core/common/status.h" - -namespace onnxruntime { - -using TimePoint = std::chrono::high_resolution_clock::time_point; - -// Using statements for common classes that we refer to in ONNXRuntime very often. -// TODO(Task:137) Remove 'using' statements from header files -using common::Status; - -#ifdef _WIN32 -#define ORT_UNUSED_PARAMETER(x) (x) -#else -#define ORT_UNUSED_PARAMETER(x) (void)(x) -#endif - -#ifndef ORT_HAVE_ATTRIBUTE -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif -#endif - -// ORT_ATTRIBUTE_UNUSED -// -// Prevents the compiler from complaining about or optimizing away variables -// that appear unused on Linux -#if ORT_HAVE_ATTRIBUTE(unused) || (defined(__GNUC__) && !defined(__clang__)) -#undef ORT_ATTRIBUTE_UNUSED -#define ORT_ATTRIBUTE_UNUSED __attribute__((__unused__)) -#else -#define ORT_ATTRIBUTE_UNUSED -#endif - -// macro to explicitly ignore the return value from a function call so Code Analysis doesn't complain -#define ORT_IGNORE_RETURN_VALUE(fn) \ - static_cast(fn) - -std::vector GetStackTrace(); - -// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER -// so we only define it as one for MSVC -#if (_MSC_VER && !defined(__PRETTY_FUNCTION__)) -#define __PRETTY_FUNCTION__ __FUNCTION__ -#endif - -// Capture where a message is coming from. Use __FUNCTION__ rather than the much longer __PRETTY_FUNCTION__ -#define ORT_WHERE \ - ::onnxruntime::CodeLocation(__FILE__, __LINE__, __FUNCTION__) - -#define ORT_WHERE_WITH_STACK \ - ::onnxruntime::CodeLocation(__FILE__, __LINE__, __PRETTY_FUNCTION__, ::onnxruntime::GetStackTrace()) - -// Throw an exception with optional message. -// NOTE: The arguments get streamed into a string via ostringstream::operator<< -// DO NOT use a printf format string, as that will not work as you expect. -#define ORT_THROW(...) \ - throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, ::onnxruntime::MakeString(__VA_ARGS__)) - -// Just in order to mark things as not implemented. Do not use in final code. -#define ORT_NOT_IMPLEMENTED(...) \ - throw ::onnxruntime::NotImplementedException(::onnxruntime::MakeString(__VA_ARGS__)) - -// Check condition. -// NOTE: The arguments get streamed into a string via ostringstream::operator<< -// DO NOT use a printf format string, as that will not work as you expect. -#define ORT_ENFORCE(condition, ...) \ - if (!(condition)) \ - throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, #condition, \ - ::onnxruntime::MakeString(__VA_ARGS__)) - -#define ORT_MAKE_STATUS(category, code, ...) \ - ::onnxruntime::common::Status(::onnxruntime::common::category, \ - ::onnxruntime::common::code, \ - ::onnxruntime::MakeString(__VA_ARGS__)) - -// Check condition. if not met, return status. -#define ORT_RETURN_IF_NOT(condition, ...) \ - if (!(condition)) { \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Not satsified: " #condition "\n", \ - ORT_WHERE.ToString(), ::onnxruntime::MakeString(__VA_ARGS__)); \ - } - -// Macros to disable the copy and/or move ctor and assignment methods -// These are usually placed in the private: declarations for a class. - -#define ORT_DISALLOW_COPY(TypeName) TypeName(const TypeName&) = delete - -#define ORT_DISALLOW_ASSIGNMENT(TypeName) TypeName& operator=(const TypeName&) = delete - -#define ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName) \ - ORT_DISALLOW_COPY(TypeName); \ - ORT_DISALLOW_ASSIGNMENT(TypeName) - -#define ORT_DISALLOW_MOVE(TypeName) \ - TypeName(TypeName&&) = delete; \ - TypeName& operator=(TypeName&&) = delete - -#define ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TypeName) \ - ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \ - ORT_DISALLOW_MOVE(TypeName) - -#define ORT_RETURN_IF_ERROR(expr) \ - do { \ - auto _status = (expr); \ - if ((!_status.IsOK())) return _status; \ - } while (0) - -// use this macro when cannot early return -#define ORT_CHECK_AND_SET_RETVAL(expr) \ - do { \ - if (retval.IsOK()) { \ - retval = (expr); \ - } \ - } while (0) - -// C++ Core Guideline check suppression. -#if defined(_MSC_VER) && !defined(__NVCC__) -#define GSL_SUPPRESS(tag) [[gsl::suppress(tag)]] -#else -#define GSL_SUPPRESS(tag) -#endif - -inline void MakeStringInternal(std::ostringstream& /*ss*/) noexcept { -} - -template -inline void MakeStringInternal(std::ostringstream& ss, const T& t) noexcept { - ss << t; -} - -template -inline void MakeStringInternal(std::ostringstream& ss, const T& t, const Args&... args) noexcept { - ::onnxruntime::MakeStringInternal(ss, t); - ::onnxruntime::MakeStringInternal(ss, args...); -} - -template -std::string MakeString(const Args&... args) { - std::ostringstream ss; - ::onnxruntime::MakeStringInternal(ss, args...); - return std::string(ss.str()); -} - -// Specializations for already-a-string types. -template <> -inline std::string MakeString(const std::string& str) { - return str; -} -inline std::string MakeString(const char* p_str) { - return p_str; -} - -inline long long TimeDiffMicroSeconds(TimePoint start_time) { - auto end_time = std::chrono::high_resolution_clock::now(); - return std::chrono::duration_cast(end_time - start_time).count(); -} - -inline long long TimeDiffMicroSeconds(TimePoint start_time, TimePoint end_time) { - return std::chrono::duration_cast(end_time - start_time).count(); -} - -inline std::string GetCurrentTimeString() { - auto now = std::chrono::system_clock::now(); - auto in_time_t = std::chrono::system_clock::to_time_t(now); - std::tm local_tm; //NOLINT - -#ifdef _WIN32 - localtime_s(&local_tm, &in_time_t); -#else - localtime_r(&in_time_t, &local_tm); -#endif - - char time_str[32]; - strftime(time_str, sizeof(time_str), "%Y-%m-%d_%H-%M-%S", &local_tm); - return std::string(time_str); -} - -struct null_type {}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/const_pointer_container.h b/onnxruntime/inc/core/common/const_pointer_container.h deleted file mode 100644 index bfc873f..0000000 --- a/onnxruntime/inc/core/common/const_pointer_container.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace onnxruntime { -/** - Container has T* entries. e.g. std::vector, and this class provides const access to those - via iterators and direct access, as the standard behavior only makes the pointer constant, - and not what is pointed too. i.e. you get a const pointer to T not a pointer to const T without this wrapper. - See https://stackoverflow.com/questions/8017036/understanding-const-iterator-with-pointers -*/ -template -class ConstPointerContainer { - public: - using T = typename std::remove_pointer::type; - - class ConstIterator { - public: - using const_iterator = typename Container::const_iterator; - - /** Construct iterator for container that will return const T* entries.*/ - explicit ConstIterator(const_iterator position) noexcept : current_(position) {} - - bool operator==(const ConstIterator& other) const noexcept { return current_ == other.current_; } - bool operator!=(const ConstIterator& other) const noexcept { return current_ != other.current_; } - void operator++() { ++current_; } - const T* operator*() { return *current_; } - - private: - const_iterator current_; - }; - - /** - Construct wrapper class that will provide const access to the pointers in a container of non-const pointers. - @param data Container with non-const pointers. e.g. std::vector - */ - explicit ConstPointerContainer(const Container& data) noexcept : data_(data) {} - - size_t size() const noexcept { return data_.size(); } - - ConstIterator begin() const noexcept { return ConstIterator(data_.cbegin()); } - ConstIterator end() const noexcept { return ConstIterator(data_.cend()); } - - const T* operator[](size_t index) const { return data_[index]; } - - const T* at(size_t index) const { - ORT_ENFORCE(index < data_.size()); - return data_[index]; - } - - private: - const Container& data_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/exceptions.h b/onnxruntime/inc/core/common/exceptions.h deleted file mode 100644 index 31e7a9f..0000000 --- a/onnxruntime/inc/core/common/exceptions.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/code_location.h" - -namespace onnxruntime { - -class NotImplementedException : public std::logic_error { - public: - explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; - explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message){}; -}; - -class TypeMismatchException : public std::logic_error { - public: - TypeMismatchException() noexcept : logic_error("Type mismatch"){}; -}; - -class OnnxRuntimeException : public std::exception { - public: - OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept - : OnnxRuntimeException(location, nullptr, msg) { - } - - /** - Create a new exception that captures the location it was thrown from. - @param location Location in the source code the exception is being thrown from - @param failed_condition Optional string containing the condition that failed. - e.g. "tensor.Size() == input.Size()". May be nullptr. - @param msg Message containing additional information about the exception cause. - */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { - std::ostringstream ss; - - ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous - if (failed_condition != nullptr) { - ss << " " << failed_condition << " was false."; - } - - ss << " " << msg << "\n"; - if (!location.stacktrace.empty()) { - ss << "Stacktrace:\n"; - // skip the first entry in the stacktrace as we have that information from location.ToString() - std::copy(++location.stacktrace.begin(), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); - } - - what_ = ss.str(); - } - - const char* what() const noexcept override { - return what_.c_str(); - } - - private: - const CodeLocation location_; - const std::vector stacktrace_; - std::string what_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/logging/capture.h b/onnxruntime/inc/core/common/logging/capture.h deleted file mode 100644 index d59f7ac..0000000 --- a/onnxruntime/inc/core/common/logging/capture.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/common/common.h" -#include "core/common/code_location.h" -#include "core/common/logging/severity.h" - -namespace onnxruntime { -namespace logging { - -class Logger; -enum class DataType; - -/** - Class to capture the details of a log message. -*/ -class Capture { - public: - /** - Initializes a new instance of the Capture class. - @param logger The logger. - @param severity The severity. - @param category The category. - @param dataType Type of the data. - @param location The file location the log message is coming from. - */ - Capture(const Logger& logger, logging::Severity severity, const char* category, - logging::DataType dataType, const CodeLocation& location) - : logger_{&logger}, severity_{severity}, category_{category}, data_type_{dataType}, location_{location} { - } - - /** - The stream that can capture the message via operator<<. - @returns Output stream. - */ - std::ostream& Stream() noexcept { - return stream_; - } - -#ifdef _MSC_VER - // add SAL annotation for printf format string. requires Code Analysis to run to validate usage. -#define msvc_printf_check _Printf_format_string_ -#define __attribute__(x) // Disable for MSVC. Supported by GCC and CLang. -#else -#define msvc_printf_check -#endif - - /** - Captures a printf style log message. - @param name="format">The printf format. - @param name="">Arguments to the printf format if needed. - @remarks - A maximum of 2K of output will be captured currently. - Non-static method, so 'this' is implicit first arg, and we use format(printf(2,3) - */ - void CapturePrintf(msvc_printf_check const char* format, ...) __attribute__((format(printf, 2, 3))); - - /** - Process a printf style log message. - @param format The printf format. - @param ... Arguments to the printf format if needed. - @remarks - A maximum of 2K of output will be captured currently. - Note: As va_list is 'char *', we have to disambiguate this from CapturePrintf - so that something like "One string: %s", "the string" does not consider "the string" - to be the va_list. - */ - void ProcessPrintf(msvc_printf_check const char* format, va_list args); - - logging::Severity Severity() const noexcept { - return severity_; - } - - char SeverityPrefix() const noexcept { - // Carefully setup so severity_ is a valid index - GSL_SUPPRESS(bounds .2) { - return logging::SEVERITY_PREFIX[static_cast(severity_)]; - } - } - - const char* Category() const noexcept { - return category_; - } - - logging::DataType DataType() const noexcept { - return data_type_; - } - - const CodeLocation& Location() const noexcept { - return location_; - } - - std::string Message() const noexcept { - return stream_.str(); - } - - ~Capture(); - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Capture); - - const Logger* logger_; - const logging::Severity severity_; - const char* category_; - const logging::DataType data_type_; - const CodeLocation location_; - - std::ostringstream stream_; -}; -} // namespace logging -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/logging/isink.h b/onnxruntime/inc/core/common/logging/isink.h deleted file mode 100644 index a67777d..0000000 --- a/onnxruntime/inc/core/common/logging/isink.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/logging/logging.h" - -namespace onnxruntime { -namespace logging { -class ISink { - public: - ISink() = default; - - /** - Sends the message to the sink. - @param timestamp The timestamp. - @param logger_id The logger identifier. - @param message The captured message. - */ - void Send(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) { - SendImpl(timestamp, logger_id, message); - } - - /** - Sends a Profiling Event Record to the sink. - @param Profiling Event Record - */ - virtual void SendProfileEvent(profiling::EventRecord&) const {}; - - virtual ~ISink() = default; - - private: - // Make Code Analysis happy by disabling all for now. Enable as needed. - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink); - - virtual void SendImpl(const Timestamp& timestamp, const std::string& logger_id, const Capture& message) = 0; -}; -} // namespace logging -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/logging/logging.h b/onnxruntime/inc/core/common/logging/logging.h deleted file mode 100644 index 3c808b9..0000000 --- a/onnxruntime/inc/core/common/logging/logging.h +++ /dev/null @@ -1,318 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/logging/capture.h" -#include "core/common/logging/severity.h" - -#include "core/common/logging/macros.h" - -/* - - Logging overview and expected usage: - - At program startup: - * Create one or more ISink instances. If multiple, combine using composite_sink. - * Create a LoggingManager instance with the sink/s with is_default_instance set to true - * Only one instance should be created in this way, and it should remain valid for - until the program no longer needs to produce log output. - - You can either use the static default Logger which LoggingManager will create when constructed - via LoggingManager::DefaultLogger(), or separate Logger instances each with different log ids - via LoggingManager::CreateLogger. - - The log id is passed to the ISink instance with the sink determining how the log id is used - in the output. - - LoggingManager - * creates the Logger instances used by the application - * provides a static default logger instance - * owns the log sink instance - * applies checks on severity and output of user data - - The log macros create a Capture instance to capture the information to log. - If the severity and/or user filtering settings would prevent logging, no evaluation - of the log arguments will occur, so no performance cost beyond the severity and user - filtering check. - - A sink can do further filter as needed. - -*/ - -namespace onnxruntime { -namespace profiling { - -enum EventCategory { - SESSION_EVENT = 0, - NODE_EVENT, - EVENT_CATEGORY_MAX -}; - -/* -Event descriptions for the above session events. -*/ -static constexpr const char* event_categor_names_[EVENT_CATEGORY_MAX] = { - "Session", - "Node"}; - -/* -Timing record for all events. -*/ -struct EventRecord { - EventRecord(EventCategory category, - int process_id, - int thread_id, - std::string event_name, - long long time_stamp, - long long duration, - std::unordered_map&& event_args) : cat(category), - pid(process_id), - tid(thread_id), - name(std::move(event_name)), - ts(time_stamp), - dur(duration), - args(event_args) {} - EventCategory cat; - int pid; - int tid; - std::string name; - long long ts; - long long dur; - std::unordered_map args; -}; -} // namespace profiling -namespace logging { - -using Timestamp = std::chrono::time_point; - -#ifndef NDEBUG -ORT_ATTRIBUTE_UNUSED static bool vlog_enabled = true; // Set directly based on your needs. -#else -constexpr bool vlog_enabled = false; // no VLOG output -#endif - -enum class DataType { - SYSTEM = 0, ///< System data. - USER = 1 ///< Contains potentially sensitive user data. -}; - -// Internal log categories. -// Logging interface takes const char* so arbitrary values can also be used. -struct Category { - static const char* onnxruntime; ///< General output - static const char* System; ///< Log output regarding interactions with the host system - // TODO: What other high level categories are meaningful? Model? Optimizer? Execution? -}; - -class ISink; -class Logger; -class Capture; - -/// -/// The logging manager. -/// Owns the log sink and potentially provides a default Logger instance. -/// Provides filtering based on a minimum LogSeverity level, and of messages with DataType::User if enabled. -/// -class LoggingManager final { - public: - enum InstanceType { - Default, ///< Default instance of LoggingManager that should exist for the lifetime of the program - Temporal ///< Temporal instance. CreateLogger(...) should be used, however DefaultLogger() will NOT be provided via this instance. - }; - - /** - Initializes a new instance of the LoggingManager class. - @param sink The sink to write to. Use CompositeSink if you need to write to multiple places. - @param default_min_severity The default minimum severity. Messages with lower severity will be ignored unless - overridden in CreateLogger. - @param default_filter_user_data If set to true ignore messages with DataType::USER unless overridden in CreateLogger. - @param instance_type If InstanceType::Default, this is the default instance of the LoggingManager - and is expected to exist for the lifetime of the program. - It creates and owns the default logger that calls to the static DefaultLogger method return. - @param default_logger_id Logger Id to use for the default logger. nullptr/ignored if instance_type == Temporal. - @param default_max_vlog_level Default maximum level for VLOG messages to be created unless overridden in CreateLogger. - Requires a severity of kVERBOSE for VLOG messages to be logged. - */ - LoggingManager(std::unique_ptr sink, Severity default_min_severity, bool default_filter_user_data, - InstanceType instance_type, - const std::string* default_logger_id = nullptr, - int default_max_vlog_level = -1); - - /** - Creates a new logger instance which will use the provided logger_id and default severity and vlog levels. - @param logger_id The log identifier. - @returns A new Logger instance that the caller owns. - */ - std::unique_ptr CreateLogger(const std::string& logger_id); - - /** - Creates a new logger instance which will use the provided logger_id, severity and vlog levels. - @param logger_id The log identifier. - @param min_severity The minimum severity. Requests to create messages with lower severity will be ignored. - @param filter_user_data If set to true ignore messages with DataType::USER. - @param max_vlog_level Maximum level for VLOG messages to be created. - @returns A new Logger instance that the caller owns. - */ - std::unique_ptr CreateLogger(const std::string& logger_id, - Severity min_severity, bool filter_user_data, int max_vlog_level = -1); - - /** - Gets the default logger instance if set. Throws if no default logger is currently registered. - @remarks - Creating a LoggingManager instance with is_default_instance == true registers a default logger. - Note that the default logger is only valid until the LoggerManager that registered it is destroyed. - @returns The default logger if available. - */ - static const Logger& DefaultLogger(); - - /** - Logs a FATAL level message and creates an exception that can be thrown with error information. - @param category The log category. - @param location The location the log message was generated. - @param format_str The printf format string. - @param ... The printf arguments. - @returns A new Logger instance that the caller owns. - */ - static std::exception LogFatalAndCreateException(const char* category, - const CodeLocation& location, - const char* format_str, ...); - - /** - Logs the message using the provided logger id. - @param logger_id The log identifier. - @param message The log message. - */ - void Log(const std::string& logger_id, const Capture& message) const; - - /** - Sends a Profiling Event Record to the sink. - @param Profiling Event Record - */ - void SendProfileEvent(profiling::EventRecord& eventRecord) const; - ~LoggingManager(); - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(LoggingManager); - - Timestamp GetTimestamp() const noexcept; - void CreateDefaultLogger(const std::string& logger_id); - - std::unique_ptr sink_; - const Severity default_min_severity_; - const bool default_filter_user_data_; - const int default_max_vlog_level_; - bool owns_default_logger_; - static Logger* s_default_logger_; - - struct Epochs { - const std::chrono::time_point high_res; - const std::chrono::time_point system; - const std::chrono::minutes localtime_offset_from_utc; - }; - - static const Epochs& GetEpochs() noexcept; -}; - -/** - Logger provides a per-instance log id. Everything else is passed back up to the LoggingManager -*/ -class Logger { - public: - /** - Initializes a new instance of the Logger class. - @param loggingManager The logging manager. - @param id The identifier for messages coming from this Logger. - @param severity Minimum severity for messages to be created and logged. - @param filter_user_data Should USER data be filtered from output. - @param vlog_level Minimum level for VLOG messages to be created. Note that a severity of kVERBOSE must be provided - for VLOG messages to be logged. - */ - Logger(const LoggingManager& loggingManager, std::string id, - Severity severity, bool filter_user_data, int vlog_level) - : logging_manager_{&loggingManager}, - id_{id}, - min_severity_{severity}, - filter_user_data_{filter_user_data}, - max_vlog_level_{severity > Severity::kVERBOSE ? -1 : vlog_level} { // disable unless logging VLOG messages - } - - /** - Check if output is enabled for the provided LogSeverity and DataType values. - @param severity The severity. - @param data_type Type of the data. - @returns True if a message with these values will be logged. - */ - bool OutputIsEnabled(Severity severity, DataType data_type) const noexcept { - return (severity >= min_severity_ && (data_type != DataType::USER || !filter_user_data_)); - } - - /** - Return the maximum VLOG level allowed. - */ - int VLOGMaxLevel() const noexcept { - return max_vlog_level_; - } - - /** - Logs the captured message. - @param message The log message. - */ - void Log(const Capture& message) const { - logging_manager_->Log(id_, message); - } - - /** - Sends a Profiling Event Record to the sink. - @param Profiling Event Record - */ - void SendProfileEvent(profiling::EventRecord& eventRecord) const { - logging_manager_->SendProfileEvent(eventRecord); - } - - private: - const LoggingManager* logging_manager_; - const std::string id_; - const Severity min_severity_; - const bool filter_user_data_; - const int max_vlog_level_; -}; - -inline const Logger& LoggingManager::DefaultLogger() { - if (s_default_logger_ == nullptr) { - // fail early for attempted misuse. don't use logging macros as we have no logger. - throw std::logic_error("Attempt to use DefaultLogger but none has been registered."); - } - - return *s_default_logger_; -} - -inline Timestamp LoggingManager::GetTimestamp() const noexcept { - static const Epochs& epochs = GetEpochs(); - - const auto high_res_now = std::chrono::high_resolution_clock::now(); - return std::chrono::time_point_cast( - epochs.system + (high_res_now - epochs.high_res) + epochs.localtime_offset_from_utc); -} - -/** - Return the current thread id. -*/ -unsigned int GetThreadId(); - -/** - Return the current process id. -*/ -unsigned int GetProcessId(); - -} // namespace logging -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/logging/macros.h b/onnxruntime/inc/core/common/logging/macros.h deleted file mode 100644 index 570bc14..0000000 --- a/onnxruntime/inc/core/common/logging/macros.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -// NOTE: Don't include this file directly. Include logging.h - -#define CREATE_MESSAGE(logger, severity, category, datatype) \ - ::onnxruntime::logging::Capture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) - -/* - Both printf and stream style logging are supported. - Not that printf currently has a 2K limit to the message size. - - LOGS_* macros are for stream style - LOGF_* macros are for printf style - - The Message class captures the log input, and pushes it through the logger in its destructor. - - Use the *FATAL* macros if you want a Severity::kFatal message to also throw. - - There are a few variants to minimize the length of the macro name required in the calling code. - They are optimized so the shortest names are for the (expected) most common usage. This can be - tweaked if needed. - - Explicit logger vs LoggingManager::DefaulLogger() - Default is for a logger instance to be explicitly passed in. - The logger instance provides an identifier so that log messages from different runs can be separated. - - Variants with DEFAULT in the macro name use the default logger provided by logging manager. This is - static so accessible from any code, provided a LoggingManager instance created with InstanceType::Default - exists somewhere. See logging.h for further explanation of the expected setup. - - DataType - Default uses DataType::SYSTEM. - - Variants with USER in the macro name use DataType::USER. This is data that could be PII, and may need to - be filtered from output. LoggingManager applies this filtering. - - Category - Default category is ::onnxruntime::Logging::Category::onnxruntime. - - If you wish to provide a different category, use variants with CATEGORY in the macro name - -*/ - -// Logging with explicit category - -// iostream style logging. Capture log info in Message, and push to the logger in ~Message. -#define LOGS_CATEGORY(logger, severity, category) \ - if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ - CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).Stream() - -#define LOGS_USER_CATEGORY(logger, severity, category) \ - if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::USER)) \ - CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).Stream() - - // printf style logging. Capture log info in Message, and push to the logger in ~Message. -#define LOGF_CATEGORY(logger, severity, category, format_str, ...) \ - if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::SYSTEM)) \ - CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::SYSTEM).CapturePrintf(format_str, ##__VA_ARGS__) - -#define LOGF_USER_CATEGORY(logger, severity, category, format_str, ...) \ - if ((logger).OutputIsEnabled(::onnxruntime::logging::Severity::k##severity, ::onnxruntime::logging::DataType::USER)) \ - CREATE_MESSAGE(logger, severity, category, ::onnxruntime::logging::DataType::USER).CapturePrintf(format_str, ##__VA_ARGS__) - - // Logging with category of "onnxruntime" - -#define LOGS(logger, severity) \ - LOGS_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGS_USER(logger, severity) \ - LOGS_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime) - - // printf style logging. Capture log info in Message, and push to the logger in ~Message. -#define LOGF(logger, severity, format_str, ...) \ - LOGF_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - -#define LOGF_USER(logger, severity, format_str, ...) \ - LOGF_USER_CATEGORY(logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - - /* - - Macros that use the default logger. - A LoggingManager instance must be currently valid for the default logger to be available. - - */ - - // Logging with explicit category - -#define LOGS_DEFAULT_CATEGORY(severity, category) \ - LOGS_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) - -#define LOGS_USER_DEFAULT_CATEGORY(severity, category) \ - LOGS_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category) - -#define LOGF_DEFAULT_CATEGORY(severity, category, format_str, ...) \ - LOGF_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) - -#define LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ...) \ - LOGF_USER_CATEGORY(::onnxruntime::logging::LoggingManager::DefaultLogger(), severity, category, format_str, ##__VA_ARGS__) - -// Logging with category of "onnxruntime" - -#define LOGS_DEFAULT(severity) \ - LOGS_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGS_USER_DEFAULT(severity) \ - LOGS_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGF_DEFAULT(severity, format_str, ...) \ - LOGF_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - -#define LOGF_USER_DEFAULT(severity, format_str, ...) \ - LOGF_USER_DEFAULT_CATEGORY(severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - - /* - - Conditional logging - - */ - - // Logging with explicit category - -#define LOGS_CATEGORY_IF(boolean_expression, logger, severity, category) \ - if ((boolean_expression) == true) LOGS_CATEGORY(logger, severity, category) - -#define LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ - if ((boolean_expression) == true) LOGS_DEFAULT_CATEGORY(severity, category) - -#define LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, category) \ - if ((boolean_expression) == true) LOGS_USER_CATEGORY(logger, severity, category) - -#define LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category) \ - if ((boolean_expression) == true) LOGS_USER_DEFAULT_CATEGORY(severity, category) - -#define LOGF_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ - if ((boolean_expression) == true) LOGF_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__) - -#define LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ - if ((boolean_expression) == true) LOGF_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__) - -#define LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, category, format_str, ...) \ - if ((boolean_expression) == true) LOGF_USER_CATEGORY(logger, severity, category, format_str, ##__VA_ARGS__) - -#define LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, category, format_str, ...) \ - if ((boolean_expression) == true) LOGF_USER_DEFAULT_CATEGORY(severity, category, format_str, ##__VA_ARGS__) - - // Logging with category of "onnxruntime" - -#define LOGS_IF(boolean_expression, logger, severity) \ - LOGS_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGS_DEFAULT_IF(boolean_expression, severity) \ - LOGS_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGS_USER_IF(boolean_expression, logger, severity) \ - LOGS_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGS_USER_DEFAULT_IF(boolean_expression, severity) \ - LOGS_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime) - -#define LOGF_IF(boolean_expression, logger, severity, format_str, ...) \ - LOGF_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - -#define LOGF_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ - LOGF_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, format_str, ##__VA_ARGS__) - -#define LOGF_USER_IF(boolean_expression, logger, severity, format_str, ...) \ - LOGF_USER_CATEGORY_IF(boolean_expression, logger, severity, ::onnxruntime::logging::Category::onnxruntime, \ - format_str, ##__VA_ARGS__) - -#define LOGF_USER_DEFAULT_IF(boolean_expression, severity, format_str, ...) \ - LOGF_USER_DEFAULT_CATEGORY_IF(boolean_expression, severity, ::onnxruntime::logging::Category::onnxruntime, \ - format_str, ##__VA_ARGS__) - -/* - - Debug verbose logging of caller provided level. - Disabled in Release builds. - Use the _USER variants for VLOG statements involving user data that may need to be filtered. -*/ -#define VLOGS(logger, level) \ - if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ - LOGS_CATEGORY(logger, VERBOSE, "VLOG" #level) - -#define VLOGS_USER(logger, level) \ - if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ - LOGS_USER_CATEGORY(logger, VERBOSE, "VLOG" #level) - -#define VLOGF(logger, level, format_str, ...) \ - if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ - LOGF_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__) - -#define VLOGF_USER(logger, level, format_str, ...) \ - if (::onnxruntime::logging::vlog_enabled && level <= (logger).VLOGMaxLevel()) \ - LOGF_USER_CATEGORY(logger, VERBOSE, "VLOG" #level, format_str, ##__VA_ARGS__) - - // Default logger variants -#define VLOGS_DEFAULT(level) \ - VLOGS(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) - -#define VLOGS_USER_DEFAULT(level) \ - VLOGS_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level) - -#define VLOGF_DEFAULT(level, format_str, ...) \ - VLOGF(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) - -#define VLOGF_USER_DEFAULT(level, format_str, ...) \ - VLOGF_USER(::onnxruntime::logging::LoggingManager::DefaultLogger(), level, format_str, ##__VA_ARGS__) diff --git a/onnxruntime/inc/core/common/logging/severity.h b/onnxruntime/inc/core/common/logging/severity.h deleted file mode 100644 index e43f192..0000000 --- a/onnxruntime/inc/core/common/logging/severity.h +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -namespace onnxruntime { -namespace logging { -// mild violation of naming convention. the 'k' lets us use token concatenation in the macro -// ::onnxruntime::Logging::Severity::k##severity. It's not legal to have ::onnxruntime::Logging::Severity::##severity -// the uppercase makes the LOG macro usage look as expected for passing an enum value as it will be LOGS(logger, ERROR) -enum class Severity { - kVERBOSE = 0, - kINFO = 1, - kWARNING = 2, - kERROR = 3, - kFATAL = 4 -}; - -constexpr const char* SEVERITY_PREFIX = "VIWEF"; - -} // namespace logging -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/ml_status.h b/onnxruntime/inc/core/common/ml_status.h deleted file mode 100644 index 9f597c8..0000000 --- a/onnxruntime/inc/core/common/ml_status.h +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -namespace onnxruntime { - -enum class MLStatus : uint32_t { - OK = 0, - FAIL = 1, - INVALID_ARGUMENT = 2, - NO_SUCHFILE = 3, - NO_MODEL = 4, - ENGINE_ERROR = 5, - RUNTIME_EXCEPTION = 6, - INVALID_PROTOBUF = 7, - MODEL_LOADED = 8, - NOT_IMPLEMENTED = 9, - INVALID_GRAPH = 10, - SHAPE_INFERENCE_NOT_REGISTERED = 11, - REQUIREMENT_NOT_REGISTERED = 12 -}; - -inline const char* MLStatusToString(MLStatus status) noexcept { - switch (status) { - case MLStatus::OK: - return "SUCCESS"; - case MLStatus::INVALID_ARGUMENT: - return "INVALID_ARGUMENT"; - case MLStatus::NO_SUCHFILE: - return "NO_SUCHFILE"; - case MLStatus::NO_MODEL: - return "NO_MODEL"; - case MLStatus::ENGINE_ERROR: - return "ENGINE_ERROR"; - case MLStatus::RUNTIME_EXCEPTION: - return "RUNTIME_EXCEPTION"; - case MLStatus::INVALID_PROTOBUF: - return "INVALID_PROTOBUF"; - case MLStatus::MODEL_LOADED: - return "MODEL_LOADED"; - case MLStatus::NOT_IMPLEMENTED: - return "NOT_IMPLEMENTED"; - case MLStatus::INVALID_GRAPH: - return "INVALID_GRAPH"; - case MLStatus::SHAPE_INFERENCE_NOT_REGISTERED: - return "SHAPE_INFERENCE_NOT_REGISTERED"; - case MLStatus::REQUIREMENT_NOT_REGISTERED: - return "REQUIREMENT_NOT_REGISTERED"; - default: - return "GENERAL ERROR"; - } -} - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/common/status.h b/onnxruntime/inc/core/common/status.h deleted file mode 100644 index ad0d3ef..0000000 --- a/onnxruntime/inc/core/common/status.h +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Modifications Copyright (c) Microsoft. - -#pragma once - -#include -#include -#include "core/common/ml_status.h" - -namespace onnxruntime { -namespace common { - -enum StatusCategory { - NONE = 0, - SYSTEM = 1, - ONNXRUNTIME = 2, -}; - -/** - Error code for ONNXRuntime. -*/ -enum StatusCode { - OK = static_cast(MLStatus::OK), - FAIL = static_cast(MLStatus::FAIL), - INVALID_ARGUMENT = static_cast(MLStatus::INVALID_ARGUMENT), - NO_SUCHFILE = static_cast(MLStatus::NO_SUCHFILE), - NO_MODEL = static_cast(MLStatus::NO_MODEL), - ENGINE_ERROR = static_cast(MLStatus::ENGINE_ERROR), - RUNTIME_EXCEPTION = static_cast(MLStatus::RUNTIME_EXCEPTION), - INVALID_PROTOBUF = static_cast(MLStatus::INVALID_PROTOBUF), - MODEL_LOADED = static_cast(MLStatus::MODEL_LOADED), - NOT_IMPLEMENTED = static_cast(MLStatus::NOT_IMPLEMENTED), - INVALID_GRAPH = static_cast(MLStatus::INVALID_GRAPH), - SHAPE_INFERENCE_NOT_REGISTERED = static_cast(MLStatus::SHAPE_INFERENCE_NOT_REGISTERED), - REQUIREMENT_NOT_REGISTERED = static_cast(MLStatus::REQUIREMENT_NOT_REGISTERED), -}; - -class Status { - public: - Status() noexcept = default; - - Status(StatusCategory category, int code, const std::string& msg); - - Status(StatusCategory category, int code); - - Status(const Status& other) - : state_((other.state_ == nullptr) ? nullptr : std::make_unique(*other.state_)) {} - - Status& operator=(const Status& other) { - if (state_ != other.state_) { - if (other.state_ == nullptr) { - state_.reset(); - } else { - state_ = std::make_unique(*other.state_); - } - } - return *this; - } - - Status(Status&& other) = default; - Status& operator=(Status&& other) = default; - ~Status() = default; - - bool IsOK() const noexcept; - - int Code() const noexcept; - - StatusCategory Category() const noexcept; - - const std::string& ErrorMessage() const noexcept; - - std::string ToString() const; - - bool operator==(const Status& other) const { - return (this->state_ == other.state_) || (ToString() == other.ToString()); - } - - bool operator!=(const Status& other) const { - return !(*this == other); - } - - static const Status& OK() noexcept; - - private: - static const std::string& EmptyString() noexcept; - - struct State { - State(StatusCategory cat0, int code0, const std::string& msg0) - : category(cat0), code(code0), msg(msg0) {} - - const StatusCategory category; - const int code; - const std::string msg; - }; - - // As long as Code() is OK, state_ == nullptr. - std::unique_ptr state_; -}; - -inline std::ostream& operator<<(std::ostream& out, const Status& status) { - return out << status.ToString(); -} - -} // namespace common -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/alloc_kind.h b/onnxruntime/inc/core/framework/alloc_kind.h deleted file mode 100644 index 7ccb012..0000000 --- a/onnxruntime/inc/core/framework/alloc_kind.h +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -namespace onnxruntime { -// The ml-Values fall into the following categories with respect to their -// memory management: -// - inference inputs: owned (allocated and freed) by caller, and is by -// default read-only by the runtime. -// - inference outputs: allocated by runtime, ownership transferred to -// caller. TODO: Make sure this semantics is clear in InferenceSession API. -// - weights (constant tensors): can be allocated once (statically), and -// reused by all inference calls within an InferenceSession. -// - tensor values: The lifetimes of these tensor-values are statically -// determined, which is used for memory reuse/sharing optimizations. The -// runtime allocates/frees these values at the right time (as determined -// by the static allocation plan). Note that this is simplified since we -// do not try to optimize for "slice" like ops, where we may be able to -// conditionally reuse memory/data in some cases but not others. -// Generalizing this is future work. - -enum class AllocKind { - kAllocate = 0, - kReuse = 1, - kPreExisting = 2, - kAllocateStatically = 3, - kAllocateOutput = 4 -}; - -std::ostream& operator<<(std::ostream& out, AllocKind alloc_kind); -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/allocator.h b/onnxruntime/inc/core/framework/allocator.h deleted file mode 100644 index 085284d..0000000 --- a/onnxruntime/inc/core/framework/allocator.h +++ /dev/null @@ -1,192 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/exceptions.h" -#include "core/common/status.h" -#include "core/framework/fence.h" -#include "core/session/onnxruntime_c_api.h" - -struct OrtAllocatorInfo { - // use string for name, so we could have customized allocator in execution provider. - const char* name; - int id; - OrtMemType mem_type; - OrtAllocatorType type; - - constexpr OrtAllocatorInfo(const char* name1, OrtAllocatorType type, int id1 = 0, OrtMemType mem_type1 = OrtMemTypeDefault) -#if (defined(__GNUC__) || defined(__clang__)) - __attribute__((nonnull)) -#endif - : name(name1), - id(id1), - mem_type(mem_type1), - type(type) { - } - - inline bool operator==(const OrtAllocatorInfo& other) const { - return mem_type == other.mem_type && type == other.type && id == other.id && strcmp(name, other.name) == 0; - } - - // To make OrtAllocatorInfo become a valid key in std map - inline bool operator<(const OrtAllocatorInfo& other) const { - if (type != other.type) - return type < other.type; - if (mem_type != other.mem_type) - return mem_type < other.mem_type; - if (id != other.id) - return id < other.id; - - return strcmp(name, other.name) < 0; - } - - inline std::string ToString() const { - std::ostringstream ostr; - ostr << "OrtAllocatorInfo: [" - << " name:" << name - << " id:" << id - << " mem_type:" << mem_type - << " type:" << type - << "]"; - return ostr.str(); - } -}; - -std::ostream& operator<<(std::ostream& out, const OrtAllocatorInfo& info); - -namespace onnxruntime { -constexpr const char* CPU = "Cpu"; - -// forward declaration -class SessionState; - -template -using IAllocatorUniquePtr = std::unique_ptr>; - -class IAllocator { - public: - virtual ~IAllocator() = default; - virtual void* Alloc(size_t size) = 0; - virtual void Free(void* p) = 0; - virtual const OrtAllocatorInfo& Info() const = 0; - - /** - optional CreateFence interface, as provider like DML has its own fence - */ - virtual FencePtr CreateFence(const SessionState* /*unused*/) { return nullptr; } - - static bool CalcMemSizeForArray(size_t nmemb, size_t size, size_t* out) noexcept { - return CalcMemSizeForArrayWithAlignment<0>(nmemb, size, out); - } - - /** - * https://cwe.mitre.org/data/definitions/190.html - * \tparam alignment must be power of 2 - * \param nmemb - * \param size - * \param out - * \return true, successful. false, overflow - */ - template - static bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept ORT_MUST_USE_RESULT; - /** - * allocate memory for an array which has nmemb items of data, each size bytes long - */ - void* AllocArray(size_t nmemb, size_t size) { - size_t len; - if (!CalcMemSizeForArray(nmemb, size, &len)) - return nullptr; - return Alloc(len); - } - - /** - * allocate memory for an array which has nmemb items of data, each size bytes long - */ - template - void* AllocArrayWithAlignment(size_t nmemb, size_t size) { - size_t len; - if (!CalcMemSizeForArrayWithAlignment(nmemb, size, &len)) - return nullptr; - return Alloc(len); - } - - /** - Create a std::unique_ptr that is allocated and freed by the provided IAllocator. - @param allocator The allocator. - @param count_or_bytes The exact bytes to allocate if T is void, otherwise the number of elements to allocate. - @returns std::unique_ptr with allocated memory and deleter. - */ - template - static IAllocatorUniquePtr MakeUniquePtr(std::shared_ptr allocator, size_t count_or_bytes) { - if (allocator == nullptr) return nullptr; - // for now limit to fundamental types. we could support others, but to do so either we or the caller - // needs to call the dtor for the objects, for buffers allocated on device we don't have destructor - //static_assert(std::is_fundamental::value, "Fundamental type required as no destructors are called."); - - size_t alloc_size = count_or_bytes; - - // if T is not void, 'count_or_bytes' == number of items so allow for that - if (!std::is_void::value) { - // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't - // reachable if T is void. use std::conditional to 'use' void* in the sizeof call - if (!CalcMemSizeForArray(count_or_bytes, sizeof(typename std::conditional::value, void*, T>::type), - &alloc_size)) return nullptr; - } - return IAllocatorUniquePtr{ - static_cast(allocator->Alloc(alloc_size)), // allocate - [=](T* ptr) { allocator->Free(ptr); }}; // capture IAllocator so it's always valid, and use as deleter - } -}; - -template -bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t* out) noexcept { - static constexpr size_t max_allowed = (static_cast(1) << (static_cast(std::numeric_limits::digits >> 1))) - alignment; - static constexpr size_t max_size = std::numeric_limits::max() - alignment; - static constexpr size_t alignment_mask = alignment - 1; - //Indeed, we only need to check if max_size / nmemb < size - //max_allowed is for avoiding unnecessary DIV. - if (nmemb >= max_allowed && max_size / nmemb < size) { - return false; - } - if (size >= max_allowed && - nmemb > 0 && max_size / nmemb < size) { - return false; - } - if (alignment == 0) - *out = size * nmemb; - else - *out = (size * nmemb + alignment_mask) & ~static_cast(alignment_mask); - return true; -} - -/** - The resource allocator on a physical device. - This allocator will directly allocate resource from system call -*/ -class IDeviceAllocator : public IAllocator { - public: - ~IDeviceAllocator() override = default; - void* Alloc(size_t size) override = 0; - void Free(void* p) override = 0; - const OrtAllocatorInfo& Info() const override = 0; - virtual bool AllowsArena() const { return true; } -}; - -class CPUAllocator : public IDeviceAllocator { - public: - void* Alloc(size_t size) override; - void Free(void* p) override; - const OrtAllocatorInfo& Info() const override; -}; - -using AllocatorPtr = std::shared_ptr; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/custom_ops_author.h b/onnxruntime/inc/core/framework/custom_ops_author.h deleted file mode 100644 index f5a03a3..0000000 --- a/onnxruntime/inc/core/framework/custom_ops_author.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -/** - this header should include all the headers that are required to build a custom op so that - custom op developers don't have to worry about which headers to include, etc. -*/ -#include "core/framework/op_kernel.h" - -struct KernelsContainer { - std::vector<::onnxruntime::KernelCreateInfo> kernels_list; -}; - -struct SchemasContainer { - std::vector schemas_list; - std::string domain; - int baseline_opset_version; - int opset_version; -}; - -extern "C" { - KernelsContainer* GetAllKernels(); - SchemasContainer* GetAllSchemas(); - void FreeKernelsContainer(KernelsContainer*); - void FreeSchemasContainer(SchemasContainer*); -} diff --git a/onnxruntime/inc/core/framework/customregistry.h b/onnxruntime/inc/core/framework/customregistry.h deleted file mode 100644 index ffa7212..0000000 --- a/onnxruntime/inc/core/framework/customregistry.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/status.h" -#include "core/common/logging/logging.h" -#include "core/graph/schema_registry.h" -#include "core/framework/op_kernel.h" -#include "core/framework/kernel_def_builder.h" -#include "core/framework/kernel_registry.h" - -namespace onnxruntime { - -/** - Represents a registry that contains both custom kernels and custom schemas. -*/ -class CustomRegistry : public KernelRegistry, public onnxruntime::OnnxRuntimeOpSchemaRegistry { - public: - CustomRegistry() = default; - ~CustomRegistry() override = default; - - /** - * Register a kernel definition together with kernel factory method to this session. - * If any conflict happened between registered kernel def and built-in kernel def, - * registered kernel will have higher priority. - * Call this before invoking Initialize(). - * @return OK if success. - */ - common::Status RegisterCustomKernel(KernelDefBuilder& kernel_def_builder, const KernelCreateFn& kernel_creator); - - common::Status RegisterCustomKernel(KernelCreateInfo&); - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(CustomRegistry); -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/data_types.h b/onnxruntime/inc/core/framework/data_types.h deleted file mode 100644 index 69c44a6..0000000 --- a/onnxruntime/inc/core/framework/data_types.h +++ /dev/null @@ -1,603 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/exceptions.h" - -namespace ONNX_NAMESPACE { -class TypeProto; -} // namespace ONNX_NAMESPACE -namespace onnxruntime { -/// Predefined registered types -//maps -using MapStringToString = std::map; -using MapStringToInt64 = std::map; -using MapStringToFloat = std::map; -using MapStringToDouble = std::map; -using MapInt64ToString = std::map; -using MapInt64ToInt64 = std::map; -using MapInt64ToFloat = std::map; -using MapInt64ToDouble = std::map; - -//vectors/sequences -using VectorString = std::vector; -using VectorInt64 = std::vector; -using VectorFloat = std::vector; -using VectorDouble = std::vector; -using VectorMapStringToFloat = std::vector; -using VectorMapInt64ToFloat = std::vector; - -class DataTypeImpl; -class TensorTypeBase; - -// MLFloat16 -union MLFloat16 { - uint16_t val; - - explicit MLFloat16(uint16_t x) : val(x) {} - MLFloat16() : val(0) {} -}; - -inline bool operator==(const MLFloat16& left, const MLFloat16& right) { - return left.val == right.val; -} - -inline bool operator!=(const MLFloat16& left, const MLFloat16& right) { - return left.val != right.val; -} - -struct ort_endian { - union q { - uint16_t v_; - uint8_t b_[2]; - constexpr explicit q(uint16_t v) noexcept : v_(v) {} - }; - static constexpr bool is_little() { - return q(0x200).b_[0] == 0x0; - } - static constexpr bool is_big() { - return q(0x200).b_[0] == 0x2; - } -}; - -//BFloat16 -struct BFloat16 { - uint16_t val{0}; - explicit BFloat16() {} - explicit BFloat16(uint16_t v) : val(v) {} - explicit BFloat16(float v) { - uint16_t* dst = reinterpret_cast(&v); - if (ort_endian::is_little()) { - val = dst[1]; - } else { - val = dst[0]; - } - } - float ToFloat() const { - float result; - uint16_t* dst = reinterpret_cast(&result); - if (ort_endian::is_little()) { - dst[1] = val; - dst[0] = 0; - } else { - dst[0] = val; - dst[1] = 0; - } - return result; - } -}; - -inline void BFloat16ToFloat(const BFloat16* blf, float* flt, size_t size) { - auto src = blf; - auto d = flt; - for (; size != 0; ++src, ++d, --size) { - *d = src->ToFloat(); - } -} - -inline void FloatToBFloat16(const float* flt, BFloat16* blf, size_t size) { - auto src = flt; - auto d = blf; - for (; size != 0; ++src, ++d, --size) { - new (d) BFloat16(*src); - } -} - -inline bool operator==(const BFloat16& left, const BFloat16& right) { - return left.val == right.val; -} - -inline bool operator!=(const BFloat16& left, const BFloat16& right) { - return left.val != right.val; -} - -// DataTypeImpl pointer as unique DataTypeImpl identifier. -using MLDataType = const DataTypeImpl*; -// be used with class MLValue -using DeleteFunc = void (*)(void*); -using CreateFunc = std::function; - -/** - * \brief Base class for MLDataType - * - */ -class DataTypeImpl { - public: - virtual ~DataTypeImpl() = default; - - /** - * \brief this API will be used to check type compatibility at runtime - * - * \param type_proto a TypeProto instance that is constructed for a specific type - * will be checked against a TypeProto instance contained within a corresponding - * MLDataType instance. - */ - virtual bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const = 0; - - virtual size_t Size() const = 0; - - virtual DeleteFunc GetDeleteFunc() const = 0; - - /** - * \brief Retrieves an instance of TypeProto for - * a given MLDataType - * \returns optional TypeProto. Only ONNX types - has type proto, non-ONNX types will return nullptr. - */ - virtual const ONNX_NAMESPACE::TypeProto* GetTypeProto() const = 0; - - virtual bool IsTensorType() const { - return false; - } - - // Returns this if this is of tensor-type and null otherwise - virtual const TensorTypeBase* AsTensorType() const { - return nullptr; - } - - // Return the type meta that we are using in the runtime. - template - static MLDataType GetType(); - - // Return the types for a concrete tensor type, like Tensor_Float - template - static MLDataType GetTensorType(); - - /** - * Convert an ONNX TypeProto to onnxruntime DataTypeImpl. - * However, this conversion is lossy. Don't try to use 'this->GetTypeProto()' converting it back - * Don't pass the returned value to MLValue::MLValue(...) function - * \param proto - */ - static MLDataType TypeFromProto(const ONNX_NAMESPACE::TypeProto& proto); - - // Registers ONNX_NAMESPACE::DataType (internalized string) with - // MLDataType. DataType is produced by internalizing an instance of - // TypeProto contained within MLDataType - static void RegisterDataType(MLDataType); - - static const std::vector& AllTensorTypes(); - static const std::vector& AllFixedSizeTensorTypes(); -}; - -std::ostream& operator<<(std::ostream& out, MLDataType data_type); - -/* - * Type registration helpers - */ -namespace data_types_internal { -/// TensorType helpers -/// - -// There is a specialization only for one -// type argument. -template -struct TensorContainedTypeSetter { - static void SetTensorElementType(ONNX_NAMESPACE::TypeProto&); - static void SetMapKeyType(ONNX_NAMESPACE::TypeProto&); -}; - -/// Is a given type on the list of types? -/// Accepts a list of types and the first argument is the type -/// We are checking if it is listed among those that follow -template -struct IsAnyOf; - -/// Two types remaining, end of the list -template -struct IsAnyOf : public std::is_same { -}; - -template -struct IsAnyOf { - static constexpr bool value = (std::is_same::value || - IsAnyOf::value); -}; - -/// Tells if the specified type is one of fundamental types -/// that can be contained within a tensor. -/// We do not have raw fundamental types, rather a subset -/// of fundamental types is contained within tensors. -template -struct IsTensorContainedType : public IsAnyOf { -}; - -/// This template's Get() returns a corresponding MLDataType -/// It dispatches the call to either GetTensorType<>() or -/// GetType<>() -template -struct GetMLDataType; - -template -struct GetMLDataType { - static MLDataType Get() { - return DataTypeImpl::GetTensorType(); - } -}; - -template -struct GetMLDataType { - static MLDataType Get() { - return DataTypeImpl::GetType(); - } -}; - -/// MapTypes helper API -/// K should always be one of the primitive data types -/// V can be either a primitive type (in which case it is a tensor) -/// or other preregistered types - -void CopyMutableMapValue(const ONNX_NAMESPACE::TypeProto&, - ONNX_NAMESPACE::TypeProto&); - -template -struct SetMapTypes { - static void Set(ONNX_NAMESPACE::TypeProto& proto) { - TensorContainedTypeSetter::SetMapKeyType(proto); - MLDataType dt = GetMLDataType::value>::Get(); - const auto* value_proto = dt->GetTypeProto(); - ORT_ENFORCE(value_proto != nullptr, typeid(V).name(), - " expected to be a registered ONNX type"); - CopyMutableMapValue(*value_proto, proto); - } -}; - -/// Sequence helpers -/// -// Element type is a primitive type so we set it to a tensor -void CopyMutableSeqElement(const ONNX_NAMESPACE::TypeProto&, - ONNX_NAMESPACE::TypeProto&); - -template -struct SetSequenceType { - static void Set(ONNX_NAMESPACE::TypeProto& proto) { - MLDataType dt = GetMLDataType::value>::Get(); - const auto* elem_proto = dt->GetTypeProto(); - ORT_ENFORCE(elem_proto != nullptr, typeid(T).name(), - " expected to be a registered ONNX type"); - CopyMutableSeqElement(*elem_proto, proto); - } -}; - -/// OpaqueTypes helpers -/// -void AssignOpaqueDomainName(const char* domain, const char* name, - ONNX_NAMESPACE::TypeProto& proto); - -} // namespace data_types_internal - -/// All tensors base -class TensorTypeBase : public DataTypeImpl { - public: - static MLDataType Type(); - - /// We first compare type_proto pointers and then - /// if they do not match try to account for the case - /// where TypeProto was created ad-hoc and not queried from MLDataType - bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override; - - bool IsTensorType() const override { - return true; - } - - const TensorTypeBase* AsTensorType() const override { - return this; - } - - size_t Size() const override; - - DeleteFunc GetDeleteFunc() const override; - - const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; - - virtual MLDataType GetElementType() const { - // should never reach here. - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); - } - - TensorTypeBase(const TensorTypeBase&) = delete; - TensorTypeBase& operator=(const TensorTypeBase&) = delete; - - protected: - ONNX_NAMESPACE::TypeProto& mutable_type_proto(); - - TensorTypeBase(); - ~TensorTypeBase() override; - - private: - struct Impl; - Impl* impl_; -}; - -/** - * \brief Tensor type. This type does not have a C++ type associated with - * it at registration time except the element type. One of the types mentioned - * above at IsTensorContainedType<> list is acceptable. - * - * \details - * Usage: - * ORT_REGISTER_TENSOR(ELEMENT_TYPE) - * Currently all of the Tensors irrespective of the dimensions are mapped to Tensor - * type. IsCompatible() currently ignores shape. - */ - -template -class TensorType : public TensorTypeBase { - public: - static_assert(data_types_internal::IsTensorContainedType::value, - "Requires one of the tensor fundamental types"); - - static MLDataType Type(); - - /// Tensors only can contain basic data types - /// that have been previously registered with ONNXRuntime - MLDataType GetElementType() const override { - return DataTypeImpl::GetType(); - } - - private: - TensorType() { - using namespace data_types_internal; - TensorContainedTypeSetter::SetTensorElementType(this->mutable_type_proto()); - } -}; - -/** - * \brief Base type for all non-tensors, maps, sequences and opaques - */ -class NonTensorTypeBase : public DataTypeImpl { - public: - size_t Size() const override = 0; - - DeleteFunc GetDeleteFunc() const override = 0; - - virtual CreateFunc GetCreateFunc() const = 0; - - const ONNX_NAMESPACE::TypeProto* GetTypeProto() const override; - - NonTensorTypeBase(const NonTensorTypeBase&) = delete; - NonTensorTypeBase& operator=(const NonTensorTypeBase&) = delete; - - protected: - NonTensorTypeBase(); - ~NonTensorTypeBase() override; - - ONNX_NAMESPACE::TypeProto& mutable_type_proto(); - - bool IsMapCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; - - bool IsSequenceCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; - - bool IsOpaqueCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const; - - private: - struct Impl; - Impl* impl_; -}; - -// This is where T is the actual CPPRuntimeType -template -class NonTensorType : public NonTensorTypeBase { - private: - static void Delete(void* p) { - delete static_cast(p); - } - - public: - size_t Size() const override { - return sizeof(T); - } - - DeleteFunc GetDeleteFunc() const override { - return &Delete; - } - - CreateFunc GetCreateFunc() const override { - return []() { return new T(); }; - } - - protected: - NonTensorType() = default; -}; - -/** - * \brief MapType. Use this type to register - * mapping types. - * - * \param T - cpp type that you wish to register as runtime MapType - * - * \details Usage: ORT_REGISTER_MAP(C++Type) - * The type is required to have mapped_type and - * key_type defined - */ -template -class MapType : public NonTensorType { - public: - static_assert(data_types_internal::IsTensorContainedType::value, - "Requires one of the tensor fundamental types as key"); - - static MLDataType Type(); - - bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { - return this->IsMapCompatible(type_proto); - } - - private: - MapType() { - using namespace data_types_internal; - SetMapTypes::Set(this->mutable_type_proto()); - } -}; - -/** - * \brief SequenceType. Use to register sequences. - * - * \param T - CPP type that you wish to register as Sequence - * runtime type. - * - * \details Usage: ORT_REGISTER_SEQ(C++Type) - * The type is required to have value_type defined - */ -template -class SequenceType : public NonTensorType { - public: - static MLDataType Type(); - - bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { - return this->IsSequenceCompatible(type_proto); - } - - private: - SequenceType() { - data_types_internal::SetSequenceType::Set(this->mutable_type_proto()); - } -}; - -/** - * \brief OpaqueType - * - * \param T - cpp runtume that implements the Opaque type - * - * \param const char D[] - domain must be extern to be unique - * - * \param const char N[] - name must be extern to be unique - * - * \details Only one CPP type can be associated with a particular - * OpaqueType registration - * - */ -template -class OpaqueType : public NonTensorType { - public: - static MLDataType Type(); - - bool IsCompatible(const ONNX_NAMESPACE::TypeProto& type_proto) const override { - return this->IsOpaqueCompatible(type_proto); - } - - private: - OpaqueType() { - data_types_internal::AssignOpaqueDomainName(D, N, this->mutable_type_proto()); - } -}; - -template -class NonOnnxType : public DataTypeImpl { - private: - static void Delete(void* p) { - delete static_cast(p); - } - - public: - bool IsCompatible(const ONNX_NAMESPACE::TypeProto&) const override { - return false; - } - - static MLDataType Type(); - - size_t Size() const override { - return sizeof(T); - } - - DeleteFunc GetDeleteFunc() const override { - return &Delete; - } - - const ONNX_NAMESPACE::TypeProto* GetTypeProto() const final { - return nullptr; - } - - private: - NonOnnxType() = default; -}; - -// Explicit specialization of base class template function -// is only possible within the enclosing namespace scope, -// thus a simple way to pre-instantiate a given template -// at a registration time does not currently work and the macro -// is needed. -#define ORT_REGISTER_TENSOR_TYPE(ELEM_TYPE) \ - template <> \ - MLDataType TensorType::Type() { \ - static TensorType tensor_type; \ - return &tensor_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetTensorType() { \ - return TensorType::Type(); \ - } - -#define ORT_REGISTER_MAP(TYPE) \ - template <> \ - MLDataType MapType::Type() { \ - static MapType map_type; \ - return &map_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return MapType::Type(); \ - } - -#define ORT_REGISTER_SEQ(TYPE) \ - template <> \ - MLDataType SequenceType::Type() { \ - static SequenceType sequence_type; \ - return &sequence_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return SequenceType::Type(); \ - } - -#define ORT_REGISTER_NON_ONNX_TYPE(TYPE) \ - template <> \ - MLDataType NonOnnxType::Type() { \ - static NonOnnxType non_onnx_type; \ - return &non_onnx_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return NonOnnxType::Type(); \ - } - -#define ORT_REGISTER_OPAQUE_TYPE(CPPType, Domain, Name) \ - template <> \ - MLDataType OpaqueType::Type() { \ - static OpaqueType opaque_type; \ - return &opaque_type; \ - } \ - template <> \ - MLDataType DataTypeImpl::GetType() { \ - return OpaqueType::Type(); \ - } -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/environment.h b/onnxruntime/inc/core/framework/environment.h deleted file mode 100644 index f36ebb6..0000000 --- a/onnxruntime/inc/core/framework/environment.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/common/common.h" -#include "core/common/status.h" - -namespace onnxruntime { -/** - Provides the runtime environment for onnxruntime. - Create one instance for the duration of execution. -*/ -class Environment { - public: - /** - Create and initialize the runtime environment. - */ - static Status Create(std::unique_ptr& environment); - - /** - This function will call ::google::protobuf::ShutdownProtobufLibrary - */ - ~Environment(); - - /** - Returns whether any runtime environment instance has been initialized. - */ - static bool IsInitialized() { return is_initialized_; } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment); - - Environment() = default; - Status Initialize(); - - static std::atomic is_initialized_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/execution_provider.h b/onnxruntime/inc/core/framework/execution_provider.h deleted file mode 100644 index 48fbc43..0000000 --- a/onnxruntime/inc/core/framework/execution_provider.h +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/status.h" -#include "core/framework/tensor.h" -#include "core/framework/func_api.h" - -namespace onnxruntime { -class GraphViewer; -class Node; -} // namespace onnxruntime -namespace onnxruntime { - -struct ComputeCapability; -class KernelRegistry; -class KernelRegistryManager; - -/** - Logical device representation. -*/ -typedef std::map AllocatorMap; - -// if we are export the fused function to dll, the function will still in the same binary as lotus -// use std function to give execution provider some chance to capture some state. -using CreateFunctionStateFunc = std::function; -using ComputeFunc = std::function; -using DestroyFunctionStateFunc = std::function; - -struct NodeComputeInfo { - CreateFunctionStateFunc create_state_func; - ComputeFunc compute_func; - DestroyFunctionStateFunc release_state_func; -}; - -class IExecutionProvider { - public: - virtual ~IExecutionProvider() = default; - - /** - Get all IAllocators for <*this> execution provider. - */ - std::vector GetAllocatorMap() const { - std::vector values; - for (auto& kv : allocators_) { - values.push_back(kv.second); - } - return values; - } - - /** - Get allocator with specified MemType - */ - virtual AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const; - - /** - Get execution provider's capability for the specified . - Return a bunch of IndexedSubGraphs <*this> execution provider can run if - the sub-graph contains only one node or can fuse to run if the sub-graph - contains more than one node. The node indexes contained in sub-graphs may - have overlap, and it's ONNXRuntime's responsibility to do the partition - and decide whether a node will be assigned to <*this> execution provider. - */ - virtual std::vector> - GetCapability(const onnxruntime::GraphViewer& graph_viewer, - const std::vector& kernel_registries) const; - - /** - Get kernel registry per execution provider type. - The KernelRegistry share pointer returned is shared across sessions. - - NOTE: this is a tricky but final solution to achieve following goals, - 1. The execution provider type based kernel registry should be shared - across sessions. - Only one copy of this kind of kernel registry exists in ONNXRuntime - with multiple sessions/models. - 2. Adding an execution provider into ONNXRuntime does not need to touch ONNXRuntime - frameowrk/session code. - 3. onnxruntime (framework/session) does not depend on any specific - execution provider lib. - */ - virtual std::shared_ptr GetKernelRegistry() const = 0; - - /** - Copy tensor between execution providers - */ - virtual common::Status CopyTensor(const Tensor& src, Tensor& dst) const = 0; - - /** - Copy tensor between execution providers on specified exec queue - */ - virtual common::Status CopyTensor(const Tensor& src, Tensor& dst, - int exec_queue_id) const; - - /** - Returns an opaque handle whose exact type varies based on the provider - and is interpreted accordingly by the corresponding kernel implementation. - For Direct3D operator kernels, this may return an IUnknown supporting - QueryInterface to ID3D12GraphicsCommandList1. - */ - virtual const void* GetExecutionHandle() const noexcept = 0; - - /** - @return type of the execution provider; should match that set in the node - through the SetExecutionProvider API. Example valid return values are: - kCpuExecutionProvider, kCudaExecutionProvider - */ - virtual std::string Type() const = 0; - - /** - Blocks until the device has completed all preceding requested tasks. - Currently this is primarily used by the IOBinding object to ensure that all - inputs have been copied to the device before execution begins. - */ - virtual common::Status Sync() const; - - /** - Called when InferenceSession::Run started - NOTE that due to async execution in provider, the actual work of previous - Run may not be finished on device This function should be regarded as the - point after which a new Run would start to submit commands from CPU - */ - virtual common::Status OnRunStart(); - - /** - Called when InferenceSession::Run ended - NOTE that due to async execution in provider, the actual work of this Run - may not be finished on device This function should be regarded as the point - that all commands of current Run has been submmited by CPU - */ - virtual common::Status OnRunEnd(); - - void InsertAllocator(AllocatorPtr allocator); - - /** - Given a list of fused_node, return create_state/compute/release_state func for each node. - */ - virtual common::Status Compile(const std::vector& fused_node, - std::vector& node_compute_funcs); - - /** - Given a list of fused_node, return a dll that expose functions for each node. - For each node, there should be three symbols: - Create_State_${node_name} - Compute_${node_name} - Release_State_${node_name} - */ - virtual common::Status Compile(const std::vector& fused_node, - std::string& dll_path); - - private: - AllocatorMap allocators_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/fence.h b/onnxruntime/inc/core/framework/fence.h deleted file mode 100644 index 2f103fc..0000000 --- a/onnxruntime/inc/core/framework/fence.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/graph/basic_types.h" - -namespace onnxruntime { - -/* - We use a simple fence mechanism for async compute. Assumptions in this fence mechanism: - * Execution provider command queues, which execute in the same order of submit - * No fence needed for kernels within one execution provider command queue - * Fence is used to synchronize between command queues, and execution providers - - Fence usage: - 1. Fence object would be created by allocation planer for input/output when KernelDef::ExecQueueId() is not zero - 2. If fence object exists, executor would call BeforeUsingAs* prior to kernel::Compute(), and AfterUsedAs* afterwards -*/ -class IFence { - public: - virtual ~IFence() = default; - - /** - Called by executor before MLValue is used as input in a compute kernel in provider_type and exec queue_id - This should wait in the specified provider's exec queue for previous write to MLValue to finish - */ - virtual void BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int queue_id) = 0; - - /** - Called by executor before MLValue is used as output in a compute kernel in provider_type and exec queue_id - This should wait in the specified provider's exec queue for previous read to MLValue to finish - */ - virtual void BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) = 0; - - /** - Called by executor after MLValue is used as input in a compute kernel in provider_type and exec queue_id - This should update the read fence of the MLValue - */ - virtual void AfterUsedAsInput(int queue_id) = 0; - - /** - Called by executor after MLValue is used as output in a compute kernel in provider_type and exec queue_id - This should update the write fence of the MLValue - */ - virtual void AfterUsedAsOutput(int queue_id) = 0; -}; -using Fence_t = IFence*; -using FencePtr = std::shared_ptr; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/framework_common.h b/onnxruntime/inc/core/framework/framework_common.h deleted file mode 100644 index c3336f2..0000000 --- a/onnxruntime/inc/core/framework/framework_common.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include "run_options.h" - -namespace onnxruntime { // forward declarations -class Model; -class GraphTransformer; -class NodeArg; -} // namespace onnxruntime - -namespace onnxruntime { -class MLValue; -using InputDefList = std::vector; -using OutputDefList = std::vector; - -using NameMLValMap = std::unordered_map; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/func_api.h b/onnxruntime/inc/core/framework/func_api.h deleted file mode 100644 index 6066453..0000000 --- a/onnxruntime/inc/core/framework/func_api.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once -#include "core/common/common.h" -namespace onnxruntime { -//TODO: Should use the lotus cpi element type definition. -enum DType { - TFloat32 = 0, - TInt32 = 1, - TDouble = 2 - //TODO: more types -}; - -typedef struct { - void* data; - /*! \brief Number of dimensions */ - size_t ndim; - /*! \brief The data type of the pointer*/ - DType dtype; - /*! \brief The shape of the tensor */ - int64_t* shape; -} ONNXRunTimeTensor; - -// AllocateFunc(void* handle, size_t alignment, size_t size) -using AllocateFunc = void* (*)(void*, size_t, size_t); -using DestroyFunc = void (*)(void*, void*); -using AllocatorHandle = void*; - -typedef struct { - //right now we only include allocation for host memory - AllocateFunc allocate_func; - DestroyFunc release_func; - AllocatorHandle allocator_handle; - const char* node_name; -} ComputeContext; - -using FunctionState = void*; -// take the ComputeContext, and create a function state. -using CreateFunctionStateC = int (*)(ComputeContext*, FunctionState*); -// pass in the function state and input/output tensors, perform compute and return status code, 0 - succeed. -using ComputeFuncC = int (*)(FunctionState, ONNXRunTimeTensor*, size_t, ONNXRunTimeTensor*, size_t); -// release the function state. -using DestroyFunctionStateC = void (*)(FunctionState); -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/kernel_def_builder.h b/onnxruntime/inc/core/framework/kernel_def_builder.h deleted file mode 100644 index 90edd3b..0000000 --- a/onnxruntime/inc/core/framework/kernel_def_builder.h +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/graph/basic_types.h" -#include "core/framework/data_types.h" -#include "core/framework/allocator.h" - -namespace onnxruntime { -class KernelDefBuilder; - -typedef std::map MemTypeMap; - -// note that input/output might be on CPU implicitly when the node is from CPU execution provider -inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { - return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; -} - -class KernelDef { - public: - explicit KernelDef() { - } - - const std::string& OpName() const { - return op_name_; - } - - const std::string& Domain() const { - return op_domain_; - } - - void SinceVersion(/*out*/ int* start, /*out*/ int* end) const { - *start = op_since_version_start_; - *end = op_since_version_end_; - } - - onnxruntime::ProviderType Provider() const { - return provider_type_; - } - - const std::unordered_map>& TypeConstraints() const { - return type_constraints_; - } - - const std::vector>& MayInplace() const { - return inplace_map_; - } - - const std::vector>& Alias() const { - return alias_map_; - } - - OrtMemType InputMemoryType(size_t input_index) const { - auto it = input_memory_type_args_.find(input_index); - if (it == input_memory_type_args_.end()) - return default_inputs_mem_type_; - return it->second; - } - - OrtMemType OutputMemoryType(size_t output_index) const { - auto it = output_memory_type_args_.find(output_index); - if (it == output_memory_type_args_.end()) - return default_outputs_mem_type_; - return it->second; - } - - int ExecQueueId() const { - return exec_queue_id_; - } - - bool IsConflict(const KernelDef& other) const; - - private: - friend class KernelDefBuilder; - - // The operator name supported by <*this> kernel.. - std::string op_name_; - - // The operator since_version range supported by <*this> kernel. - // A kernel could support an operator definition between - // and (inclusive). - int op_since_version_start_ = 1; - int op_since_version_end_ = INT_MAX; - - // The operator domain supported by <*this> kernel. - // Default to 'onnxruntime::kOnnxDomain'. - // Please note the behavior of std::string("") and std::string() are different - std::string op_domain_; - - // The type of the execution provider. - std::string provider_type_; - - // The supported data types for inputs/outputs. - // Key is input/output name defined in op schema, Value are supported types. - std::unordered_map> type_constraints_; - - // An element means that output j reuses the memory of input i. - std::vector> inplace_map_; - - // An element means that output j is an alias of input i. - std::vector> alias_map_; - - // The memory types of inputs/outputs of this kernel - MemTypeMap input_memory_type_args_; - MemTypeMap output_memory_type_args_; - - // execution command queue id, 0 for default queue in execution provider - int exec_queue_id_ = 0; - // Default memory type for all inputs - OrtMemType default_inputs_mem_type_{OrtMemTypeDefault}; - // Default memory type for all outputs - OrtMemType default_outputs_mem_type_{OrtMemTypeDefault}; -}; - -class KernelDefBuilder { - public: - explicit KernelDefBuilder() - : kernel_def_(new KernelDef()) {} - - KernelDefBuilder& SetName(const std::string& op_name); - KernelDefBuilder& SetName(const char* op_name); - - KernelDefBuilder& SetDomain(const std::string& domain); - KernelDefBuilder& SetDomain(const char* domain); - - /** - This kernel supports operator definition since (to latest). - */ - KernelDefBuilder& SinceVersion(int since_version) { - kernel_def_->op_since_version_start_ = since_version; - return *this; - } - - /** - The start and end version should be set accordingly per version range for - each domain registered in OpSchemaRegistry::DomainToVersionRange in - \onnxruntime\onnxruntime\core\graph\op.h as below. - Key: domain. Value: pair. - std::unordered_map> map_; - */ - KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) { - kernel_def_->op_since_version_start_ = since_version_start; - kernel_def_->op_since_version_end_ = since_version_end; - return *this; - } - - /** - The execution provider type of the kernel. - */ - KernelDefBuilder& Provider(onnxruntime::ProviderType provider_type); - KernelDefBuilder& Provider(const char* provider_type); - - /** - Specify the set of types that this kernel supports. A further restriction - of the set of types specified in the op schema. - The arg name could be either op formal parameter name, say "X", or type - argument name specified in op schema, say "T". - */ - KernelDefBuilder& TypeConstraint(const std::string& arg_name, - const std::vector& supported_types); - KernelDefBuilder& TypeConstraint(const char* arg_name, - const std::vector& supported_types); - - /** - Like TypeConstraint but supports just a single type. - */ - KernelDefBuilder& TypeConstraint(const std::string& arg_name, MLDataType supported_type); - KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType supported_type); - - /** - Inplace mapping from inputs to outputs allowed. - It means that uplayer runtime could do memory in-place optimization - as it will not impact the correctness of this kernel. - */ - KernelDefBuilder& MayInplace(const std::vector>& inplaces); - KernelDefBuilder& MayInplace(int input_index, int output_index); - - /** - Alias mapping from inputs to outputs. Different from Inplace that the - content of the tensor is not changed. This is to take care of operators - such as Identity and Reshape. - */ - KernelDefBuilder& Alias(const std::vector>& aliases); - KernelDefBuilder& Alias(int input_index, int output_index); - - /** - Specify that this kernel requires an input arg - in certain memory type (instead of the default, device memory). - */ - template - KernelDefBuilder& InputMemoryType(int input_index) { - kernel_def_->input_memory_type_args_.insert(std::make_pair(input_index, T)); - return *this; - } - - /** - Specify that this kernel provides an output arg - in certain memory type (instead of the default, device memory). - */ - template - KernelDefBuilder& OutputMemoryType(int output_index) { - kernel_def_->output_memory_type_args_.insert(std::make_pair(output_index, T)); - return *this; - } - - /** - Specify that this kernel runs on which execution queue in the provider - */ - KernelDefBuilder& ExecQueueId(int queue_id) { - kernel_def_->exec_queue_id_ = queue_id; - return *this; - } - - /** - Specify the default inputs memory type, if not specified, it is DefaultMemory - */ - KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) { - kernel_def_->default_inputs_mem_type_ = mem_type; - return *this; - } - - /** - Specify the default outputs memory type, if not specified, it is DefaultMemory - */ - KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) { - kernel_def_->default_outputs_mem_type_ = mem_type; - return *this; - } - - /** - Return the kernel definition, passing ownership of the KernelDef to the caller - */ - std::unique_ptr Build() { - return std::move(kernel_def_); - } - - private: - // we own the KernelDef until Build() is called. - std::unique_ptr kernel_def_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/kernel_registry.h b/onnxruntime/inc/core/framework/kernel_registry.h deleted file mode 100644 index b9f94e5..0000000 --- a/onnxruntime/inc/core/framework/kernel_registry.h +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/op_kernel.h" - -namespace onnxruntime { -class KernelRegistry { - public: - KernelRegistry() = default; - - // Register a kernel with kernel definition and function to create the kernel. - Status Register(KernelDefBuilder& kernel_def_builder, - const KernelCreateFn& kernel_creator); - - Status Register(KernelCreateInfo&& create_info); - - // Mainly for provide debug info - std::vector GetAllRegisteredOpNames() const; - - // factory functions should always return a unique_ptr for maximum flexibility - // for its clients unless the factory is managing the lifecycle of the pointer - // itself. - // TODO(Task:132) Make usage of unique_ptr/shared_ptr as out param consistent - Status CreateKernel(const onnxruntime::Node& node, - const IExecutionProvider& execution_provider, - const SessionState& session_state, - std::unique_ptr& op_kernel) const; - - // Check if an execution provider can create kernel for a node and return - // the kernel if so - const KernelCreateInfo* TryFindKernel(const onnxruntime::Node& node, - onnxruntime::ProviderType exec_provider) const; - - private: - // Check if the node's input/outpuData/attributes are compatible with this - // kernel_def, If so, the kernel defined by the kernel_def is used to - // execute this node. exec_provider is used to match kernel when node has no provider - static bool VerifyKernelDef(const onnxruntime::Node& node, - const KernelDef& kernel_def, - std::string& error_str, - onnxruntime::ProviderType exec_provider = ""); - - // Kernel create function map from op name to kernel creation info. - KernelCreateMap kernel_creator_fn_map_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/ml_value.h b/onnxruntime/inc/core/framework/ml_value.h deleted file mode 100644 index b545686..0000000 --- a/onnxruntime/inc/core/framework/ml_value.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "core/common/common.h" -#include "core/common/exceptions.h" -#include "core/framework/allocator.h" -#include "core/framework/data_types.h" -#include "core/framework/tensor.h" - -namespace onnxruntime { -/** - Represents both tensors and non-tensors. -*/ -class MLValue { - public: - MLValue() : data_(nullptr) {} - virtual ~MLValue() = default; - - MLValue(void* pData, MLDataType type, DeleteFunc deleter) { - Init(pData, type, deleter); - } - - void Init(void* pData, MLDataType type, DeleteFunc deleter) { - data_.reset(pData, deleter); - type_ = type; - } - - bool IsAllocated() const { - return data_ && type_; - } - - template - const T& Get() const { - ORT_ENFORCE(DataTypeImpl::GetType() == type_, DataTypeImpl::GetType(), " != ", type_); - return *static_cast(data_.get()); - } - - template - T* GetMutable() { - ORT_ENFORCE(DataTypeImpl::GetType() == type_, DataTypeImpl::GetType(), " != ", type_); - return static_cast(data_.get()); - } - - bool IsTensor() const noexcept { - return DataTypeImpl::GetType() == type_; - } - - MLDataType Type() const { - return type_; - } - - Fence_t Fence() const { - return fence_.get(); - } - - void SetFence(FencePtr fence) { - fence_ = fence; - } - - void ShareFenceWith(MLValue& v) { - fence_ = v.fence_; - } - - private: - std::shared_ptr data_; - MLDataType type_{nullptr}; - FencePtr fence_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/op_kernel.h b/onnxruntime/inc/core/framework/op_kernel.h deleted file mode 100644 index 98b2fb2..0000000 --- a/onnxruntime/inc/core/framework/op_kernel.h +++ /dev/null @@ -1,313 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/common/exceptions.h" -#include "core/common/logging/logging.h" -#include "core/common/status.h" -#include "core/framework/execution_provider.h" -#include "core/framework/kernel_def_builder.h" -#include "core/framework/ml_value.h" -#include "core/framework/op_kernel_info.h" -#include "core/framework/op_node_proto_helper.h" -#include "core/framework/tensor.h" -#include "core/graph/constants.h" -#include "core/graph/graph_viewer.h" -#include "gsl/span" -#include "onnx/defs/schema.h" - -namespace onnxruntime { -class ExecutionFrame; -class OpKernelContext; -class OpKernelWrapper; - -class OpKernel { - public: - using DoneCallback = std::function; - - explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_(info) {} - virtual ~OpKernel() = default; - - const onnxruntime::Node& Node() const { - return op_kernel_info_.node(); - } - - const ::onnxruntime::KernelDef& KernelDef() const { - return op_kernel_info_.GetKernelDef(); - } - - virtual Status Compute(OpKernelContext* context) const = 0; - - virtual Status ComputeAsync(OpKernelContext*, - DoneCallback) const { - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); - } - - const OrtAllocatorInfo& Allocator(int id, OrtMemType mem_type) const { - return op_kernel_info_.GetAllocatorInfo(id, mem_type); - } - - const OpKernelInfo& Info() const { return op_kernel_info_; } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); - OpKernelInfo op_kernel_info_; -}; - -class OpKernelContext { - public: - using ArgMap = std::unordered_map; - - explicit OpKernelContext(ExecutionFrame* frame, - const OpKernel* kernel, - const logging::Logger& logger); - - virtual ~OpKernelContext() = default; - - /** - Return the number of inputs for a variadic argument. - @param arg_num The operator argument number. - @returns Number of inputs the argument has. - */ - int NumVariadicInputs(size_t arg_num) const; - - MLDataType InputType(int index) const; - MLDataType OutputType(int index) const; - - template - const T* Input(int index) const { - const MLValue* p_ml_value = GetInputMLValue(index); - return p_ml_value ? &(p_ml_value->Get()) : nullptr; - } - - // Fetch output (non-tensor) with specified index. - template - T* Output(int index) { - if (index < 0 || index >= OutputCount()) - return nullptr; - - MLValue* p_ml_value = nullptr; - ORT_ENFORCE(GetOrCreateOutputMLValue(index, p_ml_value).IsOK()); - return p_ml_value ? p_ml_value->GetMutable() : nullptr; - } - - // In the case that memory allocation has not been done for an output tensor, - // The memory allocation will be done on-the-fly with given tensor shape. - // Return nullptr if the output is an unused optional output. - Tensor* Output(int index, const TensorShape& shape); - - const logging::Logger& Logger() const { - return *logger_; - } - - int InputCount() const { - return static_cast(kernel_->Node().InputDefs().size()); - } - - int ImplicitInputCount() const { - return static_cast(kernel_->Node().ImplicitInputDefs().size()); - } - - int OutputCount() const { - return static_cast(kernel_->Node().OutputDefs().size()); - } - - /** - * return an allocator on device 0, with memtype of OrtMemTypeDefault - * - */ - Status GetTempSpaceAllocator(AllocatorPtr* output) const; - - /** - Return the fence of current node's input. - @param index The index of the input. - @returns Point to the Fence of the input MLValue. - It is null if the input MLValue doesn't have fence or the input is optional. - */ - Fence_t InputFence(int index) const; - - /** - Return the fence of current node's implicit input. - @param index The index of the implicit input. - @returns Point to the Fence of the implicit input MLValue. - It is null if the input MLValue doesn't have fence or the input is optional. - */ - Fence_t ImplicitInputFence(int index) const; - - /** - Return the fence of current node's output identifed by index. - @param index The index of the output. - @returns Point to the Fence of the output MLValue. - It is null if the output MLValue doesn't have fence or the output is optional. - */ - Fence_t OutputFence(int index) const; - - protected: - onnxruntime::NodeIndex GetNodeIndex() const; - const SessionState& GetSessionState() const; - - const MLValue* GetInputMLValue(int index) const; - const MLValue* GetImplicitInputMLValue(int index) const; - MLValue* GetOutputMLValue(int index); - - private: - Status GetOrCreateOutputMLValue(int index, MLValue*& value); - - int GetInputArgIndex(int index) const; - int GetImplicitInputArgIndex(int index) const; - int GetOutputArgIndex(int index) const; - - ExecutionFrame* execution_frame_{nullptr}; - const OpKernel* kernel_{nullptr}; - const logging::Logger* logger_{nullptr}; - - // The argument starting index in ExecutionFrame. - int node_input_start_index_{-1}; - int node_implicit_input_start_index_{-1}; - int node_output_start_index_{-1}; -}; - -// Fetching output tensor without shape is not allowed except when it already exists -template <> -inline Tensor* OpKernelContext::Output(int index) { - MLValue* p_ml_value = GetOutputMLValue(index); - ORT_ENFORCE(p_ml_value, "Please fetch output tensor with specified shape."); - return p_ml_value->GetMutable(); -} - -using KernelCreateFn = std::function; - -struct KernelCreateInfo { - std::unique_ptr kernel_def; // Owned and stored in the global kernel registry. - KernelCreateFn kernel_create_func; - Status status; - - KernelCreateInfo(std::unique_ptr definition, - KernelCreateFn create_func) - : kernel_def(std::move(definition)), - kernel_create_func(create_func) {} - - KernelCreateInfo(KernelCreateInfo&& other) - : kernel_def(std::move(other.kernel_def)), - kernel_create_func(other.kernel_create_func) {} -}; - -using KernelCreateMap = std::multimap; - -// Forward declarations for the non-specialized BuildKernelCreateInfo method. -template -KernelCreateInfo BuildKernelCreateInfo(); - -namespace ml { -template -KernelCreateInfo BuildKernelCreateInfo(); -} // namespace ml - -namespace contrib { -template -KernelCreateInfo BuildKernelCreateInfo(); -} // namespace contrib - -// Naming convention for operator kernel classes -#define ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name) \ - provider##_##name##_##domain##_ver##ver - -#define ONNX_CPU_OPERATOR_KERNEL(name, ver, builder, ...) \ - ONNX_OPERATOR_KERNEL_EX(name, kOnnxDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_ML_KERNEL(name, ver, builder, ...) \ - ONNX_OPERATOR_KERNEL_EX(name, kMLDomain, ver, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_OPERATOR_KERNEL_EX(name, domain, ver, provider, builder, ...) \ - class ONNX_OPERATOR_KERNEL_CLASS_NAME(provider, domain, ver, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(ver) \ - .Provider(provider) \ - .Build(), \ - [](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); }); \ - } - -#define ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name) \ - provider##_##name##_##domain##_ver##startver##_##endver - -#define ONNX_CPU_OPERATOR_VERSIONED_KERNEL(name, startver, endver, builder, ...) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kOnnxDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_VERSIONED_ML_KERNEL(name, startver, endver, builder, ...) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, kMLDomain, startver, endver, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_OPERATOR_VERSIONED_KERNEL_EX(name, domain, startver, endver, provider, builder, ...) \ - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(provider, domain, startver, endver, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(startver, endver) \ - .Provider(provider) \ - .Build(), \ - [](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); }); \ - } - -#define ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name) \ - provider##_##name##_##domain##_ver##ver##_##type - -#define ONNX_CPU_OPERATOR_TYPED_KERNEL(name, ver, type, builder, ...) \ - ONNX_OPERATOR_TYPED_KERNEL_EX(name, kOnnxDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_TYPED_ML_KERNEL(name, ver, type, builder, ...) \ - ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMLDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(name, ver, type, builder, ...) \ - ONNX_OPERATOR_TYPED_KERNEL_EX(name, kMSDomain, ver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_OPERATOR_TYPED_KERNEL_EX(name, domain, ver, type, provider, builder, ...) \ - class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(provider, domain, ver, type, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(ver) \ - .Provider(provider) \ - .Build(), \ - [](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); }); \ - } - -#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name) \ - provider##_##name##_##domain##_ver##startver##_##endver##_##type - -#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL(name, startver, endver, type, builder, ...) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kOnnxDomain, startver, endver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_ML_KERNEL(name, startver, endver, type, builder, ...) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMLDomain, startver, endver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_CPU_OPERATOR_VERSIONED_TYPED_MS_KERNEL(name, startver, endver, type, builder, ...) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, kMSDomain, startver, endver, type, kCpuExecutionProvider, builder, __VA_ARGS__) - -#define ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(name, domain, startver, endver, type, provider, builder, ...) \ - class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(provider, domain, startver, endver, type, name); \ - template <> \ - KernelCreateInfo \ - BuildKernelCreateInfo() { \ - return KernelCreateInfo( \ - builder.SetName(#name) \ - .SetDomain(domain) \ - .SinceVersion(startver, endver) \ - .Provider(provider) \ - .Build(), \ - [](const OpKernelInfo& info) -> OpKernel* { return new __VA_ARGS__(info); }); \ - } - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/op_kernel_info.h b/onnxruntime/inc/core/framework/op_kernel_info.h deleted file mode 100644 index 47bb0e9..0000000 --- a/onnxruntime/inc/core/framework/op_kernel_info.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/framework/execution_provider.h" -#include "core/framework/kernel_def_builder.h" -#include "core/framework/ml_value.h" -#include "core/framework/op_node_proto_helper.h" -#include "core/graph/graph_viewer.h" -#include "gsl/span" -#include "gsl/gsl_util" - -namespace onnxruntime { - -class SessionState; - -/** - A very light-weight class, which works as an aggregated - view of all data needed for constructing a Kernel instance. - NOTE: it does not own/hold any objects. -*/ -class OpKernelInfo : public OpNodeProtoHelper { - public: - explicit OpKernelInfo(const onnxruntime::Node& node, - const KernelDef& kernel_def, - const IExecutionProvider& execution_provider, - const SessionState& session_state); - - OpKernelInfo(const OpKernelInfo& other); - - const OrtAllocatorInfo& GetAllocatorInfo(int device_id, OrtMemType mem_type) const; - - const AllocatorPtr GetAllocator(int device_id, OrtMemType mem_type) const; - - const KernelDef& GetKernelDef() const; - - const IExecutionProvider* GetExecutionProvider() const noexcept; - - const onnxruntime::Node& node() const noexcept; - - bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const; - - common::Status GetFusedFuncs(ComputeFunc* compute, CreateFunctionStateFunc* create, DestroyFunctionStateFunc* release) const; - - private: - ORT_DISALLOW_MOVE(OpKernelInfo); - ORT_DISALLOW_ASSIGNMENT(OpKernelInfo); - - const onnxruntime::Node& node_; - const KernelDef& kernel_def_; - // For non cpu/cuda case, this pointer should be set so that function kernel - // will delegate kernel compute call to compute call. - gsl::not_null execution_provider_; - ProtoHelperNodeContext proto_helper_context_; - const SessionState& session_state_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/op_node_proto_helper.h b/onnxruntime/inc/core/framework/op_node_proto_helper.h deleted file mode 100644 index 492dd09..0000000 --- a/onnxruntime/inc/core/framework/op_node_proto_helper.h +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/status.h" -#include "core/graph/graph_viewer.h" -#include "gsl/span" - -#ifdef __has_attribute -#define ORT_HAVE_ATTRIBUTE(x) __has_attribute(x) -#else -#define ORT_HAVE_ATTRIBUTE(x) 0 -#endif - -#if ORT_HAVE_ATTRIBUTE(nodiscard) -#define MUST_USE_RESULT [[nodiscard]] -#elif defined(__clang__) && ORT_HAVE_ATTRIBUTE(warn_unused_result) -#define MUST_USE_RESULT __attribute__((warn_unused_result)) -#else -#define MUST_USE_RESULT -#endif - -class IMLOpKernel; - -namespace onnxruntime { - -/** - A set of wrappers with common signatures for use with both OpKernelInfo - (as its base class) and InferenceContext. Used by ABI kernels for both - shape / type inference and kernel construction -*/ -template -class OpNodeProtoHelper { - public: - explicit OpNodeProtoHelper(const Impl_t* impl) : impl_(impl) {} - - /** - Get a single attribute - */ - template - MUST_USE_RESULT Status GetAttr(const std::string& name, T* value) const; - - /** - Get a single attribute - Call this function only when onnx doesn't have default value - */ - template - T GetAttrOrDefault(const std::string& name, const T& default_value) const { - T tmp; - return GetAttr(name, &tmp).IsOK() ? tmp : default_value; - } - - /** - Get a single attribute - Call this function only when onnx doesn't have default value - */ - template - void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const { - if (!GetAttr(name, value).IsOK()) - *value = default_value; - } - - /** - Get repeated attributes - Call this function only when onnx doesn't have default value - */ - template - MUST_USE_RESULT std::vector GetAttrsOrDefault(const std::string& name, const std::vector& default_value = std::vector{}) const { - std::vector tmp; - return GetAttrs(name, tmp).IsOK() ? tmp : default_value; - } - - /** - Get repeated attributes - */ - template - MUST_USE_RESULT Status GetAttrs(const std::string& name, std::vector& values) const; - - template - MUST_USE_RESULT Status GetAttrs(const std::string& name, gsl::span values) const; - - uint32_t GetPrimitiveAttrElementCount(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; - - bool HasPrimitiveAttribute(ONNX_NAMESPACE::AttributeProto_AttributeType type, - const std::string& name) const noexcept; - - uint32_t GetInputCount() const { - return gsl::narrow_cast(impl_->getNumInputs()); - } - - uint32_t GetOutputCount() const { - return gsl::narrow_cast(impl_->getNumOutputs()); - } - - const ONNX_NAMESPACE::TypeProto* GetInputType(size_t index) const { - return impl_->getInputType(index); - } - - const ONNX_NAMESPACE::TypeProto* GetOutputType(size_t index) const { - // Work around lack of a const method from the onnx InferenceContext interface - return const_cast(impl_)->getOutputType(index); - } - - // Try to query an attribute, returning nullptr if it doesn't exist - const ONNX_NAMESPACE::AttributeProto* TryGetAttribute(const std::string& name) const { - return impl_->getAttribute(name); - } - - const ONNX_NAMESPACE::AttributeProto* GetAttribute(const std::string& name) const { - const ONNX_NAMESPACE::AttributeProto* attr = TryGetAttribute(name); - ORT_ENFORCE(attr != nullptr); - return attr; - } - - private: - OpNodeProtoHelper() = delete; - const Impl_t* impl_ = nullptr; -}; - -// The methods on the following class are called by OpNodeProtoHelper, implementing -// the same signatures as InferenceContext other than const-ness. -class ProtoHelperNodeContext { - public: - explicit ProtoHelperNodeContext(const onnxruntime::Node& node) : node_(node) {} - ProtoHelperNodeContext() = delete; - - const ONNX_NAMESPACE::AttributeProto* getAttribute(const std::string& name) const; - size_t getNumInputs() const; - const ONNX_NAMESPACE::TypeProto* getInputType(size_t index) const; - size_t getNumOutputs() const; - const ONNX_NAMESPACE::TypeProto* getOutputType(size_t index) const; - - private: - const onnxruntime::Node& node_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/run_options.h b/onnxruntime/inc/core/framework/run_options.h deleted file mode 100644 index f10f161..0000000 --- a/onnxruntime/inc/core/framework/run_options.h +++ /dev/null @@ -1,34 +0,0 @@ - -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include "core/session/onnxruntime_c_api.h" - -/** - * Configuration information for a single Run. - */ -struct OrtRunOptions { - unsigned run_log_verbosity_level = 0; ///< applies to a particular Run() invocation - std::string run_tag; ///< to identify logs generated by a particular Run() invocation - - /// set to 'true' to terminate any currently executing Run() calls that are using this - /// OrtRunOptions instance. the individual calls will exit gracefully and return an error status. - bool terminate = false; - OrtRunOptions() = default; - ~OrtRunOptions() = default; - - // disable copy, move and assignment. we don't want accidental copies, to ensure that the instance provided to - // the Run() call never changes and the terminate mechanism will work. - OrtRunOptions(const OrtRunOptions&) = delete; - OrtRunOptions(OrtRunOptions&&) = delete; - OrtRunOptions& operator=(const OrtRunOptions&) = delete; - OrtRunOptions& operator=(OrtRunOptions&&) = delete; -}; - -namespace onnxruntime { -using RunOptions = OrtRunOptions; -} diff --git a/onnxruntime/inc/core/framework/tensor.h b/onnxruntime/inc/core/framework/tensor.h deleted file mode 100644 index f2ab6f2..0000000 --- a/onnxruntime/inc/core/framework/tensor.h +++ /dev/null @@ -1,201 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "gsl/span" - -#include "core/framework/allocator.h" -#include "core/framework/data_types.h" -#include "core/framework/tensor_shape.h" -#include "onnxruntime_config.h" - -namespace onnxruntime { - -// TODO: Do we need this class or is IAllocator::MakeUniquePtr sufficient/better -class BufferDeleter { - public: - BufferDeleter() : alloc_(nullptr) {} - BufferDeleter(AllocatorPtr alloc) - : alloc_(alloc) {} - - void operator()(void* p) const { - if (alloc_) - alloc_->Free(p); - } - - private: - // TODO: we may need consider the lifetime of alloc carefully - // The alloc_ here is the allocator that used to allocate the buffer - // And need go with the unique_ptr together. If it is using our internal - // allocator, it is ok as our allocators are global managed. But if it - // is provide by user, user need to be very careful about it. - // A weak_ptr may be a choice to reduce the impact, but that require to - // change our current allocator mgr to use shared_ptr. Will revisit it - // later. - AllocatorPtr alloc_; -}; - -typedef std::unique_ptr BufferUniquePtr; -using BufferNakedPtr = void*; -//TODO:ensure dtype_!=nullptr -#ifdef __GNUC__ -#pragma GCC diagnostic push -#ifdef HAS_NULL_DEREFERENCE -#pragma GCC diagnostic ignored "-Wnull-dereference" -#endif -#endif -/* - We want to keep tensor as simple as possible, it is just a placeholder - for a piece of memory, with additional shape information. - Memory is owned and managed by Executor / Workspace, so Tensor just uses - it, and won't do any allocation / release. -*/ -class Tensor final { - public: - /** - Create tensor with given type, shape, pre-allocate memory and allocator info. - */ - Tensor(MLDataType p_type, - const TensorShape& shape, - BufferNakedPtr p_data, - const OrtAllocatorInfo& alloc, - AllocatorPtr deleter = nullptr, - int64_t offset = 0); - - ~Tensor(); - - /** - Copy constructor and assign op will just pass the shape and memory - reference to another tensor. Not deep clone/copy. - */ - Tensor(const Tensor& src); - - ///requires other.buffer_deleter_ == nullptr - Tensor& ShallowCopy(const Tensor& other); - - Tensor(Tensor&& other); - - Tensor& operator=(Tensor&& other); - - /** - Returns the data type. - */ - MLDataType DataType() const { return dtype_; } - - /** - Returns the shape of the tensor. - */ - const TensorShape& Shape() const noexcept { return shape_; } - - /** - Returns the location of the tensor's memory - */ - const OrtAllocatorInfo& Location() const { return alloc_info_; } - - /** - May return nullptr if tensor size is zero - */ - template - T* MutableData() { - // Type check - ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", - DataTypeImpl::GetType(), "!=", dtype_); - return reinterpret_cast(static_cast(p_data_) + byte_offset_); - } - - /** - May return nullptr if tensor size is zero - */ - template - gsl::span MutableDataAsSpan() { - // Type check - ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", - DataTypeImpl::GetType(), "!=", dtype_); - T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, shape_.Size()); - } - - template - const T* Data() const { - // Type check - ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", - DataTypeImpl::GetType(), "!=", dtype_); - return reinterpret_cast(static_cast(p_data_) + byte_offset_); - } - - template - gsl::span DataAsSpan() const { - // Type check - ORT_ENFORCE(DataTypeImpl::GetType() == dtype_, "Tensor type mismatch. ", - DataTypeImpl::GetType(), "!=", dtype_); - const T* data = reinterpret_cast(static_cast(p_data_) + byte_offset_); - return gsl::make_span(data, shape_.Size()); - } - - void* MutableDataRaw(MLDataType type) { - ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); - return p_data_; - } - - const void* DataRaw(MLDataType type) const { - ORT_ENFORCE(type == dtype_, "Tensor type mismatch.", type, "!=", dtype_); - return p_data_; - } - - void* MutableDataRaw() noexcept { - return p_data_; - } - - const void* DataRaw() const noexcept { - return p_data_; - } - - /** - * Resizes the tensor without touching underlying storage. - * This requires the total size of the tensor to remains constant. - * @warning this function is NOT thread-safe. - */ - inline void Reshape(const TensorShape& new_shape) { - ORT_ENFORCE(shape_.Size() == new_shape.Size(), - "Tensor size (" + std::to_string(shape_.Size()) + - ") != new size (" + std::to_string(new_shape.Size()) + ")"); - shape_ = new_shape; - } - - size_t Size() const noexcept { - return shape_.Size() * dtype_->Size(); - } - - // More API methods. - private: - void Init(MLDataType p_type, - const TensorShape& shape, - void* p_raw_data, - const OrtAllocatorInfo& alloc, - AllocatorPtr deleter, - int64_t offset = 0); - - void ReleaseBuffer(); - - void* p_data_; - /** - if buffer_deleter_ is null, it means tensor does not own the buffer. - otherwise tensor will use the deleter to release the buffer when - tensor is released. - */ - AllocatorPtr buffer_deleter_; - - TensorShape shape_; - MLDataType dtype_; - OrtAllocatorInfo alloc_info_; - int64_t byte_offset_; -}; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/framework/tensor_shape.h b/onnxruntime/inc/core/framework/tensor_shape.h deleted file mode 100644 index 1007601..0000000 --- a/onnxruntime/inc/core/framework/tensor_shape.h +++ /dev/null @@ -1,146 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include -#include -#include -#include -#include -#include "onnxruntime_config.h" - -namespace ONNX_NAMESPACE { -class TensorShapeProto; -} - -namespace onnxruntime { -#ifdef __GNUC__ -#pragma GCC diagnostic push -#ifdef HAS_NULL_DEREFERENCE -#pragma GCC diagnostic ignored "-Wnull-dereference" -#endif -#endif -class TensorShape : private std::vector { - // TODO - Use a custom STL allocator to avoid heap allocations in the common case. - // We use negative numbers for unknown symbolic dimension. Each negative - // number represents a unique symbolic dimension. - // Private inheritance is used to prevent ambiguity of element versus dimension size - public: - TensorShape() = default; - - TensorShape(const TensorShape& /*other*/) = default; - TensorShape& operator=(const TensorShape& /*other*/) = default; - - TensorShape(TensorShape&& /*other*/) = default; - TensorShape& operator=(TensorShape&& /*other*/) = default; - - TensorShape(const int64_t* dimension_sizes, size_t dimension_count); - - TensorShape(const std::vector& dims); - - TensorShape(const std::initializer_list& dims); - - TensorShape(const std::vector& dims, size_t start, size_t end); - - /** - Return the dimension specified by . - */ - const int64_t& operator[](size_t idx) const { - return std::vector::operator[](static_cast(idx)); - } - - int64_t& operator[](size_t idx) { - return std::vector::operator[](static_cast(idx)); - } - - bool operator==(const TensorShape& other) const noexcept { - auto thisVector = static_cast*>(this); - auto otherVector = static_cast*>(&other); - return *thisVector == *otherVector; - } - - bool operator!=(const TensorShape& other) const noexcept { - return !(*this == other); - } - - size_t NumDimensions() const noexcept { - return size(); - } - - /** - Copy dims into an array with given size - */ - void CopyDims(int64_t* dims, size_t num_dims) const { - memcpy(dims, data(), sizeof(value_type) * std::min(num_dims, NumDimensions())); - } - - /** - Return underlying vector representation. - */ - const std::vector& GetDims() const { return *this; } - - /** - * Return the total number of elements. Returns 1 for an empty (rank 0) TensorShape. - * - * May return -1 - */ - int64_t Size() const; - - /** - Return the total number of elements up to the specified dimension. - If the dimension interval is empty (dimension == 0), return 1. - @param dimension Return size up to this dimension. Value must be between 0 and this->NumDimensions(), inclusive. - */ - int64_t SizeToDimension(size_t dimension) const; - - /** - Return the total number of elements from the specified dimension to the end of the tensor shape. - If the dimension interval is empty (dimension == this->NumDimensions()), return 1. - @param dimension Return size from this dimension to the end. Value must be between 0 and this->NumDimensions(), - inclusive. - */ - int64_t SizeFromDimension(size_t dimension) const; - - /** - Return a new TensorShape of the dimensions from dimstart to dimend. - */ - TensorShape Slice(size_t dimstart, size_t dimend) const; - - /** - Return a new TensorShape of the dimensions from dimstart to end. - */ - TensorShape Slice(size_t dimstart) const; - - /** - output dimensions nicely formatted - */ - std::string ToString() const; - - /** - Calculate size between start and end. - Assumes start and end are between 0 and this->NumDimensions(), inclusive, and that - start < end. - */ - int64_t SizeHelper(size_t start, size_t end) const; - - /** - empty shape or 1D shape (1) is regarded as scalar tensor - */ - bool IsScalar() const { - return size() == 0 || (size() == 1 && at(0) == 1); - } - - static const TensorShape& ReinterpretBaseType(const std::vector& dimensions) { - static_assert(sizeof(TensorShape) == sizeof(std::vector), "Size of TensorShape prevents safe casting from vector"); - return *static_cast(&dimensions); - } -}; -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif -// operator<< to nicely output to a stream -std::ostream& operator<<(std::ostream& out, const ::onnxruntime::TensorShape& shape); - -std::ostream& operator<<(std::ostream& out, const ONNX_NAMESPACE::TensorShapeProto& shape_proto); - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/basic_types.h b/onnxruntime/inc/core/graph/basic_types.h deleted file mode 100644 index 24702b4..0000000 --- a/onnxruntime/inc/core/graph/basic_types.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include -#include - -namespace ONNX_NAMESPACE { -class ValueInfoProto; -class TensorProto; -class TypeProto; -class AttributeProto; -} // namespace ONNX_NAMESPACE - -namespace onnxruntime { -using NodeIndex = size_t; -using Version = int64_t; -using NodeArgInfo = ONNX_NAMESPACE::ValueInfoProto; -using InitializedTensorSet = std::unordered_map; -using ArgNameToTypeMap = std::unordered_map; -using ProviderType = const std::string&; -// TODO - Evaluate switching the types below to support transparent comparators and enable -// lookups based on gsl::cstring_span<> and std::string_view. This would reduces allocations -// converting to std::string, but requires conversion to std::map> -// instead of std::unordered_map]>. - -using NodeAttributes = std::unordered_map; -class IOnnxRuntimeOpSchemaCollection; -using IOnnxRuntimeOpSchemaCollectionPtr = std::shared_ptr; -} // namespace onnxruntime - -namespace onnxruntime { -class OpKernel; -class OpKernelInfo; - -using KernelCreateFn = std::function; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/constants.h b/onnxruntime/inc/core/graph/constants.h deleted file mode 100644 index 25b6f36..0000000 --- a/onnxruntime/inc/core/graph/constants.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/common/common.h" - -namespace onnxruntime { -constexpr const char* kNoOp = "NoOp"; -constexpr const char* kConstant = "Constant"; -constexpr const char* kFunctionOp = "_kFunctionOp"; -constexpr const char* kConstantValue = "value"; -constexpr const char* kOnnxDomain = ""; -constexpr const char* kOnnxDomainAlias = "ai.onnx"; -constexpr const char* kMLDomain = "ai.onnx.ml"; -constexpr const char* kMSDomain = "com.microsoft"; -constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; -constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; -constexpr const char* kMklDnnExecutionProvider = "MKLDNNExecutionProvider"; -constexpr const char* kNupharExecutionProvider = "NupharExecutionProvider"; -constexpr const char* kBrainSliceExecutionProvider = "BrainSliceExecutionProvider"; -constexpr const char* kTRTExecutionProvider = "TRTExecutionProvider"; -} // namespace onnxruntime - diff --git a/onnxruntime/inc/core/graph/function.h b/onnxruntime/inc/core/graph/function.h deleted file mode 100644 index 959af66..0000000 --- a/onnxruntime/inc/core/graph/function.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/graph/indexed_sub_graph.h" - -namespace onnxruntime { -class Graph; -class Node; -} // namespace onnxruntime - -namespace onnxruntime { - -/** -@class Function -Class representing a Function. -*/ -class Function { - public: - virtual ~Function() = default; - - /** Gets the OpSchema for the Function. */ - virtual const ONNX_NAMESPACE::OpSchema& OpSchema() const = 0; - - /** Gets the Graph instance for the Function body subgraph. */ - virtual const onnxruntime::Graph& Body() const = 0; - - /** Gets the IndexedSubGraph for the Function. */ - virtual const IndexedSubGraph& GetIndexedSubGraph() const = 0; -}; - -/** -Create a new Function instance. -@param graph The graph containing the Function. -@param customized_func the IndexedSubGraph to use for the Function. -*/ -std::unique_ptr MakeFunction(const onnxruntime::Graph& graph, - std::unique_ptr customized_func); -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/graph.h b/onnxruntime/inc/core/graph/graph.h deleted file mode 100644 index 15ff88f..0000000 --- a/onnxruntime/inc/core/graph/graph.h +++ /dev/null @@ -1,985 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/const_pointer_container.h" -#include "core/common/status.h" -#include "core/graph/basic_types.h" -#include "core/graph/constants.h" -#include "core/graph/graph_nodes.h" -#include "core/graph/node_arg.h" -#include "core/graph/onnx_protobuf.h" -#include "core/graph/function.h" -#include "gsl/gsl_util" -#include "gsl/pointers" - -namespace onnxruntime { -class Graph; -struct IndexedSubGraph; -class Node; -class OpSignature; - -/** -@class Node -Class representing a node in the graph. -*/ -class Node { - public: - /** Node types */ - enum class Type { - Primitive = 0, ///< The node refers to a primitive operator. - Fused = 1, ///< The node refers to a function. - }; - - ~Node() = default; - - /** - @class EdgeEnd - Class representing the end of an edge. It could be an input or output edge end of a node. - For the node's input edge end, it's the source end, as the destination end is the node itself. - For the node's output edge end, it's the destination end, as the source end is the node itself. - */ - class EdgeEnd { - public: - /** - Construct an EdgeEnd - @param node The source node if this is an input edge to the current node, - or the destination node if this is an output edge from the current node. - @param src_arg_index The node arg index of source node of the edge. - @param dst_arg_index The node arg index of destination node of the edge. - */ - EdgeEnd(const Node& node, int src_arg_index, int dst_arg_index) noexcept; - - /** Construct a control edge. - @param node The node the edge joins to the current node. - */ - explicit EdgeEnd(const Node& node) noexcept; - - /** Gets the Node that this EdgeEnd refers to. */ - const Node& GetNode() const noexcept; - - /** Gets the source arg index. - @returns the source arg index of <*this> edge.*/ - int GetSrcArgIndex() const; - - /** Gets the destination arg index. - @returns the destination arg index of <*this> edge.*/ - int GetDstArgIndex() const; - - private: - const Node* node_; - const int src_arg_index_; - const int dst_arg_index_; - }; - - /** Gets the Node's NodeIndex. */ - NodeIndex Index() const noexcept; - - /** Gets the Node's name. */ - const std::string& Name() const noexcept; - - /** Gets the Node's operator type. */ - const std::string& OpType() const noexcept; - - /** Gets the domain of the OperatorSet that specifies the operator returned by #OpType. */ - const std::string& Domain() const noexcept; - - /** Gets the Node's OpSchema. - @remarks The graph containing this node must be resolved, otherwise nullptr will be returned. */ - const ONNX_NAMESPACE::OpSchema* Op() const noexcept; - - /** Gets the Node's Node::Type. */ - Node::Type NodeType() const noexcept; - - /** Gets the function body if the #NodeType is fused, or nullptr if not. */ - const Function* GetFunctionBody() const noexcept; - - /** Gets the node description. */ - const std::string& Description() const noexcept; - - /** - Helper to iterate through the container returned by #InputDefs() or #OutputDefs() and call the provided function. - @param node_args Collection of NodeArgs returned by #InputDefs() or #OutputDefs() - @param func Function to call for each valid NodeArg in the node_args. The function is called with the NodeArg - and the index number in the container. - @returns common::Status with success or error information. - @remarks Returns immediately on error. - */ - static common::Status ForEachWithIndex(const ConstPointerContainer>& node_args, - std::function func) { - for (size_t index = 0; index < node_args.size(); ++index) { - auto arg = node_args[index]; - if (!arg->Exists()) - continue; - ORT_RETURN_IF_ERROR(func(*arg, index)); - } - return common::Status::OK(); - } - - /** Gets the Node's input definitions. - @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */ - const ConstPointerContainer> InputDefs() const noexcept { - return ConstPointerContainer>(definitions_.input_defs); - } - - /** Gets a modifiable collection of the Node's input definitions. */ - std::vector& MutableInputDefs() noexcept { - return definitions_.input_defs; - } - - /** Gets a modifiable collection of the Node's output definitions. */ - std::vector& MutableOutputDefs() noexcept { - return definitions_.output_defs; - } - - /** Gets the count of arguments for each of the Node's explicit inputs. */ - const std::vector& InputArgCount() const noexcept { return definitions_.input_arg_count; } - - /** Gets a modifiable count of arguments for each of the Node's explicit inputs. - @todo This should be removed in favor of a method that updates the input args and the count. - Currently these operations are separate which is not a good setup. */ - std::vector& MutableInputArgsCount() { return definitions_.input_arg_count; } - - /** Gets the implicit inputs to this Node. - If this Node contains a subgraph, these are the NodeArg's that are implicitly consumed by Nodes within that - subgraph. e.g. If and Loop operators.*/ - const std::vector& ImplicitInputDefs() const noexcept { - return definitions_.implicit_input_defs; - } - - /** Gets the Node's output definitions. - @remarks requires ConstPointerContainer wrapper to apply const to the NodeArg pointers so access is read-only. */ - const ConstPointerContainer> OutputDefs() const noexcept { - return ConstPointerContainer>(definitions_.output_defs); - } - - /** Struct to provide sorting between EdgeEnd instances based on NodeIndex first, and NodeArg::Name second. */ - struct EdgeEndCompare { - bool operator()(const EdgeEnd& lhs, const EdgeEnd& rhs) const { - if (lhs.GetNode().Index() == rhs.GetNode().Index()) { - if (lhs.GetSrcArgIndex() == rhs.GetSrcArgIndex()) { - return lhs.GetDstArgIndex() < rhs.GetDstArgIndex(); - } - return lhs.GetSrcArgIndex() < rhs.GetSrcArgIndex(); - } - return lhs.GetNode().Index() < rhs.GetNode().Index(); - } - }; - - using EdgeSet = std::set; - using EdgeConstIterator = EdgeSet::const_iterator; - - /** - @class NodeConstIterator - Class to provide const access to Node instances iterated via an EdgeConstIterator. */ - class NodeConstIterator { - public: - NodeConstIterator(EdgeConstIterator p_iter); - - bool operator==(const NodeConstIterator& p_other) const; - - bool operator!=(const NodeConstIterator& p_other) const; - - void operator++(); - void operator--(); - - const Node& operator*(); - - private: - EdgeConstIterator m_iter; - }; - - // Functions defined to traverse a Graph as below. - - /** Gets an iterator to the beginning of the input nodes to this Node. */ - NodeConstIterator InputNodesBegin() const noexcept { return NodeConstIterator(relationships_.input_edges.cbegin()); }; - /** Gets an iterator to the end of the input nodes to this Node. */ - NodeConstIterator InputNodesEnd() const noexcept { return NodeConstIterator(relationships_.input_edges.cend()); } - - /** Gets an iterator to the beginning of the output nodes from this Node. */ - NodeConstIterator OutputNodesBegin() const noexcept { return NodeConstIterator(relationships_.output_edges.cbegin()); } - /** Gets an iterator to the end of the output nodes from this Node. */ - NodeConstIterator OutputNodesEnd() const noexcept { return NodeConstIterator(relationships_.output_edges.cend()); } - - /** Gets an iterator to the beginning of the input edges to this Node. - @remarks There are no nullptr entries in this collection. */ - EdgeConstIterator InputEdgesBegin() const noexcept { return relationships_.input_edges.cbegin(); } - - /** Gets an iterator to the end of the input edges to this Node. */ - EdgeConstIterator InputEdgesEnd() const noexcept { return relationships_.input_edges.cend(); } - - /** Gets an iterator to the beginning of the output edges from this Node. - @remarks There are no nullptr entries in this collection. */ - EdgeConstIterator OutputEdgesBegin() const noexcept { return relationships_.output_edges.cbegin(); } - - /** Gets an iterator to the end of the output edges from this Node. */ - EdgeConstIterator OutputEdgesEnd() const noexcept { return relationships_.output_edges.cend(); } - - /** Gets the Node's control inputs. */ - const std::set& ControlInputs() const noexcept { return relationships_.control_inputs; } - - /** Gets the number of input edges to this Node */ - size_t GetInputEdgesCount() const noexcept { return relationships_.input_edges.size(); } - - /** Gets the number of output edges from this Node */ - size_t GetOutputEdgesCount() const noexcept { return relationships_.output_edges.size(); } - - /** Add an attribute to this Node with specified attribute name and value. */ - void AddAttribute(const std::string& attr_name, const ONNX_NAMESPACE::AttributeProto& value); - -#define ADD_ATTR_INTERFACES(TypeName) \ - void AddAttribute(const std::string& attr_name, const TypeName& value); \ - void AddAttribute(const std::string& attr_name, \ - const std::vector& values); - - ADD_ATTR_INTERFACES(int64_t) - ADD_ATTR_INTERFACES(float) - ADD_ATTR_INTERFACES(std::string) - ADD_ATTR_INTERFACES(ONNX_NAMESPACE::TensorProto) - ADD_ATTR_INTERFACES(ONNX_NAMESPACE::GraphProto) - - /** Remove the specified attribute from this Node */ - bool ClearAttribute(const std::string& attr_name); - - /** Gets the Node's attributes. */ - const NodeAttributes& GetAttributes() const noexcept; - - /** Gets the Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve. - @param attr_name Attribute name for the GraphProto attribute. - @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto. - */ - const Graph* GetGraphAttribute(const std::string& attr_name) const; - - /** Gets the mutable Graph instance that is instantiated from a GraphProto attribute during Graph::Resolve. - @param attr_name Attribute name for the GraphProto attribute. - @returns nullptr if the Graph instance has not been instantiated or attribute does not contain a GraphProto. - */ - Graph* GetMutableGraphAttribute(const std::string& attr_name); - - /** Gets a map of attribute name to the mutable Graph instances for all subgraphs of the Node. - @returns Map of the attribute name that defines the subgraph to the subgraph's Graph instance. - nullptr if the Node has no subgraphs. - */ - const std::unordered_map>& GetAttributeNameToMutableSubgraphMap() { - return attr_to_subgraph_map_; - } - - /** Gets the execution ProviderType that this node will be executed by. */ - ProviderType GetExecutionProviderType() const noexcept; - - /** Sets the execution ProviderType that this Node will be executed by. */ - void SetExecutionProviderType(ProviderType execution_provider_type); - - /** Gets the NodeProto representation of this Node. */ - void ToProto(ONNX_NAMESPACE::NodeProto& proto) const; - - /** Call the provided function for all explicit inputs, implicit inputs, and outputs of this Node. - If the NodeArg is an explicit or implicit input, is_input will be true when func is called. */ - void ForEachDef(std::function func) const; - - /** Replaces any matching definitions in the Node's explicit inputs or explicit outputs. - @param replacements Map of current NodeArg to replacement NodeArg. - */ - void ReplaceDefs(const std::map& replacements); - - /** - @class Definitions - The input and output definitions for this Node. - */ - class Definitions { - public: - Definitions() noexcept = default; - - /** The Node's explicit input definitions. */ - std::vector input_defs; - - /** - The number of inputs for each argument of the operator or function which this node refers. - @remarks For example, #input_defs has 10 elements (inputs), and #input_arg_count is {4, 6}. - This means that 4 elements (inputs) of input_defs map to the first argument of the operator or function, and - the other 6 map to the second argument. - */ - std::vector input_arg_count; - - /** The Node's output definitions. */ - std::vector output_defs; - - /** The Node's implicit input definitions if the Node contains one or more subgraphs - (i.e. GraphProto attributes) and the subgraph/s implicitly consume these values. - @remarks For example, a subgraph in an 'If' node gets all its input values via this mechanism rather than - there being explicit inputs to the 'If' node that are passed to the subgraph. - They are pseudo-inputs to this Node as it has an implicit dependency on them. */ - std::vector implicit_input_defs; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Definitions); - }; - - /** - @class Relationships - Defines the relationships between this Node and other Nodes in the Graph. - */ - class Relationships { - public: - Relationships() = default; - - void Clear() noexcept { - input_edges.clear(); - output_edges.clear(); - control_inputs.clear(); - } - - /** The edges for Nodes that provide inputs to this Node. */ - EdgeSet input_edges; - - /** The edges for Nodes that receive outputs from this Node. */ - EdgeSet output_edges; - - /** The Node names of the control inputs to this Node. */ - std::set control_inputs; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Relationships); - }; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); - - // NOTE: This friendship relationship should ONLY be used for calling methods of the Node class and not accessing - // the data members directly, so that the Node can maintain its internal invariants. - friend class Graph; - - Node(NodeIndex index, Graph& graph) : index_(index), graph_(&graph) {} - - void Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes* attributes, - const std::string& domain); - - // create a Graph instance for an attribute that contains a GraphProto - void CreateSubgraph(const std::string& attr_name); - - // internal only method to allow selected classes to directly alter the input/output definitions and arg counts - Definitions& MutableDefinitions() noexcept; - - // internal only method to allow selected classes to directly alter the links between nodes. - Relationships& MutableRelationships() noexcept; - - const std::vector>& MutableSubgraphs() noexcept { return subgraphs_; } - - const Definitions& GetDefinitions() const noexcept { return definitions_; } - const Relationships& GetRelationships() const noexcept { return relationships_; } - - void SetNodeType(Node::Type node_type) noexcept; - - void SetFunctionBody(const Function& func); - - // validate and update the input arg count - common::Status UpdateInputArgCount(); - - // Node index. Default to impossible value rather than 0. - NodeIndex index_ = std::numeric_limits::max(); - - // Node name. - std::string name_; - - // Node operator type. - std::string op_type_; - - // OperatorSet domain of op_type_. - std::string domain_; - - // OperatorSchema that <*this> node refers to. - const ONNX_NAMESPACE::OpSchema* op_ = nullptr; - Node::Type node_type_ = Node::Type::Primitive; - - // The function body is owned by graph_ - const Function* func_body_ = nullptr; - - // Node doc string. - std::string description_; - - // input/output defs and arg count - Definitions definitions_; - - // Relationships between this node and others in the graph - Relationships relationships_; - - // Device. - std::string execution_provider_type_; - - // Map from attribute name to attribute. - // This allows attribute adding and removing. - NodeAttributes attributes_; - - // Graph that contains this Node - Graph* graph_; - - // Map of attribute name to the Graph instance created from the GraphProto attribute - std::unordered_map> attr_to_subgraph_map_; - - // Graph instances for subgraphs that are owned by this Node - std::vector> subgraphs_; -}; - -/** -@class Graph -The Graph representation containing the graph inputs and outputs, the Node instances, -and the edges connecting the nodes. -*/ -class Graph { - public: - /** - Resolve this Graph to ensure it is completely valid, fully initialized, and able to be executed. - 1. Run through all validation rules. - a. Node name and node output's names should be unique. - b. Attribute match between node and op definition. - c. Input/Output match between node and op definition. - d. Graph is acyclic and sort nodes in topological order. - 2. Check & Setup inner nodes' dependency. - 3. Cleanup function definition lists. - @returns common::Status with success or error information. - */ - common::Status Resolve(); - - /** Gets the Graph name. */ - const std::string& Name() const noexcept; - /** Sets the Graph name. */ - void SetName(const std::string& name); - - /** Gets the Graph description. */ - const std::string& Description() const noexcept; - /** Gets the Graph description. */ - void SetDescription(const std::string& description); - - /** Add an initializer tensor to the Graph. */ - void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto); - - /** Remove the initializer tensor with the provided name from the Graph. */ - void RemoveInitializedTensor(const std::string& tensor_name); - - /** Gets an initializer tensor with the provided name. - @param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not. - @returns True if found. - */ - bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; - - /** Gets all the initializer tensors in this Graph. */ - const InitializedTensorSet& GetAllInitializedTensors() const noexcept; - - /** Removes all initializer tensors from this Graph and releases the memory they were using. */ - void CleanAllInitializedTensors() noexcept; - - /** Gets the Graph inputs excluding initializers. - These are the required inputs to the Graph as the initializers can be optionally overridden via graph inputs. - @remarks Contains no nullptr values. */ - const std::vector& GetInputs() const noexcept { return graph_inputs_excluding_initializers_; } - - /** Gets the Graph inputs including initializers. - This is the full set of inputs, in the same order as defined in the GraphProto. - @remarks Contains no nullptr values. */ - const std::vector& GetInputsIncludingInitializers() const noexcept { - return graph_inputs_including_initializers_; - } - - /** Gets the Graph outputs. - @remarks Contains no nullptr values.*/ - const std::vector& GetOutputs() const noexcept { return graph_outputs_; } - - /** Returns true if a Node output is a Graph output. */ - bool IsNodeOutputsInGraphOutputs(const Node& node) { - for (auto output_def : node.OutputDefs()) { - if (std::find(GetOutputs().cbegin(), GetOutputs().cend(), output_def) != GetOutputs().cend()) { - return true; - } - } - return false; - } - - /** Gets the NodeArgs that represent value_info instances in the Graph. - These are the values that are neither Graph inputs nor outputs. - @remarks Contains no nullptr values. */ - const std::vector& GetValueInfo() const noexcept; - - /** Gets the Node with the specified node index. - @returns Node instance if found. nullptr if node_index is invalid or node has been freed. - */ - const Node* GetNode(NodeIndex node_index) const { return NodeAtIndexImpl(node_index); } - - /** Gets the mutable Node with the specified node index. - @returns Mutable Node instance if found. nullptr if node_index is invalid or node has been freed. - */ - Node* GetNode(NodeIndex node_index) { return NodeAtIndexImpl(node_index); } - - /** Get a GraphNodes instance that provides mutable access to all valid Nodes in the Graph. */ - GraphNodes& Nodes() noexcept { return iterable_nodes_; } - - /** Get a GraphNodes instance that provides const access to all valid Nodes in the Graph. */ - const GraphNodes& Nodes() const noexcept { return iterable_nodes_; } - - /** Gets the maximum NodeIndex value used in the Graph. */ - int MaxNodeIndex() const noexcept { return static_cast(nodes_.size()); } //assume the casting won't overflow - - /** Gets the number of valid Nodes in the Graph. - @remarks This may be smaller than MaxNodeIndex(), as Nodes may be removed during optimization. - */ - int NumberOfNodes() const noexcept { return num_of_nodes_; } - - /** Gets the mutable NodeArg with the provided name. - @returns Pointer to NodeArg if found, nullptr if not. */ - NodeArg* GetNodeArg(const std::string& name) { - auto iter = node_args_.find(name); - if (iter != node_args_.end()) { - return iter->second.get(); - } - return nullptr; - } - - /** Gets the const NodeArg with the provided name. - @returns Pointer to const NodeArg if found, nullptr if not. */ - const NodeArg* GetNodeArg(const std::string& name) const { - return const_cast(this)->GetNodeArg(name); - } - - /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. - @param name The NodeArg name. - @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. - @returns NodeArg reference. - */ - NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) { - auto iter = node_args_.find(name); - if (iter != node_args_.end()) { - return *(iter->second); - } - - auto result = node_args_.insert(std::make_pair(name, std::make_unique(name, p_arg_type))); - return *(result.first->second); - } - - /** Generate a unique name.in this Graph for a NodeArg */ - std::string GenerateNodeArgName(const std::string& base_name); - - /** Generate a unique name.in this Graph for a Node */ - std::string GenerateNodeName(const std::string& base_name); - - /** Add a Node to this Graph. - @param name The Node name. Must be unique in this Graph. - @param op_type The operator type. e.g. ONNX operator name. - @param description Arbitrary description of the Node. - @param input_args The explicit inputs to this Node. - @param output_args The outputs from this Node. - @param attributes Optional NodeAttributes to add. - @param domain The domain for the op_type. - @returns Reference to the new Node. - @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe. - */ - Node& AddNode(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes* attributes = nullptr, - const std::string& domain = ""); - - /** Copy a Node and add it to this Graph. - @param other Node to copy - @returns Reference to the Node that was created and added to this Graph. - @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe. - */ - Node& AddNode(const Node& other); - - /** Remove a Node from this Graph and free it. - The output edges of this specified node MUST have been removed before removing the node. - The input edges of this specified node is removed while removing the node. The process of - removing a node from a graph should be, - 1. Remove out edges of this specified node. - 2. Remove this specified node. - 3. Add new input edges connected with all out nodes. - @returns true if the node_index was valid - @remarks Do not call AddNode and Remove Node concurrently as they are not thread-safe. - */ - bool RemoveNode(NodeIndex node_index); - - /** Add an edge between two Nodes. - @param src_node_index NodeIndex of source Node that is providing output to the destination Node. - @param dst_node_index NodeIndex of destination Node that is receiving input from the source Node. - @param src_arg_index node arg index of source node. - @param dst_arg_index node arg index of destination node. - */ - void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index); - - /** Remove an edge between two Nodes. - @param src_node_index NodeIndex of source Node to remove an output edge from. - @param dst_node_index NodeIndex of destination Node to remove an input edge from. - @param src_arg_index node arg index of source node. - @param dst_arg_index node arg index of destination node. - */ - void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index); - - /** - Add a control edge between two Nodes in this Graph. - The source Node does not produce output that is directly consumed by the destination Node, however the - destination Node must execute after the source node. The control edge allows this ordering to occur. - */ - bool AddControlEdge(NodeIndex src_node_index, NodeIndex dst_node_index); - - /** Mark the Graph as needing Resolve() to be called. - This should be done after modifying any aspect of the Graph that changes the Nodes or relationships between them. */ - Graph& SetGraphResolveNeeded() noexcept { - graph_resolve_needed_ = true; - return *this; - } - - /** Gets flag indicating whether Graph::Resolve needs to be called before using the Graph. */ - bool GraphResolveNeeded() const noexcept { - return graph_resolve_needed_; - } - - /** Sets flag that Graph::graph_proto_ needs to be updated to reflect changes in the Graph. */ - Graph& SetGraphProtoSyncNeeded() noexcept { - graph_proto_sync_needed_ = true; - return *this; - } - - /** Gets flag indicating whether Graph::graph_proto_ needs to be synchronized with this Graph instance. */ - bool GraphProtoSyncNeeded() const noexcept { - return graph_proto_sync_needed_; - } - - /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs, - up to their source node/s. - @param from NodeIndex values for a set of Nodes to traverse from. - @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been. - @param leave Visit function invoked on the node after its parents have all been visited. - @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. - */ - void ReverseDFSFrom(const std::vector& from, - const std::function& enter, - const std::function& leave, - const std::function& comp = {}) const; - - /** Performs a reverse depth-first search (DFS) traversal from a set of nodes, via their inputs, - up to their source node/s. - @param from Set of Nodes to traverse from. - @param enter Visit function that will be invoked on a node when it is visited but its parents haven't been. - @param leave Visit function invoked on the node after its parents have all been visited. - @param comp Comparison function to stabilize the traversal order by making Node ordering deterministic. - */ - void ReverseDFSFrom(const std::vector& from, - const std::function& enter, - const std::function& leave, - const std::function& comp = {}) const; - - /** Gets the map of operator domains to their opset versions. */ - const std::unordered_map& DomainToVersionMap() const noexcept { - return domain_to_version_; - } - - /** Gets the GraphProto representation of this Graph. */ - const ONNX_NAMESPACE::GraphProto& ToGraphProto(); - - /** Gets the ISchemaRegistry instances being used with this Graph. */ - IOnnxRuntimeOpSchemaCollectionPtr GetSchemaRegistry() const; - - /** - Create a single Node that is the result of the a fusion of multiple nodes in this Graph. - @param sub_graph A IndexSubGraph instance with details of the nodes to fuse. - @param fused_node_name The name for the new Node. - @returns Node with fused subgraph. - */ - Node& FuseSubGraph(std::unique_ptr sub_graph, const std::string& fused_node_name); - - /** - Directly insert the nodes in the function Node provided into this Graph. - @param node Node with Node::Type of Node::Type::Fused - @returns Status indicating success or providing an error message. - */ - Status InlineFunction(Node& node); - - /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will - be used as a GraphProto attribute in another Node.. - e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to - define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs - when the Graph is resolved. - */ - void AddOuterScopeNodeArg(const std::string& name) { - ORT_IGNORE_RETURN_VALUE(outer_scope_node_arg_names_.insert(name)); - } - - /** When programmatically constructing a Graph, explicitly set the order to use for graph inputs when the graph is - resolved. - This will determine the graph input order when the Graph is converted to a GraphProto by Graph::ToGraphProto. - @param inputs NodeArgs that represent graph inputs which need to be explicitly ordered. - Any graph inputs not in this list will be appended to the ordered graph input list, in the order that they were first - used by Nodes (i.e. the order of Node creation implicitly determines the ordering). - @remarks If the Graph was loaded from a GraphProto this has no effect.*/ - void SetInputOrder(const std::vector inputs) { - graph_input_order_ = inputs; - } - - /** When programmatically constructing a Graph, explicitly set the order to use for graph outputs when the graph is - resolved. - This will determine the graph output order when the Graph is converted to a GraphProto by Graph::ToGraphProto. - @param outputs NodeArgs that represent graph outputs which need to be explicitly ordered. - Any graph outputs not in this list will be appended to the ordered graph output list, in the order that they were first - produced by Nodes (i.e. the order of Node creation implicitly determines the ordering). - @remarks If the Graph was loaded from a GraphProto this has no effect.*/ - void SetOutputOrder(const std::vector outputs) { - graph_output_order_ = outputs; - } - - /** Construct a Graph instance for a subgraph that is created from a GraphProto attribute in a Node. - Inherits some properties from the parent graph. - @param parent_graph The Graph containing the Node which has a GraphProto attribute. - @param subgraph_proto The GraphProto from the Node attribute. - */ - Graph(Graph& parent_graph, ONNX_NAMESPACE::GraphProto& subgraph_proto); - - virtual ~Graph(); - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph); - - // This friendship relationship should only be used to call Graph::Graph and - // Graph::LoadGraph All other access should be via the public API. - friend class Model; - - Graph() = delete; - - // Constructor: Given a loaded from model file, construct - // a object. Used by Model to create a Graph instance. - Graph(ONNX_NAMESPACE::GraphProto* graph_proto, - const std::unordered_map& domain_to_version, - Version ir_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - const std::unordered_map& model_functions = {}); - - // internal use by the Graph class only - Graph(ONNX_NAMESPACE::GraphProto* graph_proto, - const std::unordered_map& domain_to_version, - Version ir_version, - IOnnxRuntimeOpSchemaCollectionPtr schema_registry, - Graph* parent_graph, - const std::unordered_map& model_functions = {}); - - // Add node with specified . - Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, - const ArgNameToTypeMap& name_to_type); - - Version IrVersion() const noexcept { - return ir_version_; - } - - Graph& GraphResolveNeeded(bool needed) noexcept { - graph_resolve_needed_ = needed; - return *this; - } - - Graph& GraphProtoSyncNeeded(bool needed) noexcept { - graph_proto_sync_needed_ = needed; - return *this; - } - - // During the Resolve of a Graph it is necessary to recursively descend into subgraphs (created from GraphProto - // Node attributes in the Graph) if present. - // The ResolveContext holds the collection of values for the current Graph instance, be it the main graph - // or a subgraph, so that the various operations that are part of the Resolve can work iteratively or - // recursively as needed. - struct ResolveContext { - ResolveContext() = default; - - std::unordered_map> output_args; - std::unordered_set inputs_and_initializers; - std::unordered_set outer_scope_node_args; - std::unordered_map node_name_to_index; - std::unordered_set nodes_with_subgraphs; - - void Clear() { - output_args.clear(); - inputs_and_initializers.clear(); - outer_scope_node_args.clear(); - node_name_to_index.clear(); - nodes_with_subgraphs.clear(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ResolveContext); - }; - - // search this and up through any parent_graph_ instance for a NodeArg - NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); - - // Initialize all the graph inputs, initializers and outputs - common::Status InitInputsInitializersOutputs(); - - // recursively accumulate and set the outer scope node args in the resolve context for all subgraphs - // so they can be used to resolve outer scope dependencies when running BuildConnections for the subgraphs. - common::Status SetOuterScopeNodeArgs(const std::unordered_set& outer_scope_node_args); - - // Build and verify node connection (edges). - // Verify NodeArg name/type/shape matching correctly. - common::Status BuildConnections(std::vector& outer_scope_node_args_consumed); - - common::Status VerifyNoDuplicateName(); - - // Check whether <*this> graph is acyclic while performing a topological sort. - // Depth-first going from bottom up through the graph and checking whether there are any back edges. - // NodesInTopologicalOrder is updated with the nodes' indexes in topological - // order if returned is "OK", otherwise it's undefined. - common::Status PerformTopologicalSortAndCheckIsAcyclic(); - - common::Status PerformTypeAndShapeInferencing(); - - enum class Type { - // A main graph. - Main = 1, - // A sub graph (function). - Sub = 2, - }; - - common::Status Resolve(bool no_proto_sync_required); - - // Recursively find all subgraphs including nested subgraphs - void FindAllSubgraphs(std::vector& subgraphs); - - // Iterate this Graph instance and all subgraphs, calling the provided function for each. - common::Status ForThisAndAllSubgraphs(const std::vector& subgraphs, std::function func); - - common::Status InferAndVerifyTypeMatch(Node& node, const ONNX_NAMESPACE::OpSchema& op); - - // perform type and shape inferencing on the subgraph and Resolve to validate - static common::Status InferAndVerifySubgraphTypes(const Node& node, Graph& subgraph, - const std::vector& input_types, - std::vector& output_types); - - // Apply type-inference and type-checking to all inputs and initializers: - common::Status TypeCheckInputsAndInitializers(); - - // Compute set of input and initializer names and checking for duplicate names - common::Status VerifyInputAndInitializerNames(); - - // Infer and set type information across <*this> graph if needed, and verify type/attribute - // information matches between node and op. - common::Status VerifyNodeAndOpMatch(); - - // Set graph inputs/outputs when resolving a graph.. - common::Status SetGraphInputsOutputs(); - - // Sync graph inputs/outputs when serializing to proto. - void SyncGraphInputsOutputs(); - - // Clear all unused initializers - void CleanUnusedInitializers(); - - gsl::not_null AllocateNode(); - - // Release the node. - // @returns false if node_index was invalid. - bool ReleaseNode(NodeIndex node_index); - - Node* NodeAtIndexImpl(NodeIndex node_index) const { - // if we are trying to access a node that doesn't exist there's (most - // likely) either a logic issue or a graph consistency/correctness issue. - // use ORT_ENFORCE to prove that or uncover scenarios where we actually - // expect attempts to retrieve a non-existent node. - ORT_ENFORCE(node_index < nodes_.size(), "Validating no unexpected access using an invalid node_index."); - return nodes_[node_index].get(); - } - - std::vector CreateNodeArgs(const google::protobuf::RepeatedPtrField& names, - const ArgNameToTypeMap& name_to_type_map); - - bool IsSubgraph() const { return parent_graph_ != nullptr; } - - void AddFunction(const ONNX_NAMESPACE::FunctionProto* func_proto); - - // GraphProto to store name, version, initializer. - // When serializing <*this> Graph to a GraphProto, the nodes and - // functions in will also be fed into so that - // it's consistent with <*this> graph. - // This pointer is owned by parent model. - ONNX_NAMESPACE::GraphProto* graph_proto_; - - InitializedTensorSet name_to_initial_tensor_; - std::vector removed_initializer_indexes_; - - Type graph_type_ = Type::Main; - - IOnnxRuntimeOpSchemaCollectionPtr schema_registry_; - - std::vector> function_container_; - - // Graph nodes. - // Element in may be nullptr due to graph optimization. - std::vector> nodes_; - - // Wrapper of Graph nodes to provide iteration services that hide nullptr entries - GraphNodes iterable_nodes_{nodes_}; - - // Number of nodes. - // Normally this is smaller than the size of , as some - // elements in may be removed when doing graph optimization, - // or some elements may be merged, etc. - int num_of_nodes_ = 0; - - // A flag indicates whether <*this> graph needs to be resolved. - bool graph_resolve_needed_ = false; - - bool graph_proto_sync_needed_ = false; - - // The topological order of node index used to do node and op match verification temporarily. - std::vector nodes_in_topological_order_; - - // Full list of graph inputs. Matches number and order of inputs in the GraphProto. - std::vector graph_inputs_including_initializers_; - - // Graph inputs excluding initializers. - std::vector graph_inputs_excluding_initializers_; - - // Graph outputs. - std::vector graph_outputs_; - - // Graph value_info. - std::vector value_info_; - - // All node args owned by <*this> graph. Key is node arg name. - std::unordered_map> node_args_; - - const std::unordered_map domain_to_version_; - - std::unordered_map model_functions_; - - // Model IR version. - Version ir_version_{}; - - int name_generator_ = 0; - - ResolveContext resolve_context_; - - // the parent graph if this is a subgraph. - Graph* parent_graph_; - - // NodeArgs that come from outer scope. Used when building a graph so that - // these don't get recorded as graph inputs in the GraphProto. - std::unordered_set outer_scope_node_arg_names_; - - // Explicit graph input order to be used when constructing a Graph manually. - std::vector graph_input_order_; - - // Explicit graph output order to be used when constructing a Graph manually. - std::vector graph_output_order_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/graph_nodes.h b/onnxruntime/inc/core/graph/graph_nodes.h deleted file mode 100644 index dc30696..0000000 --- a/onnxruntime/inc/core/graph/graph_nodes.h +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -namespace onnxruntime { - -class Node; - -/** -Class that provides iteration over all valid nodes in the Graph. -*/ -class GraphNodes { - using TNodesContainer = std::vector>; - - public: - template - class NodeIterator; - - /** - Construct a GraphNodes instance to provide iteration over all valid nodes in the Graph - @param[in] nodes Nodes to iterate, skipping invalid entries. - */ - explicit GraphNodes(TNodesContainer& nodes) noexcept : nodes_(nodes) {} - - using ConstNodeIterator = NodeIterator; - using MutableNodeIterator = NodeIterator; - - ConstNodeIterator cbegin() const noexcept { - return {nodes_.cbegin(), nodes_.cend()}; - } - - ConstNodeIterator cend() const noexcept { - return {nodes_.cend(), nodes_.cend()}; - } - - ConstNodeIterator begin() const noexcept { - return cbegin(); - } - - ConstNodeIterator end() const noexcept { - return cend(); - } - - MutableNodeIterator begin() noexcept { - return {nodes_.begin(), nodes_.end()}; - } - - MutableNodeIterator end() noexcept { - return {nodes_.end(), nodes_.end()}; - } - - /** - @class NodeIterator - Iterator to provide const and non-const access to valid Node instances in a Graph. - @remarks Skips invalid nodes. - */ - template - class NodeIterator { - // get the type being returned by the iterator. can't use TIterator::value_type as that is always non-const - using IterType = typename std::remove_reference::reference>::type; - // and determine what we will return based on its constness - using T = typename std::conditional::value, - const Node, // return const Node if this is a const iterator - Node>::type; // else return Node - - public: - using iterator_category = std::input_iterator_tag; - using value_type = T; - using difference_type = typename TIterator::difference_type; - using pointer = T*; - using reference = T&; - using const_reference = std::add_const_t; - - /** Construct a NodeInterator and move to the first valid node. */ - NodeIterator(TIterator current, const TIterator end) noexcept : current_{current}, end_{end} { - // skip to next valid node, stopping at end if none are found - while (current_ < end && *current_ == nullptr) { - ++current_; - } - } - - bool operator==(const NodeIterator& other) const noexcept { - return (current_ == other.current_); - } - - bool operator!=(const NodeIterator& other) const noexcept { - return (current_ != other.current_); - } - - void operator++() { - if (current_ < end_) { - while (++current_ != end_) { - if (*current_ != nullptr) break; - } - } - } - - NodeIterator operator++(int) { - NodeIterator tmp{*this}; - ++(*this); - - return tmp; - } - - /** Return the current Node&. This will be const if the iterator was returned from a const GraphNodes instance. */ - reference operator*() { - // if iterator is valid we always have a non-nullptr node - // if this is a nullptr we're at end_ and this shouldn't be being called - return **current_; - } - - pointer operator->() { - return current_->get(); - } - - private: - TIterator current_; - const TIterator end_; - }; - - private: - TNodesContainer& nodes_; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/graph_transformer.h b/onnxruntime/inc/core/graph/graph_transformer.h deleted file mode 100644 index e1978bf..0000000 --- a/onnxruntime/inc/core/graph/graph_transformer.h +++ /dev/null @@ -1,118 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/graph/graph_viewer.h" -#include "core/graph/rewrite_rule.h" - -namespace onnxruntime { - -/** -@class GraphTransformer - -The interface for in-place transformation of a Graph. -*/ -class GraphTransformer { - public: - GraphTransformer(const std::string& name, const std::string& desc) - : name_(name), desc_(desc) { - } - - virtual ~GraphTransformer() = default; - - /** Gets the name of this graph transformer. */ - const std::string& Name() const noexcept { - return name_; - } - - /** Gets the description of this graph transformer. */ - const std::string& Description() const noexcept { - return desc_; - } - - /** Apply the in-place transformation defined by this transformer to the provided Graph instance. - @param[out] modified Set to true if the Graph was modified. - @returns Status with success or error information. - */ - virtual common::Status Apply(Graph& graph, bool& modified) const = 0; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphTransformer); - - const std::string name_; - const std::string desc_; -}; - -/** -@class RuleBasedGraphTransformer - -Rule based graph transformer that provides an API to register rewrite rules, -and an API to apply all applicable rules to a Graph. - -Represents an IGraphTransformer determined by a set of rewrite-rules. -The transformer will apply all the rewrite-rules iteratively as determined by the underlying rewriting-strategy. -Several rewriting-strategies are possible when traversing the graph and applying rewrite rules, -each with different trade offs. At the moment, we define one that performs top-down traversal of nodes. - -@TODO: Is a bottom-up traversal more efficient? -@TODO: Is it worth adding the max number of passes a rule should be applied for? -@TODO: We need to define a contract about whether a rewrite rule is allowed to leave - the graph in an inconsistent state (this will determine when and where we will be - calling Graph::resolve(). -*/ -class RuleBasedGraphTransformer : public GraphTransformer { - public: - RuleBasedGraphTransformer(const std::string& name, const std::string& desc) : GraphTransformer(name, desc) {} - - /** - Register a rewriting rule. - - @TODO (revisit needed): Using OpSignature* here will ask that OpSignature should be stored globally. - Otherwise, there will be multiple addresses/pointers for the same operator or function. - To avoid this, we may use OpSignature ID as the key, which should be name_domain_version. - We will use the string type instead of the OpSchema for now. We should probably add a version as well. - */ - Status Register(const std::string& op_type, std::unique_ptr rule); - - /** Check if the given op_type has any rules registered for it - @returns true if there are rules registered for this op_type.*/ - bool HasRules(const std::string& op_type) const { - return op_to_rules_.find(op_type) != op_to_rules_.cend(); - } - - /** - Gets the rewrite rules for the given op_type. - @returns a pointer to the vector containing all the rewrite rules registered for op_type if found. nullptr - otherwise. - */ - const std::vector>* GetRewriteRules(const std::string& op_type) const { - auto entry = op_to_rules_.find(op_type); - if (entry != op_to_rules_.cend()) - return &entry->second; - - return nullptr; - } - - private: - using RewriteRuleSet = std::unordered_map>>; - - RewriteRuleSet op_to_rules_; -}; - -/** -@class TopDownRuleBasedTransformer - -This is a rule-based Graph transformer that applies rules by performing top-down passes of the Graph. -*/ -class TopDownRuleBasedTransformer : public RuleBasedGraphTransformer { - public: - TopDownRuleBasedTransformer(const std::string& name, const std::string& desc) - : RuleBasedGraphTransformer(name, desc) {} - - // Performs a single top-down traversal of the graph and applies all registered rules. - common::Status Apply(Graph& graph, bool& modified) const override; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/graph_viewer.h b/onnxruntime/inc/core/graph/graph_viewer.h deleted file mode 100644 index 346cc53..0000000 --- a/onnxruntime/inc/core/graph/graph_viewer.h +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/graph/graph.h" - -namespace onnxruntime { -class Function; -struct IndexedSubGraph; -} // namespace onnxruntime - -namespace onnxruntime { - -/** -@class GraphViewer -Class that provides a read-only view of the Graph. -@remarks If the underlying Graph is changed, GetNodesInTopologicalOrder and GetRootNodes may become invalid. -*/ -class GraphViewer { - public: - /** - Construct a GraphViewer from the provided Graph instance. - */ - explicit GraphViewer(const Graph& graph); - - /** Gets the Graph name. */ - const std::string& Name() const noexcept; - - /** Gets the Graph description. */ - const std::string& Description() const noexcept; - - /** - Gets a tensor created from an initializer. - @param tensor_name The tensor name - @param[out] value Sets the pointer to the TensorProto if found, or nullptr if not. - @returns True if found. False if not. - */ - bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const; - - /** - Gets the Graph inputs, excluding initializers. - @returns Collection of NodeArg pointers for the graph inputs, excluding inputs that have matching initializers. - @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. - */ - const std::vector& GetInputs() const noexcept; - - /** - Gets the Graph inputs, including any initializers. - @returns Collection of NodeArg pointers for all the graph inputs. - @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. - */ - const std::vector& GetInputsIncludingInitializers() const noexcept; - - /** - Gets the Graph outputs. - @returns Collection of NodeArg pointers for all the graph outputs. - @remarks No nullptr values in the returned collection. The order will be the same as in the GraphProto. - */ - const std::vector& GetOutputs() const noexcept; - - /** Gets all ValueInfo NodeArg instances in the Graph. */ - const std::vector& GetValueInfo() const noexcept; - - /** - Gets the Node instance at the specified index. - @param node_index Index to retrieve Node from. - @remarks May return nullptr if index no longer points to a valid node due to the node being freed. - */ - const Node* GetNode(NodeIndex node_index) const; - - /** Gets an iterator over all the valid Nodes in the Graph. */ - const GraphNodes& Nodes() const noexcept; - - /** Gets the number of valid nodes in the Graph. */ - int NumberOfNodes() const noexcept; - - /** Gets the maximum NodeIndex value used by Nodes in the Graph. */ - int MaxNodeIndex() const noexcept; - - /** Gets the NodeIndex values for the Graph nodes, sorted into topological order. */ - const std::vector& GetNodesInTopologicalOrder() const; - - /** - Gets the NodeIndex values for the root nodes in the Graph. - The root nodes are the topmost nodes in the Graph that receive inputs from the Graph inputs - and no other nodes in the Graph. - */ - const std::vector& GetRootNodes() const; - - /** Gets all tensors created from initializers. */ - const InitializedTensorSet& GetAllInitializedTensors() const noexcept; - - /** - Gets the NodeArg instance for the given name. - @returns A NodeArg if found, a nullptr if not. - */ - const NodeArg* GetNodeArg(const std::string& name) const; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphViewer); - - const Graph* graph_; - - // The NodeIndex values of the graph nodes sorted in topological order. - std::vector nodes_in_topological_order_; - // Graph root nodes. - std::vector root_nodes_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/indexed_sub_graph.h b/onnxruntime/inc/core/graph/indexed_sub_graph.h deleted file mode 100644 index 7856c94..0000000 --- a/onnxruntime/inc/core/graph/indexed_sub_graph.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "core/graph/basic_types.h" -#include "core/graph/onnx_protobuf.h" - -namespace onnxruntime { - -class OpKernel; -class OpKernelInfo; - -/** -@class IndexedSubGraph - -Class containing information about a subgraph of Nodes from a Graph. -It contains a NodeIndex array of the Nodes covered by the subgraph, -and the meta definition needed for representing this subgraph as a FunctionProto, -which could be serialized/saved to a model file. -*/ -struct IndexedSubGraph { - struct MetaDef { - std::string name; ///< Name of customized SubGraph/FunctionProto - std::string domain; ///< Domain of customized SubGraph/FunctionProto - int since_version; ///< Since version of customized SubGraph/FunctionProto. - - ONNX_NAMESPACE::OperatorStatus status; ///< Status of customized SubGraph/FunctionProto. - - std::vector inputs; ///< Inputs of customized SubGraph/FunctionProto. - std::vector outputs; ///< Outputs of customized SubGraph/FunctionProto. - NodeAttributes attributes; ///< Attributes of customized SubGraph/FunctionProto. - - std::string doc_string; ///< Doc string of customized SubGraph/FunctionProto. - }; - - /** Nodes covered by this subgraph. The NodeIndex values are from the parent Graph.*/ - std::vector nodes; - - /** Set the meta definition needed to represent this subgraph as a FunctionProto - It's needed IF AND ONLY IF there are multiple indexes contained in #nodes. */ - void SetMetaDef(std::unique_ptr& meta_def_) { - meta_def = std::move(meta_def_); - } - - /** Gets the meta definition needed to represent this subgraph as a FunctionProto. - @returns MetaDef instance if it has been set. nullptr if not. */ - const MetaDef* GetMetaDef() const { - return meta_def.get(); - } - - private: - // subgraph meta definition. - std::unique_ptr meta_def; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/node_arg.h b/onnxruntime/inc/core/graph/node_arg.h deleted file mode 100644 index 06ff04c..0000000 --- a/onnxruntime/inc/core/graph/node_arg.h +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/graph/onnx_protobuf.h" - -namespace onnxruntime { - -// Node argument definition, for both input and output, -// including arg name, arg type (contains both type and shape). -// -// Design Question: in my opinion, shape should not be part of type. -// We may align the protobuf design with our operator registry interface, -// which has type specified for each operator, but no shape. Well, shape -// should be inferred with a separate shape inference function given -// input shapes, or input tensor data sometimes. -// With shape as part of type (current protobuf design), -// 1) we'll have to split the "TypeProto" into type and shape in this internal -// representation interface so that it could be easily used when doing type -// inference and matching with operator registry. -// 2) SetType should be always called before SetShape, otherwise, SetShape() -// will fail. Because shape is located in a TypeProto. -// Thoughts? -// - -/** -@class NodeArg -Class representing a data type that is input or output for a Node, including the shape if it is a Tensor. -*/ -class NodeArg { - public: - /** - Construct a new NodeArg. - @param name The name to use. - @param p_arg_type Optional TypeProto specifying type and shape information. - */ - NodeArg(const std::string& name, - const ONNX_NAMESPACE::TypeProto* p_arg_type); - - NodeArg(NodeArg&& other) = default; - - /** Gets the name. */ - const std::string& Name() const noexcept; - - /** Gets the data type. */ - ONNX_NAMESPACE::DataType Type() const noexcept; - - /** Gets the TypeProto - @returns TypeProto if type is set. nullptr otherwise. */ - const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept; - - /** Gets the shape if NodeArg is for a Tensor. - @returns TensorShapeProto if shape is set. nullptr if there's no shape specified. */ - const ONNX_NAMESPACE::TensorShapeProto* Shape() const; - - /** Sets the shape. - @remarks Shape can only be set if the TypeProto was provided to the ctor, or #SetType has been called, - as the shape information is stored as part of TypeProto. */ - void SetShape(const ONNX_NAMESPACE::TensorShapeProto& shape); - - /** Validate and merge type [and shape] info from input_type. - @returns Success unless there is existing type or shape info that can't be cleanly updated. */ - common::Status UpdateTypeAndShape(const ONNX_NAMESPACE::TypeProto& input_type); - - /** Validate and merge type [and shape] info from node_arg. - @returns Success unless there is existing type or shape info that can't be cleanly updated. */ - common::Status UpdateTypeAndShape(const NodeArg& node_arg); - - /** Gets this NodeArg as a ValueInfoProto. */ - const NodeArgInfo& ToProto() const noexcept { return node_arg_info_; } - - /** Gets a flag indicating whether this NodeArg exists or not. - Optional inputs are allowed in ONNX and an empty #Name represents a non-existent input argument. */ - bool Exists() const noexcept; - - private: - ORT_DISALLOW_COPY_AND_ASSIGNMENT(NodeArg); - friend class Graph; - - void SetType(ONNX_NAMESPACE::DataType p_type); - void SetType(const ONNX_NAMESPACE::TypeProto& type_proto); - - NodeArg& operator=(NodeArg&& other) = delete; - - // Node arg PType. - ONNX_NAMESPACE::DataType type_; - - // Node arg name, type and shape. - NodeArgInfo node_arg_info_; - - // Flag indicates whether <*this> node arg exists or not. - bool exists_; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/onnx_protobuf.h b/onnxruntime/inc/core/graph/onnx_protobuf.h deleted file mode 100644 index ff68b41..0000000 --- a/onnxruntime/inc/core/graph/onnx_protobuf.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -//TODO(): delete this file from public interface -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#else -#pragma warning(push) -#pragma warning(disable : 4018) /*'expression' : signed/unsigned mismatch */ -#pragma warning(disable : 4065) /*switch statement contains 'default' but no 'case' labels*/ -#pragma warning(disable : 4100) -#pragma warning(disable : 4146) /*unary minus operator applied to unsigned type, result still unsigned*/ -#pragma warning(disable : 4244) /*'conversion' conversion from 'type1' to 'type2', possible loss of data*/ -#pragma warning(disable : 4251) /*'identifier' : class 'type' needs to have dll-interface to be used by clients of class 'type2'*/ -#pragma warning(disable : 4267) /*'var' : conversion from 'size_t' to 'type', possible loss of data*/ -#pragma warning(disable : 4305) /*'identifier' : truncation from 'type1' to 'type2'*/ -#pragma warning(disable : 4307) /*'operator' : integral constant overflow*/ -#pragma warning(disable : 4309) /*'conversion' : truncation of constant value*/ -#pragma warning(disable : 4334) /*'operator' : result of 32-bit shift implicitly converted to 64 bits (was 64-bit shift intended?)*/ -#pragma warning(disable : 4355) /*'this' : used in base member initializer list*/ -#pragma warning(disable : 4506) /*no definition for inline function 'function'*/ -#pragma warning(disable : 4800) /*'type' : forcing value to bool 'true' or 'false' (performance warning)*/ -#pragma warning(disable : 4996) /*The compiler encountered a deprecated declaration.*/ -#endif -#include "onnx/defs/schema.h" -#include "onnx/onnx_pb.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#else -#pragma warning(pop) -#endif diff --git a/onnxruntime/inc/core/graph/rewrite_rule.h b/onnxruntime/inc/core/graph/rewrite_rule.h deleted file mode 100644 index ec3cd95..0000000 --- a/onnxruntime/inc/core/graph/rewrite_rule.h +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/graph/graph_viewer.h" - -namespace onnxruntime { - -/** -@class RewriteRule - -The base class for a rewrite rule. A rewrite rule represents a semantics-preserving -transformation of a computation-graph. It can be used to represent, for example, -the elimination of operators that serve as no-ops (for example, dropout during -inference), as well as inlining of "function" definitions or the dual (replacing -a complex expression by an equivalent function-call). Unlike the more general -IGraphTransformer, a rewrite-rule is applied at a single node, representing the -root of an expression that is rewritten. -*/ -class RewriteRule { - public: - RewriteRule(const std::string& name, const std::string& desc) - : name_(name), desc_(desc) { - } - - virtual ~RewriteRule() = default; - - /** Gets the name of this rewrite rule. */ - const std::string& Name() const noexcept { - return name_; - } - - /** Gets the description of this rewrite rule. */ - const std::string& Description() const noexcept { - return desc_; - } - - /** Checks if the condition of the rule is satisfied, and if so applies the rule. - @param[in] graph_editor The GraphEditor. - @param[in] node The Node to apply the rewrite to. - @param[out] modified Set to indicate whether the node was modified or not. - @returns Status indicating success or providing error information */ - common::Status CheckConditionAndApply(Graph& graph, Node& node, bool& modified) { - return SatisfyCondition(node) ? Apply(graph, node, modified) : Status::OK(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RewriteRule); - - const std::string name_; - const std::string desc_; - - /** Check if the Node satisfies a condition. - The rewrite rule is applied if the condition function returns true. This can include - a more complex pattern matching (conditions on the ascending or descending nodes of the - node for which this rule was triggered) or some other properties of the nodes. */ - virtual bool SatisfyCondition(const Node& node) = 0; - - /** - Apply the rewrite rule to a specific node. - The transformation happens in-place. The return-value of node may be different - from the input-value due to rewriting. - The value of "modified" indicates if the graph was modified or not. */ - virtual common::Status Apply(Graph& graph, Node& node, bool& modified) = 0; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/graph/schema_registry.h b/onnxruntime/inc/core/graph/schema_registry.h deleted file mode 100644 index d78eebd..0000000 --- a/onnxruntime/inc/core/graph/schema_registry.h +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "core/graph/constants.h" -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/platform/ort_mutex.h" - -#ifdef __GNUC__ -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wignored-qualifiers" -#pragma GCC diagnostic ignored "-Wunused-parameter" -#endif -#include "onnx/defs/schema.h" -#ifdef __GNUC__ -#pragma GCC diagnostic pop -#endif -#include -#include -#include "sstream" - -namespace onnxruntime { -using OpName_Domain_Version_Schema_Map = std::unordered_map< - std::string, - std::unordered_map>>; - -/** -@struct SchemaRegistryVersion -onnxruntime schema registry is a supplement to the built-in ONNX schema. -Every schema registry represent a collection of schema deltas from baseline_opset_version to opset_version -*/ -struct SchemaRegistryVersion { - int baseline_opset_version; - int opset_version; -}; - -using DomainToVersionMap = std::unordered_map; -using DomainToVersionRangeMap = std::unordered_map; - -class IOnnxRuntimeOpSchemaCollection : public ONNX_NAMESPACE::ISchemaRegistry { - public: - virtual DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const = 0; - - using ISchemaRegistry::GetSchema; - - virtual const ONNX_NAMESPACE::OpSchema* GetSchema( - const std::string& key, - const int maxInclusiveVersion, - const std::string& domain) const final { - const ONNX_NAMESPACE::OpSchema* latest_schema = nullptr; - int earliest_opset_where_unchanged = std::numeric_limits::max(); - GetSchemaAndHistory(key, maxInclusiveVersion, domain, &latest_schema, &earliest_opset_where_unchanged); - - assert(latest_schema == nullptr || (latest_schema->SinceVersion() <= maxInclusiveVersion && - earliest_opset_where_unchanged == latest_schema->SinceVersion())); - - return latest_schema; - } - - virtual void GetSchemaAndHistory( - const std::string& key, - int maxInclusiveVersion, - const std::string& domain, - const ONNX_NAMESPACE::OpSchema** latest_schema, - int* earliest_opset_where_unchanged) const = 0; -}; - -/** -@class OnnxRuntimeOpSchemaRegistry - -OnnxRuntimeOpSchemaRegistry is used to provide supplement for built-in ONNX schemas. -Each OnnxRuntimeOpSchemaRegistry must register complete opsets delta from a baseline version to max opset version. -(Please notice that baseline opsets are not include in the delta) - -For example, ONNXRuntime is build with ONNX 1.2 which is at opset7, to use ONNX opset8 and opset9, -user could create a OnnxRuntimeOpSchemaRegistry and config it as {baseline_opset_version = 7, opset_version = 9} -it means this OnnxRuntimeOpSchemaRegistry contains the complete delta from opset7 to opset9. -*/ -class OnnxRuntimeOpSchemaRegistry : public IOnnxRuntimeOpSchemaCollection { - public: - OnnxRuntimeOpSchemaRegistry() = default; - - common::Status SetBaselineAndOpsetVersionForDomain( - const std::string& domain, - int baseline_opset_version, - int opset_version); - - DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override; - - // OnnxRuntimeOpSchemaRegistry must register complete delta for a opset. - common::Status RegisterOpSet( - std::vector& schemas, - const std::string& domain, - int baseline_opset_version, - int opset_version); - - using IOnnxRuntimeOpSchemaCollection::GetSchema; - - void GetSchemaAndHistory( - const std::string& key, - const int maxInclusiveVersion, - const std::string& domain, - const ONNX_NAMESPACE::OpSchema** latest_schema, - int* earliest_opset_where_unchanged) const override; - - bool empty() const { - return map_.empty(); - } - - private: - common::Status RegisterOpSchema(ONNX_NAMESPACE::OpSchema&& op_schema); - - common::Status RegisterOpSchemaInternal(ONNX_NAMESPACE::OpSchema&& op_schema); - - OrtMutex mutex_; - - OpName_Domain_Version_Schema_Map map_; - DomainToVersionRangeMap domain_version_range_map_; -}; - -/** -@class SchemaRegistryManager - -SchemaRegistryManager provides a view based on built-in ONNX schema and a list of -OnnxRuntimeOpSchemaRegistry as supplement. - -The user needs to make sure the customized schema registry is valid, otherwise the behavior is undefined. - -@todo We may add more consistency checks later. -*/ -class SchemaRegistryManager : public onnxruntime::IOnnxRuntimeOpSchemaCollection { - public: - /** - Register a new schema registry instance. - @remarks The schema registry priority is the reverse of registration order. i.e. the last registry added will be - searched first for a matching OpSchema. - */ - void RegisterRegistry(std::shared_ptr registry); - - /** Gets the latest opset versions. - @param is_onnx_only If true, return the latest ONNX schemas. If false, return the latest schemas for all domains. - */ - DomainToVersionMap GetLatestOpsetVersions(bool is_onnx_only) const override; - - /** - Gets the OpSchema and its history. - Searches custom schema registries starting with the last one added. \ - If the OpSchema is not found the default ONNX schema registry is searched. - - @param key Operator type. - @param max_inclusive_version Maximum opset version allowed, inclusive. - @param domain The domain of the operator. - @param[out] latest_schema Returns the latest OpSchema if found. nullptr otherwise. - @param[out] earliest_opset_where_unchanged The earliest opset version preceding max_inclusive_version where the - operator is known to be unchanged. - */ - void GetSchemaAndHistory( - const std::string& key, - const int max_inclusive_version, - const std::string& domain, - const ONNX_NAMESPACE::OpSchema** latest_schema, - int* earliest_opset_where_unchanged) const override; - - private: - std::deque> registries; -}; - -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/platform/ort_mutex.h b/onnxruntime/inc/core/platform/ort_mutex.h deleted file mode 100644 index d8415b6..0000000 --- a/onnxruntime/inc/core/platform/ort_mutex.h +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#ifdef _WIN32 -#include -#include -namespace onnxruntime { -using OrtMutex = std::mutex; -using OrtCondVar = std::condition_variable; -} // namespace onnxruntime -#else -#ifdef USE_NSYNC -#include "nsync.h" -#include //for unique_lock -#include //for cv_status -#else -#include -#include -#include -#include -#include -#endif -namespace onnxruntime { - -class OrtMutex { -#ifdef USE_NSYNC - nsync::nsync_mu data_ = NSYNC_MU_INIT; -#else - pthread_mutex_t data_ = PTHREAD_MUTEX_INITIALIZER; -#endif - - public: - constexpr OrtMutex() = default; -#ifdef USE_NSYNC - ~OrtMutex() = default; -#else - ~OrtMutex(); -#endif - - OrtMutex(const OrtMutex&) = delete; - OrtMutex& operator=(const OrtMutex&) = delete; - - void lock(); - bool try_lock() noexcept; - void unlock() noexcept; - -#ifdef USE_NSYNC - using native_handle_type = nsync::nsync_mu*; -#else - using native_handle_type = pthread_mutex_t*; -#endif - native_handle_type native_handle() { return &data_; } -}; - -class OrtCondVar { -#ifdef USE_NSYNC - nsync::nsync_cv native_cv_object = NSYNC_CV_INIT; -#else - pthread_cond_t native_cv_object = PTHREAD_COND_INITIALIZER; -#endif - public: - constexpr OrtCondVar() noexcept = default; - -#ifdef USE_NSYNC - ~OrtCondVar() = default; -#else - ~OrtCondVar(); -#endif - - OrtCondVar(const OrtCondVar&) = delete; - OrtCondVar& operator=(const OrtCondVar&) = delete; - - void notify_one() noexcept; - void notify_all() noexcept; - - void wait(std::unique_lock& __lk); - template - void wait(std::unique_lock& __lk, _Predicate __pred); - - /** - * returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns cv_status::no_timeout. - * @param cond_mutex A unique_lock object. - * @param rel_time A chrono::duration object that specifies the amount of time before the thread wakes up. - * @return returns cv_status::timeout if the wait terminates when Rel_time has elapsed. Otherwise, the method returns cv_status::no_timeout - */ - template - std::cv_status - wait_for(std::unique_lock& cond_mutex, - const std::chrono::duration& rel_time); -#ifdef USE_NSYNC - using native_handle_type = nsync::nsync_cv*; -#else - using native_handle_type = pthread_cond_t*; -#endif - - native_handle_type native_handle() { return &native_cv_object; } - - private: - void timed_wait_impl(std::unique_lock& __lk, - std::chrono::time_point); -}; - -template -void OrtCondVar::wait(std::unique_lock& __lk, _Predicate __pred) { - while (!__pred()) - wait(__lk); -} - -template -std::cv_status -OrtCondVar::wait_for(std::unique_lock& cond_mutex, - const std::chrono::duration& rel_time) { - //TODO: is it possible to use nsync_from_time_point_ ? - using namespace std::chrono; - if (rel_time <= duration::zero()) - return std::cv_status::timeout; - using SystemTimePointFloat = time_point >; - using SystemTimePoint = time_point; - SystemTimePointFloat max_time = SystemTimePoint::max(); - steady_clock::time_point steady_now = steady_clock::now(); - system_clock::time_point system_now = system_clock::now(); - if (max_time - rel_time > system_now) { - nanoseconds remain = duration_cast(rel_time); - if (remain < rel_time) - ++remain; - timed_wait_impl(cond_mutex, system_now + remain); - } else - timed_wait_impl(cond_mutex, SystemTimePoint::max()); - return steady_clock::now() - steady_now < rel_time ? std::cv_status::no_timeout : std::cv_status::timeout; -} -}; // namespace onnxruntime -#endif \ No newline at end of file diff --git a/onnxruntime/inc/core/providers/cpu/cpu_provider_factory.h b/onnxruntime/inc/core/providers/cpu/cpu_provider_factory.h deleted file mode 100644 index 32289eb..0000000 --- a/onnxruntime/inc/core/providers/cpu/cpu_provider_factory.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * \param use_arena zero: false. non-zero: true. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena) -ORT_ALL_ARGS_NONNULL; - -ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out) -ORT_ALL_ARGS_NONNULL; - -#ifdef __cplusplus -} -#endif diff --git a/onnxruntime/inc/core/providers/cuda/cuda_provider_factory.h b/onnxruntime/inc/core/providers/cuda/cuda_provider_factory.h deleted file mode 100644 index 3fc4b7b..0000000 --- a/onnxruntime/inc/core/providers/cuda/cuda_provider_factory.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * \param device_id cuda device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif diff --git a/onnxruntime/inc/core/providers/mkldnn/mkldnn_provider_factory.h b/onnxruntime/inc/core/providers/mkldnn/mkldnn_provider_factory.h deleted file mode 100644 index 03ef115..0000000 --- a/onnxruntime/inc/core/providers/mkldnn/mkldnn_provider_factory.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/** - * \param use_arena zero: false. non-zero: true. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Mkldnn, _In_ OrtSessionOptions* options, int use_arena); - -#ifdef __cplusplus -} -#endif diff --git a/onnxruntime/inc/core/providers/providers.h b/onnxruntime/inc/core/providers/providers.h deleted file mode 100644 index fc16812..0000000 --- a/onnxruntime/inc/core/providers/providers.h +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -namespace onnxruntime { -class IExecutionProvider; - -struct IExecutionProviderFactory { - virtual ~IExecutionProviderFactory() {} - virtual std::unique_ptr CreateProvider() = 0; -}; -} // namespace onnxruntime diff --git a/onnxruntime/inc/core/session/onnxruntime_cxx_api.h b/onnxruntime/inc/core/session/onnxruntime_cxx_api.h deleted file mode 100644 index ab916f5..0000000 --- a/onnxruntime/inc/core/session/onnxruntime_cxx_api.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include "onnxruntime_c_api.h" -#include -#include -#include -#include - -//TODO: encode error code in the message? -#define ORT_THROW_ON_ERROR(expr) \ - do { \ - OrtStatus* onnx_status = (expr); \ - if (onnx_status != nullptr) { \ - std::string ort_error_message = OrtGetErrorMessage(onnx_status); \ - OrtReleaseStatus(onnx_status); \ - throw std::runtime_error(ort_error_message); \ - } \ - } while (0); - -#define ORT_REDIRECT_SIMPLE_FUNCTION_CALL(NAME) \ - decltype(Ort##NAME(value.get())) NAME() { \ - return Ort##NAME(value.get()); \ - } - -namespace std { -template <> -struct default_delete { - void operator()(OrtAllocator* ptr) { - OrtReleaseAllocator(ptr); - } -}; - -template <> -struct default_delete { - void operator()(OrtEnv* ptr) { - OrtReleaseEnv(ptr); - } -}; - -template <> -struct default_delete { - void operator()(OrtRunOptions* ptr) { - OrtReleaseRunOptions(ptr); - } -}; - -template <> -struct default_delete { - void operator()(OrtTypeInfo* ptr) { - OrtReleaseTypeInfo(ptr); - } -}; - -template <> -struct default_delete { - void operator()(OrtTensorTypeAndShapeInfo* ptr) { - OrtReleaseTensorTypeAndShapeInfo(ptr); - } -}; - -template <> -struct default_delete { - void operator()(OrtSessionOptions* ptr) { - OrtReleaseSessionOptions(ptr); - } -}; -} // namespace std - -namespace onnxruntime { -class SessionOptionsWrapper { - private: - std::unique_ptr value; - OrtEnv* env_; - SessionOptionsWrapper(_In_ OrtEnv* env, OrtSessionOptions* p) : value(p), env_(env){}; - - public: - operator OrtSessionOptions*() { return value.get(); } - - //TODO: for the input arg, should we call addref here? - SessionOptionsWrapper(_In_ OrtEnv* env) : value(OrtCreateSessionOptions()), env_(env){}; - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableSequentialExecution) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableSequentialExecution) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableProfiling) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableMemPattern) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableMemPattern) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(EnableCpuMemArena) - ORT_REDIRECT_SIMPLE_FUNCTION_CALL(DisableCpuMemArena) - void EnableProfiling(_In_ const char* profile_file_prefix) { - OrtEnableProfiling(value.get(), profile_file_prefix); - } - - void SetSessionLogId(const char* logid) { - OrtSetSessionLogId(value.get(), logid); - } - void SetSessionLogVerbosityLevel(uint32_t session_log_verbosity_level) { - OrtSetSessionLogVerbosityLevel(value.get(), session_log_verbosity_level); - } - void SetSessionThreadPoolSize(int session_thread_pool_size) { - OrtSetSessionThreadPoolSize(value.get(), session_thread_pool_size); - } - - SessionOptionsWrapper clone() const { - OrtSessionOptions* p = OrtCloneSessionOptions(value.get()); - return SessionOptionsWrapper(env_, p); - } -#ifdef _WIN32 - OrtSession* OrtCreateSession(_In_ const wchar_t* model_path) { - OrtSession* ret; - ORT_THROW_ON_ERROR(::OrtCreateSession(env_, model_path, value.get(), &ret)); - return ret; - } -#else - OrtSession* OrtCreateSession(_In_ const char* model_path) { - OrtSession* ret; - ORT_THROW_ON_ERROR(::OrtCreateSession(env_, model_path, value.get(), &ret)); - return ret; - } -#endif - void AppendCustomOpLibPath(_In_ const char* lib_path) { - OrtAppendCustomOpLibPath(value.get(), lib_path); - } -}; -inline OrtValue* OrtCreateTensorAsOrtValue(_Inout_ OrtAllocator* env, const std::vector& shape, ONNXTensorElementDataType type) { - OrtValue* ret; - ORT_THROW_ON_ERROR(::OrtCreateTensorAsOrtValue(env, shape.data(), shape.size(), type, &ret)); - return ret; -} - -inline OrtValue* OrtCreateTensorWithDataAsOrtValue(_In_ const OrtAllocatorInfo* info, _In_ void* p_data, size_t p_data_len, const std::vector& shape, ONNXTensorElementDataType type) { - OrtValue* ret; - ORT_THROW_ON_ERROR(::OrtCreateTensorWithDataAsOrtValue(info, p_data, p_data_len, shape.data(), shape.size(), type, &ret)); - return ret; -} - -inline std::vector GetTensorShape(const OrtTensorTypeAndShapeInfo* info) { - size_t dims = OrtGetNumOfDimensions(info); - std::vector ret(dims); - OrtGetDimensions(info, ret.data(), ret.size()); - return ret; -} -} // namespace onnxruntime - -#undef ORT_REDIRECT_SIMPLE_FUNCTION_CALL diff --git a/onnxruntime/inc/core/session/onnxruntime_c_api.h b/onnxruntime/inc/onnxruntime_c_api.h similarity index 70% rename from onnxruntime/inc/core/session/onnxruntime_c_api.h rename to onnxruntime/inc/onnxruntime_c_api.h index b6c51ba..b93451a 100644 --- a/onnxruntime/inc/core/session/onnxruntime_c_api.h +++ b/onnxruntime/inc/onnxruntime_c_api.h @@ -46,6 +46,14 @@ extern "C" { #define ORTCHAR_T char #endif +#ifndef ORT_TSTR +#ifdef _WIN32 +#define ORT_TSTR(X) L##X +#else +#define ORT_TSTR(X) (X) +#endif +#endif + // Any pointer marked with _In_ or _Out_, cannot be NULL. #ifdef __cplusplus @@ -142,6 +150,8 @@ ORT_RUNTIME_CLASS(RunOptions); ORT_RUNTIME_CLASS(TypeInfo); ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(Callback); +ORT_RUNTIME_CLASS(CustomOpDomain); // When passing in an allocator to any ORT function, be sure that the allocator object // is not destroyed until the last allocated object using it is freed. @@ -195,14 +205,12 @@ ORT_API(void, OrtEnableSequentialExecution, _In_ OrtSessionOptions* options); ORT_API(void, OrtDisableSequentialExecution, _In_ OrtSessionOptions* options); // Enable profiling for this session. -ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const char* profile_file_prefix); +ORT_API(void, OrtEnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); ORT_API(void, OrtDisableProfiling, _In_ OrtSessionOptions* options); -// Enable the memory pattern optimization. -// The idea is if the input shapes are the same, we could trace the internal memory allocation -// and generate a memory pattern for future request. So next time we could just do one allocation -// with a big chunk for all the internal memory allocation. +// deprecated ORT_API(void, OrtEnableMemPattern, _In_ OrtSessionOptions* options); +// deprecated ORT_API(void, OrtDisableMemPattern, _In_ OrtSessionOptions* options); // Enable the memory arena on CPU @@ -284,7 +292,7 @@ ORT_API_STATUS(OrtCreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, * \param out Should be freed by calling OrtReleaseValue */ ORT_API_STATUS(OrtCreateTensorWithDataAsOrtValue, _In_ const OrtAllocatorInfo* info, - _In_ void* p_data, size_t p_data_len, _In_ const size_t* shape, size_t shape_len, + _Inout_ void* p_data, size_t p_data_len, _In_ const size_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Out_ OrtValue** out); // This function doesn't work with string tensor @@ -316,8 +324,34 @@ ORT_API_STATUS(OrtGetStringTensorDataLength, _In_ const OrtValue* value, _Out_ s ORT_API_STATUS(OrtGetStringTensorContent, _In_ const OrtValue* value, _Out_ void* s, size_t s_len, _Out_ size_t* offsets, size_t offsets_len); -ORT_API_STATUS(OrtTensorProtoToOrtValue, _Inout_ OrtAllocator* allocator, - _In_ const void* input, int input_len, _Out_ OrtValue** out); +/** + * Create an OrtValue in CPU memory from a serialized TensorProto + * @param input serialized TensorProto object + * @param input_len length of 'input'. + * @param input_file_path A local file path of where the input was loaded from. Can be NULL if the tensor proto doesn't + * have any external data or it was loaded from current working dir. This path could be either a + * relative path or an absolute path. + * @param preallocated A preallocated buffer for the tensor. It should be allocated from CPU memory + * @param preallocated_size Length of the preallocated buffer in bytes, can be computed from + * the OrtGetTensorMemSizeInBytesFromTensorProto function. This function will return an error if the + * preallocated_size is not enough. + * @param out + * @return + */ +ORT_API_STATUS(OrtTensorProtoToOrtValue, _In_ const void* input, int input_len, + _In_opt_ const ORTCHAR_T* input_file_path, _Inout_ void* preallocated, size_t preallocated_size, + _Out_ OrtValue** out, _Out_ OrtCallback** deleter); + +/** + * f will be freed in this call + */ +ORT_API(void, OrtRunCallback, _Frees_ptr_opt_ OrtCallback* f); + +/** + * calculate the memory requirement for the OrtTensorProtoToOrtValue function + */ +ORT_API_STATUS(OrtGetTensorMemSizeInBytesFromTensorProto, _In_ const void* input, int input_len, size_t alignment, + _Out_ size_t* out); /** * Don't free the returned value @@ -373,7 +407,8 @@ typedef enum OrtAllocatorType { } OrtAllocatorType; /** - memory types for allocator, exec provider specific types should be extended in each provider + * memory types for allocator, exec provider specific types should be extended in each provider + * Whenever this struct is updated, please also update the MakeKey function in onnxruntime/core/framework/execution_provider.cc */ typedef enum OrtMemType { OrtMemTypeCPUInput = -2, // Any CPU memory used by non-CPU execution provider @@ -384,6 +419,12 @@ typedef enum OrtMemType { ORT_API_STATUS(OrtCreateAllocatorInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out); +/** + * Convenience function for special case of OrtCreateAllocatorInfo, for the CPU allocator. Uses name = "Cpu" and id = 0. + */ +ORT_API_STATUS(OrtCreateCpuAllocatorInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Out_ OrtAllocatorInfo** out) +ORT_ALL_ARGS_NONNULL; + /** * Test if two allocation info are equal * \return 0, equal. zero, not equal @@ -421,6 +462,120 @@ ORT_ALL_ARGS_NONNULL; ORT_API(const char*, OrtGetErrorMessage, _In_ const OrtStatus* status) ORT_ALL_ARGS_NONNULL; +/** + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ + +/** + * If input OrtValue represents a map, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If input OrtValue represents a sequence, use index to retrieve the index'th element + * of the sequence. + */ +ORT_API_STATUS(OrtGetValue, const OrtValue* value, int index, OrtAllocator* allocator, OrtValue** out); + +/** + * Returns 2 for type map and N for sequence where N is the number of elements + * in the sequence. + */ +ORT_API_STATUS(OrtGetValueCount, const OrtValue* value, size_t* out); + +/** + * To construct a map, use num_values = 2 and 'in' should be an arrary of 2 OrtValues + * representing keys and values. + * To construct a sequence, use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an arrary of N OrtValues. + * \value_type should be either map or sequence. + */ +ORT_API_STATUS(OrtCreateValue, OrtValue** const in, int num_values, enum ONNXType value_type, + OrtValue** out); + +/* + * EXPERIMENTAL APIS - Subject to change. Released as a preview to get feedback and enable early testing +*/ + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options +*/ +struct OrtKernelInfo; +typedef struct OrtKernelInfo OrtKernelInfo; + +/* + * These allow reading node attributes during kernel creation +*/ +ORT_API_STATUS(OrtKernelInfoGetAttribute_float, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ float* out); +ORT_API_STATUS(OrtKernelInfoGetAttribute_int64, _In_ OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out); + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. +*/ +struct OrtCustomOp { + uint32_t version; // Initialize to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined parameter that is passed to the Kernel* callbacks below. + void(ORT_API_CALL* CreateKernel)(_In_ struct OrtCustomOp* op, _In_ OrtKernelInfo* info, _Out_ void** op_kernel); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ struct OrtCustomOp* op); + + // Op kernel callbacks + void(ORT_API_CALL* KernelGetOutputShape)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ size_t output_index, _In_ OrtTensorTypeAndShapeInfo* output); + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtValue** inputs, _In_ size_t input_count, _In_ OrtValue** outputs, _In_ size_t output_count); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); +}; +typedef struct OrtCustomOp OrtCustomOp; + +/* +* Create a custom op domain. After all sessions using it are released, call OrtReleaseCustomOpDomain +*/ +ORT_API(OrtCustomOpDomain*, OrtCreateCustomOpDomain, _In_ const char* domain, _In_ int op_version_start, _In_ int op_version_end); + +/* + * Add custom ops to the OrtCustomOpDomain + * Note: The OrtCustomOp* pointer must remain valid until the OrtCustomOpDomain using it is released +*/ +ORT_API_STATUS(OrtCustomOpDomain_Add, _In_ OrtCustomOpDomain* custom_op_domain, _In_ OrtCustomOp* op); + +/* + * Add a custom op domain to the OrtSessionOptions + * Note: The OrtCustomOpDomain* must not be deleted until the sessions using it are released +*/ +ORT_API_STATUS(OrtAddCustomOpDomain, _In_ OrtSessionOptions* options, OrtCustomOpDomain* custom_op_domain); +/* + * END EXPERIMENTAL +*/ + #ifdef __cplusplus } #endif diff --git a/package-lock.json b/package-lock.json index 5e66320..5b58b2a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,6 +1,6 @@ { "name": "onnxjs-node", - "version": "0.0.4", + "version": "0.3.0", "lockfileVersion": 1, "requires": true, "dependencies": { @@ -162,9 +162,9 @@ "dev": true }, "@types/node": { - "version": "10.12.27", - "resolved": "https://registry.npmjs.org/@types/node/-/node-10.12.27.tgz", - "integrity": "sha512-e9wgeY6gaY21on3ve0xAjgBVjGDWq/xUteK0ujsE53bUoxycMkqfnkUgMt6ffZtykZ5X12Mg3T7Pw4TRCObDKg==" + "version": "10.14.1", + "resolved": "https://registry.npmjs.org/@types/node/-/node-10.14.1.tgz", + "integrity": "sha512-Rymt08vh1GaW4vYB6QP61/5m/CFLGnFZP++bJpWbiNxceNa6RBipDmb413jvtSf/R1gg5a/jQVl2jY4XVRscEA==" }, "@types/rimraf": { "version": "2.0.2", @@ -191,9 +191,9 @@ } }, "ajv": { - "version": "6.9.2", - "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.9.2.tgz", - "integrity": "sha512-4UFy0/LgDo7Oa/+wOAlj44tp9K78u38E5/359eSrqEp1Z5PdVfimCcs7SluXMP755RUQu6d2b4AvF0R1C9RZjg==", + "version": "6.10.0", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.10.0.tgz", + "integrity": "sha512-nffhOpkymDECQyR0mnsUtoCE8RlX38G0rYP+wgLWFyZuUyuuojSSvi/+euOiQBIn63whYwYVIIH1TvE3tu4OEg==", "dev": true, "requires": { "fast-deep-equal": "^2.0.1", @@ -417,9 +417,9 @@ } }, "before-after-hook": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/before-after-hook/-/before-after-hook-1.3.2.tgz", - "integrity": "sha512-zyPgY5dgbf99c0uGUjhY4w+mxqEGxPKg9RQDl34VvrVh2bM31lFN+mwR1ZHepq/KA3VCPk1gwJZL6IIJqjLy2w==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/before-after-hook/-/before-after-hook-1.4.0.tgz", + "integrity": "sha512-l5r9ir56nda3qu14nAXIlyq1MmUSs0meCIaFAh8HwkFwP1F8eToOuS3ah2VAHHcY04jaYD7FpJC5JTXHYRbkzg==", "dev": true }, "big-integer": { @@ -1451,11 +1451,12 @@ } }, "globby": { - "version": "9.0.0", - "resolved": "https://registry.npmjs.org/globby/-/globby-9.0.0.tgz", - "integrity": "sha512-q0qiO/p1w/yJ0hk8V9x1UXlgsXUxlGd0AHUOXZVXBO6aznDtpx7M8D1kBrCAItoPm+4l8r6ATXV1JpjY2SBQOw==", + "version": "9.1.0", + "resolved": "https://registry.npmjs.org/globby/-/globby-9.1.0.tgz", + "integrity": "sha512-VtYjhHr7ncls724Of5W6Kaahz0ag7dB4G62/2HsN+xEKG6SrPzM1AJMerGxQTwJGnN9reeyxdvXbuZYpfssCvg==", "dev": true, "requires": { + "@types/glob": "^7.1.1", "array-union": "^1.0.2", "dir-glob": "^2.2.1", "fast-glob": "^2.2.6", @@ -2037,13 +2038,13 @@ } }, "mem": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/mem/-/mem-4.1.0.tgz", - "integrity": "sha512-I5u6Q1x7wxO0kdOpYBB28xueHADYps5uty/zg936CiG8NTe5sJL8EjrCuLneuDW3PlMdZBGDIn8BirEVdovZvg==", + "version": "4.2.0", + "resolved": "https://registry.npmjs.org/mem/-/mem-4.2.0.tgz", + "integrity": "sha512-5fJxa68urlY0Ir8ijatKa3eRz5lwXnRCTvo9+TbTGAuTFJOwpGcY0X05moBd0nW45965Njt4CDI2GFQoG8DvqA==", "dev": true, "requires": { "map-age-cleaner": "^0.1.1", - "mimic-fn": "^1.0.0", + "mimic-fn": "^2.0.0", "p-is-promise": "^2.0.0" } }, @@ -2125,9 +2126,9 @@ } }, "mimic-fn": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-1.2.0.tgz", - "integrity": "sha512-jf84uxzwiuiIVKiOLpfYk7N46TSy8ubTonmneY9vrpHNAnp0QBt2BxWV9dO3/j+BoVAb+a5G6YDPW3M5HOdMWQ==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/mimic-fn/-/mimic-fn-2.0.0.tgz", + "integrity": "sha512-jbex9Yd/3lmICXwYT6gA/j2mNQGU48wCh/VzRd+/Y/PjYQtlg1gLMdZqvu9s/xH7qKvngxRObl56XZR609IMbA==", "dev": true }, "minimatch": { @@ -2597,9 +2598,9 @@ } }, "onnxjs": { - "version": "0.1.3", - "resolved": "https://registry.npmjs.org/onnxjs/-/onnxjs-0.1.3.tgz", - "integrity": "sha512-oWvzZMtVF7y4hhuHnpldJNiNQKOwhUZafiLXkFeJyAt9ZkKfEk+mw+b0TKn6Omyx3RkezaoPs4T02597cxtunA==", + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/onnxjs/-/onnxjs-0.1.4.tgz", + "integrity": "sha512-3YwgYUYPbKEPrgJuoBwa4j52jTSrRyKuqXDAWrfk+UEXjuYyKNn5+Il/L4LCp+YLBCiHRhQc83Gu0q5P+i+ncA==", "requires": { "ndarray": "^1.0.18", "ndarray-gemm": "^1.0.0", @@ -2665,9 +2666,9 @@ "dev": true }, "p-limit": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.1.0.tgz", - "integrity": "sha512-NhURkNcrVB+8hNfLuysU8enY5xn2KXphsHBaC2YmRNTZRc7RWusw6apSpdEj3jo4CMb6W9nrF6tTnsJsJeyu6g==", + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-2.2.0.tgz", + "integrity": "sha512-pZbTJpoUsCzV48Mc9Nh51VbwO0X9cuPFE8gYwx9BTCt9SF8/b7Zljd2fVgOxhIF/HDTKgpVzs+GPhyKfjLLFRQ==", "dev": true, "requires": { "p-try": "^2.0.0" @@ -3757,9 +3758,9 @@ }, "dependencies": { "camelcase": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.0.0.tgz", - "integrity": "sha512-faqwZqnWxbxn+F1d399ygeamQNy3lPp/H9H6rNrqYh4FSVCtcY+3cub1MxA8o9mDd55mM8Aghuu/kuyYA6VTsA==", + "version": "5.2.0", + "resolved": "https://registry.npmjs.org/camelcase/-/camelcase-5.2.0.tgz", + "integrity": "sha512-IXFsBS2pC+X0j0N/GE7Dm7j3bsEBp+oTpb7F50dwEVX7rf3IgwO9XatnegTsDtniKCUtEJH4fSU6Asw7uoVLfQ==", "dev": true } } diff --git a/package.json b/package.json index 06577e2..76995c8 100644 --- a/package.json +++ b/package.json @@ -1,7 +1,7 @@ { "name": "onnxjs-node", "description": "", - "version": "0.0.4", + "version": "0.3.0", "author": "fs-eire", "main": "./lib/index.js", "scripts": { diff --git a/src/inference-session-wrap.cc b/src/inference-session-wrap.cc index efd9af9..8f543ae 100644 --- a/src/inference-session-wrap.cc +++ b/src/inference-session-wrap.cc @@ -1,5 +1,3 @@ -#include - #include "inference-session-wrap.h" #include "inference-session.h" #include "utils.h" diff --git a/src/inference-session.cc b/src/inference-session.cc index bc583e4..41ab739 100644 --- a/src/inference-session.cc +++ b/src/inference-session.cc @@ -1,4 +1,3 @@ -#include #include #include #include diff --git a/src/inference-session.h b/src/inference-session.h index 06f0ed5..b73b50d 100644 --- a/src/inference-session.h +++ b/src/inference-session.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/onnxruntime_c_api.h" +#include "onnxruntime_c_api.h" #include "tensor.h" diff --git a/src/tensor.cc b/src/tensor.cc index 51aa270..0da9707 100644 --- a/src/tensor.cc +++ b/src/tensor.cc @@ -2,7 +2,7 @@ #include #include -#include "core/session/onnxruntime_c_api.h" +#include "onnxruntime_c_api.h" #include "tensor.h" #include "utils.h" diff --git a/src/tensor.h b/src/tensor.h index a8a7b76..86b80c3 100644 --- a/src/tensor.h +++ b/src/tensor.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/onnxruntime_c_api.h" +#include "onnxruntime_c_api.h" // a simple structure that represents a tensor. //