Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn-patch-13
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Nov 19, 2023
2 parents 7d74d6b + 97cc40d commit 7c6cf56
Show file tree
Hide file tree
Showing 75 changed files with 11,265 additions and 16,655 deletions.
20 changes: 7 additions & 13 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,6 @@ configure_file(${ONNXRUNTIME_ROOT}/python/_pybind_state.py.in
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py)

if (onnxruntime_ENABLE_TRAINING)
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/deprecated/*.py"
)
file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/python/training/*.py"
)
Expand Down Expand Up @@ -419,10 +416,6 @@ if (onnxruntime_ENABLE_TRAINING)
"${ORTTRAINING_SOURCE_DIR}/python/training/onnxblock/optim/*"
)
endif()
else()
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/training/*.py"
)
endif()

if (onnxruntime_BUILD_UNIT_TESTS)
Expand All @@ -443,6 +436,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
file(GLOB onnxruntime_python_transformers_testdata_whisper CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/whisper/*.onnx"
)
file(GLOB onnxruntime_python_transformers_testdata_conformer CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/test/python/transformers/test_data/models/conformer/*.onnx"
)
endif()

file(GLOB onnxruntime_python_tools_srcs CONFIGURE_DEPENDS
Expand Down Expand Up @@ -556,6 +552,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/eager_test
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer
COMMAND ${CMAKE_COMMAND} -E copy
${ONNXRUNTIME_ROOT}/__init__.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/
Expand All @@ -577,9 +574,6 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy_if_different
${CMAKE_BINARY_DIR}/onnxruntime/capi/_pybind_state.py
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_capi_training_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
COMMAND ${CMAKE_COMMAND} -E copy
$<TARGET_FILE:onnxruntime_pybind11_state>
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/
Expand Down Expand Up @@ -711,6 +705,9 @@ if (onnxruntime_BUILD_UNIT_TESTS)
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_whisper}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/whisper/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_testdata_conformer}
$<TARGET_FILE_DIR:${build_output_target}>/transformers/test_data/models/conformer/
)
endif()

Expand Down Expand Up @@ -750,9 +747,6 @@ if (onnxruntime_ENABLE_TRAINING)
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/data/
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/hooks/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_capi_training_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/capi/training/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_root_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/
Expand Down
47 changes: 30 additions & 17 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase {

using Variadic = TensorArray;

/*
Note:
OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.
2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
hence memory could still be recycled properly.
Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
*/
struct OrtLiteCustomOp : public OrtCustomOp {
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
Expand Down Expand Up @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {

OrtLiteCustomOp(const char* op_name,
const char* execution_provider,
int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
start_ver_(start_ver),
end_ver_(end_ver) {
ShapeInferFn shape_infer_fn,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
shape_infer_fn_(shape_infer_fn),
start_ver_(start_ver),
end_ver_(end_ver) {
OrtCustomOp::version = ORT_API_VERSION;

OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
Expand Down Expand Up @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;

ShapeInferFn shape_infer_fn_ = {};

int start_ver_ = 1;
int end_ver_ = MAX_CUSTOM_OP_END_VER;

void* compute_fn_ = {};
void* compute_fn_return_status_ = {};
};

//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
Expand Down Expand Up @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
ComputeFn compute_fn,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_(compute_fn),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
compute_fn_ = reinterpret_cast<void*>(compute_fn);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
Expand All @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_ = static_cast<const MyType*>(this_)->compute_fn_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
ComputeFnReturnStatus compute_fn_return_status,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_return_status_(compute_fn_return_status),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
Expand All @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_return_status_ = static_cast<const MyType*>(this_)->compute_fn_return_status_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
};
}
}

ComputeFn compute_fn_ = {};
ComputeFnReturnStatus compute_fn_return_status_ = {};
ShapeInferFn shape_infer_fn_ = {};
}; // struct OrtLiteCustomFunc

/////////////////////////// OrtLiteCustomStruct ///////////////////////////
Expand Down Expand Up @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
OrtLiteCustomStruct(const char* op_name,
const char* execution_provider,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
SetCompute(&CustomOp::Compute);

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
Expand Down
79 changes: 23 additions & 56 deletions js/node/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion js/node/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
},
"scripts": {
"buildr": "tsc && node ./script/build --config=RelWithDebInfo",
"preprepare": "node -e \"require('node:fs').copyFileSync('./node_modules/long/index.d.ts', './node_modules/long/umd/index.d.ts')\"",
"prepare": "tsc --build script test .",
"rebuild": "tsc && node ./script/build --rebuild",
"rebuildd": "tsc && node ./script/build --rebuild --config=Debug",
Expand All @@ -39,7 +40,7 @@
"jsonc": "^2.0.0",
"minimist": "^1.2.8",
"node-addon-api": "^6.0.0",
"onnx-proto": "^8.0.1"
"protobufjs": "^7.2.4"
},
"main": "dist/index.js",
"os": [
Expand Down
2 changes: 2 additions & 0 deletions js/node/test/ort-schema/protobuf/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
!onnx.js
!onnx.d.ts
21 changes: 21 additions & 0 deletions js/node/test/ort-schema/protobuf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# ONNX protobuf

This directory contains generated protobuf definition for onnx:

- onnx.js
- onnx.d.ts

These files are generated from [a fork of onnx-proto](https://github.com/fs-eire/onnx-proto/tree/update-v9).

The ONNX protobuf uses [email protected], which depends on [email protected], the version contains 2 bugs:

- type export does not work with commonjs. described in https://github.com/dcodeIO/long.js/pull/124. added a "postinstall" script to fix.
- in the generated typescript declaration file 'onnx.d.ts', the following line:
```ts
import Long = require("long");
```
need to be replaced to fix type import error:
```ts
import Long from "long";
```
this replacement is done and code format is also applied to file 'onnx.d.ts'.
Loading

0 comments on commit 7c6cf56

Please sign in to comment.