Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
RandyShuai committed Sep 19, 2023
1 parent d00792e commit 8fb1e4f
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 26 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2055,6 +2055,7 @@ struct KernelContext {
void* GetGPUComputeStream() const;
Logger GetLogger() const;
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
OrtKernelContext* GetOrtKernelContext() const { return ctx_; }

private:
OrtKernelContext* ctx_;
Expand Down
216 changes: 190 additions & 26 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -297,28 +297,98 @@ class Tensor<std::string_view> : public TensorBase {
};

using TensorPtr = std::unique_ptr<Custom::TensorBase>;

////////////////////////////// OrtTensorShape //////////////////////////////////
// struct OrtTensorShape {
// OrtTensorShape(const OrtTensorTypeAndShapeInfo* tensor_shape = nullptr) {
// if (tensor_shape) {
// auto ort_api = GetApi();
// size_t dims = 0;
// ort_api.GetDimensionsCount(tensor_shape, &dims);
// }
// }
// size_t size() const {
// return dims.size();
// }
// int64_t operator[](size_t ith) const {
// return dims.at(ith);
// }
// void append(int64_t dim) {
// dims.push_back(dim);
// }
// std::vector<int64_t> dims;
// };
//////////////////////////// OrtLiteCustomOp ////////////////////////////////
using TensorPtrs = std::vector<TensorPtr>;

struct Variadic : public TensorBase {
Variadic(OrtKernelContext* ctx,
size_t indice,
bool is_input) : TensorBase(ctx,
indice,
is_input) {
if (is_input) {
auto input_count = ctx_.GetInputCount();
for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
auto const_value = ctx_.GetInput(indice);
auto type_shape_info = const_value.GetTensorTypeAndShapeInfo();
auto type = type_shape_info.GetElementType();
TensorPtr tensor;
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
tensor = std::make_unique<Custom::Tensor<bool>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
tensor = std::make_unique<Custom::Tensor<float>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
tensor = std::make_unique<Custom::Tensor<double>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
tensor = std::make_unique<Custom::Tensor<uint8_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
tensor = std::make_unique<Custom::Tensor<int8_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
tensor = std::make_unique<Custom::Tensor<uint16_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
tensor = std::make_unique<Custom::Tensor<int16_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
tensor = std::make_unique<Custom::Tensor<uint32_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
tensor = std::make_unique<Custom::Tensor<int32_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
tensor = std::make_unique<Custom::Tensor<uint64_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
tensor = std::make_unique<Custom::Tensor<int64_t>>(ctx, ith_input, true);
break;
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
tensor = std::make_unique<Custom::Tensor<std::string>>(ctx, ith_input, true);
break;
default:
ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
break;
}
tensors_.emplace_back(tensor.release());
} // for
} else {
// a Variadic used for output is populated by the Compute so leave tensors_ empty here
}
}
template <typename T>
T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
auto tensor = std::make_unique<Tensor<T>>(ctx_.GetOrtKernelContext(), ith_output, false);
auto raw_output = tensor.get()->Allocate(shape);
tensors_.emplace_back(tensor.release());
return raw_output;
}
Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
auto tensor = std::make_unique<Tensor<std::string>>(ctx_.GetOrtKernelContext(), ith_output, false);
Tensor<std::string>& output = *tensor;
tensors_.emplace_back(tensor.release());
return output;
}
const void* DataRaw() const override {
ORT_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
return nullptr;
}
size_t SizeInBytes() const override {
ORT_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
return 0;
}
size_t Size() const {
return tensors_.size();
}
const TensorPtr& operator[](size_t indice) const {
return tensors_.at(indice);
}
private:
TensorPtrs tensors_;
};

using TensorShapeVec = std::vector<Ort::TensorTypeAndShapeInfo>;
using ShapeInferenceFn = std::function<void(const TensorShapeVec& input_shapes, TensorShapeVec& output_shape)>;
Expand Down Expand Up @@ -374,6 +444,42 @@ struct OrtLiteCustomOp : public OrtCustomOp {
}
#endif

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(context, ith_input, true));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
CreateTuple(OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
tensors.push_back(std::make_unique<Variadic>(context, ith_output, false));
std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(context, tensors, num_input, num_output, ep);
return std::tuple_cat(current, next);
}

#define CREATE_TUPLE_INPUT(data_type) \
template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
Expand Down Expand Up @@ -561,6 +667,46 @@ struct OrtLiteCustomOp : public OrtCustomOp {
}
#endif

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORT_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!input_types.empty()) {
ORT_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
}
input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORT_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}

template <typename T, typename... Ts>
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
if (!output_types.empty()) {
ORT_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
}
output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ParseArgs<Ts...>(input_types, output_types);
}

#define PARSE_INPUT_BASE(pack_type, onnx_type) \
template <typename T, typename... Ts> \
static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
Expand Down Expand Up @@ -661,12 +807,30 @@ struct OrtLiteCustomOp : public OrtCustomOp {
return self->output_types_[indice];
};

OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};

OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
return INPUT_OUTPUT_OPTIONAL;
OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
};

OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
return 1;
};

OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
return 0;
};

OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
return 1;
};

OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
return 0;
};

OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; };
Expand Down
55 changes: 55 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ static constexpr PATH_TYPE SEQUENCE_MODEL_URI_2 = TSTR("testdata/optional_sequen
#endif
static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx");
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx");
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_VARIADIC_2 = TSTR("testdata/custom_op_library/copy_variadic_2_inputs_2_outputs.onnx");
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_VARIADIC_3 = TSTR("testdata/custom_op_library/copy_variadic_3_inputs_3_outputs.onnx");
#if !defined(DISABLE_FLOAT8_TYPES)
static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_FLOAT8_URI = TSTR("testdata/custom_op_library/custom_op_test_float8.onnx");
#endif
Expand Down Expand Up @@ -1406,6 +1408,59 @@ TEST(CApiTest, test_custom_op_library) {
#endif
}

// It has memory leak. The OrtCustomOpDomain created in custom_op_library.cc:RegisterCustomOps function was not freed
#if defined(__ANDROID__)
TEST(CApiTest, DISABLED_test_custom_op_library) {
// To accomodate a reduced op build pipeline
#elif defined(REDUCED_OPS_BUILD) && defined(USE_CUDA)
TEST(CApiTest, DISABLED_test_custom_op_library) {
#else
TEST(CApiTest, test_custom_op_library_copy_variadic) {
#endif
std::cout << "Running inference using custom op shared library" << std::endl;

std::vector<Input> inputs(2);
inputs[0].name = "input_0";
inputs[0].dims = {15};
inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};
inputs[1].name = "input_1";
inputs[1].dims = {15};
inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f,
10.0f, 9.9f, 8.8f, 7.7f, 6.6f,
5.5f, 4.4f, 3.3f, 2.2f, 1.1f};

// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {15};
std::vector<float> expected_values_y = inputs[1].values;

onnxruntime::PathString lib_name;
#if defined(_WIN32)
lib_name = ORT_TSTR("custom_op_library.dll");
#elif defined(__APPLE__)
lib_name = ORT_TSTR("libcustom_op_library.dylib");
#else
lib_name = ORT_TSTR("./libcustom_op_library.so");
#endif

TestInference<float>(*ort_env, CUSTOM_OP_LIBRARY_COPY_VARIADIC_2,
inputs, "output_1", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str());

inputs.push_back({});
inputs[2].name = "input_2";
inputs[2].dims = {15};
inputs[2].values = {6.6f, 7.7f, 8.8f, 9.9f, 10.0f,
1.1f, 2.2f, 3.3f, 4.4f, 5.5f,
11.1f, 12.2f, 13.3f, 14.4f, 15.5f};

expected_values_y = inputs[2].values;
TestInference<float>(*ort_env, CUSTOM_OP_LIBRARY_COPY_VARIADIC_3,
inputs, "output_2", expected_dims_y,
expected_values_y, 0, nullptr, lib_name.c_str());
}

#if !defined(DISABLE_FLOAT8_TYPES)

struct InputF8 {
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,25 @@ void FilterFloat8(const Ort::Custom::Tensor<Ort::Float8E4M3FN_t>& floats_in,
}
#endif

// a sample custom op accepting variadic inputs, and generate variadic outputs by simply 1:1 copying.
template<typename T>
Ort::Status CopyVariadic(const Ort::Custom::Variadic& inputs, Ort::Custom::Variadic& outputs) {
for (size_t ith_input = 0; ith_input < inputs.Size(); ++ith_input) {
const auto& input = inputs[ith_input];
const auto& input_shape = input->Shape();
const T* raw_input = reinterpret_cast<const T*>(input->DataRaw());
auto num_elements = input->NumberOfElement();
T* raw_output = outputs.AllocateOutput<T>(ith_input, input_shape);
if (!raw_output) {
return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL);
}
for (int64_t jth_elem = 0; jth_elem < num_elements; ++jth_elem) {
raw_output[jth_elem] = raw_input[jth_elem];
}
}
return Ort::Status{nullptr};
}

void RegisterOps(Ort::CustomOpDomain& domain) {
static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpOne{Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider", KernelOne)};
static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)};
Expand All @@ -175,6 +194,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
static const std::unique_ptr<OrtLiteCustomOp> c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)};
static const std::unique_ptr<OrtLiteCustomOp> c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
static const std::unique_ptr<OrtLiteCustomOp> c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)};
static const std::unique_ptr<OrtLiteCustomOp> c_CopyVariadic{Ort::Custom::CreateLiteCustomOp("CopyVariadic", "CPUExecutionProvider", CopyVariadic<float>)};

#if !defined(DISABLE_FLOAT8_TYPES)
static const CustomOpOneFloat8 c_CustomOpOneFloat8;
Expand All @@ -189,6 +209,8 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
domain.Add(c_Select.get());
domain.Add(c_Fill.get());
domain.Add(c_Box.get());
domain.Add(c_CopyVariadic.get());

#if !defined(DISABLE_FLOAT8_TYPES)
domain.Add(&c_CustomOpOneFloat8);
domain.Add(c_FilterFloat8.get());
Expand Down

0 comments on commit 8fb1e4f

Please sign in to comment.