diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 47356c3fe3608..45f81783421e0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,6 +2055,7 @@ struct KernelContext { void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } private: OrtKernelContext* ctx_; diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 92d2619217f06..16dd9c9b9df42 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -297,28 +297,98 @@ class Tensor : public TensorBase { }; using TensorPtr = std::unique_ptr; - -////////////////////////////// OrtTensorShape ////////////////////////////////// -// struct OrtTensorShape { -// OrtTensorShape(const OrtTensorTypeAndShapeInfo* tensor_shape = nullptr) { -// if (tensor_shape) { -// auto ort_api = GetApi(); -// size_t dims = 0; -// ort_api.GetDimensionsCount(tensor_shape, &dims); -// } -// } -// size_t size() const { -// return dims.size(); -// } -// int64_t operator[](size_t ith) const { -// return dims.at(ith); -// } -// void append(int64_t dim) { -// dims.push_back(dim); -// } -// std::vector dims; -// }; -//////////////////////////// OrtLiteCustomOp //////////////////////////////// +using TensorPtrs = std::vector; + +struct Variadic : public TensorBase { + Variadic(OrtKernelContext* ctx, + size_t indice, + bool is_input) : TensorBase(ctx, + indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } else { + // a Variadic used for output is populated by the Compute so leave tensors_ empty here + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + const void* DataRaw() const override { + ORT_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); + return nullptr; + } + size_t SizeInBytes() const override { + ORT_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); + return 0; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t indice) const { + return tensors_.at(indice); + } + private: + TensorPtrs tensors_; +}; using TensorShapeVec = std::vector; using ShapeInferenceFn = std::function; @@ -374,6 +444,42 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; + auto next = CreateTuple(context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; + auto next = CreateTuple(context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + tensors.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; + auto next = CreateTuple(context, tensors, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + #define CREATE_TUPLE_INPUT(data_type) \ template \ static typename std::enable_if*>::value, std::tuple>::type \ @@ -561,6 +667,46 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + if (!input_types.empty()) { + ORT_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); + } + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + if (!input_types.empty()) { + ORT_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); + } + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + if (!output_types.empty()) { + ORT_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); + } + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + if (!output_types.empty()) { + ORT_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); + } + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + #define PARSE_INPUT_BASE(pack_type, onnx_type) \ template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ @@ -661,12 +807,30 @@ struct OrtLiteCustomOp : public OrtCustomOp { return self->output_types_[indice]; }; - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) { + auto self = reinterpret_cast(op); + return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) { + auto self = reinterpret_cast(op); + return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 62c69306eb5fa..5947463f87ae5 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -184,6 +184,8 @@ static constexpr PATH_TYPE SEQUENCE_MODEL_URI_2 = TSTR("testdata/optional_sequen #endif static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx"); static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_VARIADIC_2 = TSTR("testdata/custom_op_library/copy_variadic_2_inputs_2_outputs.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_VARIADIC_3 = TSTR("testdata/custom_op_library/copy_variadic_3_inputs_3_outputs.onnx"); #if !defined(DISABLE_FLOAT8_TYPES) static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_FLOAT8_URI = TSTR("testdata/custom_op_library/custom_op_test_float8.onnx"); #endif @@ -1406,6 +1408,59 @@ TEST(CApiTest, test_custom_op_library) { #endif } +// It has memory leak. The OrtCustomOpDomain created in custom_op_library.cc:RegisterCustomOps function was not freed +#if defined(__ANDROID__) +TEST(CApiTest, DISABLED_test_custom_op_library) { +// To accomodate a reduced op build pipeline +#elif defined(REDUCED_OPS_BUILD) && defined(USE_CUDA) +TEST(CApiTest, DISABLED_test_custom_op_library) { +#else +TEST(CApiTest, test_custom_op_library_copy_variadic) { +#endif + std::cout << "Running inference using custom op shared library" << std::endl; + + std::vector inputs(2); + inputs[0].name = "input_0"; + inputs[0].dims = {15}; + inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + inputs[1].name = "input_1"; + inputs[1].dims = {15}; + inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f, + 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, + 5.5f, 4.4f, 3.3f, 2.2f, 1.1f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {15}; + std::vector expected_values_y = inputs[1].values; + + onnxruntime::PathString lib_name; +#if defined(_WIN32) + lib_name = ORT_TSTR("custom_op_library.dll"); +#elif defined(__APPLE__) + lib_name = ORT_TSTR("libcustom_op_library.dylib"); +#else + lib_name = ORT_TSTR("./libcustom_op_library.so"); +#endif + + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_VARIADIC_2, + inputs, "output_1", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); + + inputs.push_back({}); + inputs[2].name = "input_2"; + inputs[2].dims = {15}; + inputs[2].values = {6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + + expected_values_y = inputs[2].values; + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_VARIADIC_3, + inputs, "output_2", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); +} + #if !defined(DISABLE_FLOAT8_TYPES) struct InputF8 { diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index 758c96b0f238f..37fc17800fe8c 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -166,6 +166,25 @@ void FilterFloat8(const Ort::Custom::Tensor& floats_in, } #endif +// a sample custom op accepting variadic inputs, and generate variadic outputs by simply 1:1 copying. +template +Ort::Status CopyVariadic(const Ort::Custom::Variadic& inputs, Ort::Custom::Variadic& outputs) { + for (size_t ith_input = 0; ith_input < inputs.Size(); ++ith_input) { + const auto& input = inputs[ith_input]; + const auto& input_shape = input->Shape(); + const T* raw_input = reinterpret_cast(input->DataRaw()); + auto num_elements = input->NumberOfElement(); + T* raw_output = outputs.AllocateOutput(ith_input, input_shape); + if (!raw_output) { + return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL); + } + for (int64_t jth_elem = 0; jth_elem < num_elements; ++jth_elem) { + raw_output[jth_elem] = raw_input[jth_elem]; + } + } + return Ort::Status{nullptr}; +} + void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpOne{Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider", KernelOne)}; static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; @@ -175,6 +194,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; + static const std::unique_ptr c_CopyVariadic{Ort::Custom::CreateLiteCustomOp("CopyVariadic", "CPUExecutionProvider", CopyVariadic)}; #if !defined(DISABLE_FLOAT8_TYPES) static const CustomOpOneFloat8 c_CustomOpOneFloat8; @@ -189,6 +209,8 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_Select.get()); domain.Add(c_Fill.get()); domain.Add(c_Box.get()); + domain.Add(c_CopyVariadic.get()); + #if !defined(DISABLE_FLOAT8_TYPES) domain.Add(&c_CustomOpOneFloat8); domain.Add(c_FilterFloat8.get());