Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom op API for thread pool #18980

Merged
merged 8 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4528,6 +4528,19 @@
* \since Version 1.17.
*/
ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);

/**
* Run fn in parallel
*
* \param[in] context
* \param[in] fn Function accepting usr_data and an integer as iterator
* \param[in] total The number of times fn is to be invoked
* \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit
* \param[in] usr_data User data to be passed back to fn
*
* \since Version 1.17.
*/
ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data);

Check warning on line 4543 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_c_api.h#L4543

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_c_api.h:4543:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};

/*
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,7 @@ struct KernelContext {
Logger GetLogger() const;
OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const;
OrtKernelContext* GetOrtKernelContext() const { return ctx_; }
void ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const;

private:
OrtKernelContext* ctx_;
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,10 @@
return Logger{out};
}

inline void KernelContext::ParallelFor(void (*fn)(void*, size_t), size_t total, size_t num_batch, void* usr_data) const {

Check warning on line 1661 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_cxx_inline.h#L1661

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_cxx_inline.h:1661:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
ThrowOnError(GetApi().KernelContext_ParallelFor(ctx_, fn, total, num_batch, usr_data));
}

inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
}
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "core/session/custom_ops.h"
#include "core/session/inference_session.h"
#include "core/session/ort_apis.h"
#include "core/platform/threadpool.h"

#if !defined(ORT_MINIMAL_BUILD)
static constexpr uint32_t min_ort_version_with_optional_io_support = 8;
Expand Down Expand Up @@ -377,6 +378,30 @@
API_IMPL_END
};

ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) {
API_IMPL_BEGIN
if (context && total && usr_data) {
const auto* ctx = reinterpret_cast<const onnxruntime::OpKernelContext*>(context);

Check warning on line 384 in onnxruntime/core/session/custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/custom_ops.cc#L384

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/custom_ops.cc:384:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
auto* tp = ctx->GetOperatorThreadPool();
if (tp) {
RandySheriffH marked this conversation as resolved.
Show resolved Hide resolved
if (num_batch) {
onnxruntime::concurrency::ThreadPool::TryBatchParallelFor(
tp,
static_cast<std::ptrdiff_t>(total),
[fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast<size_t>(ith)); },
static_cast<std::ptrdiff_t>(num_batch));
} else {
onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor(
tp,
static_cast<std::ptrdiff_t>(total),
[fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast<size_t>(ith)); });
}
}
}
return nullptr;
API_IMPL_END
};

#ifdef _WIN32
#pragma warning(pop)
#endif
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2722,6 +2722,7 @@ static constexpr OrtApi ort_api_1_to_17 = {
&OrtApis::SetSymbolicDimensions,
&OrtApis::ReadOpAttr,
&OrtApis::SetDeterministicCompute,
&OrtApis::KernelContext_ParallelFor,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -502,4 +502,6 @@
ORT_API_STATUS_IMPL(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out);
ORT_API_STATUS_IMPL(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value);

ORT_API_STATUS_IMPL(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* user_data);

Check warning on line 505 in onnxruntime/core/session/ort_apis.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/ort_apis.h#L505

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/ort_apis.h:505:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

} // namespace OrtApis
37 changes: 33 additions & 4 deletions onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,45 @@ struct KernelOne {
}
};

struct DataI {
const float* from = {};
float* to = {};
};

struct DataII {
const float* from = {};
int32_t* to = {};
};

// floats to floats
void CopyI(void* raw_data, size_t ith) {
auto data = reinterpret_cast<DataI*>(raw_data);
data->to[ith] = data->from[ith];
}

// floats to int32_t
void CopyII(void* raw_data, size_t ith) {
auto data = reinterpret_cast<DataII*>(raw_data);
data->to[ith] = static_cast<int32_t>(round(data->from[ith]));
}

// lite custom op as a function
void KernelTwo(const Ort::Custom::Tensor<float>& X,
void KernelTwo(OrtKernelContext* context,
const Ort::Custom::Tensor<float>& X,
Ort::Custom::Tensor<int32_t>& Y) {
const auto& shape = X.Shape();
auto X_raw = X.Data();
auto Y_raw = Y.Allocate(shape);
std::vector<float> floats(static_cast<size_t>(X.NumberOfElement()), 0.f);

DataI data_i = {X_raw, floats.data()};
auto total = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
for (int64_t i = 0; i < total; i++) {
Y_raw[i] = static_cast<int32_t>(round(X_raw[i]));
}

Ort::KernelContext ctx(context);
ctx.ParallelFor(CopyI, static_cast<size_t>(total), 0, &data_i); // test simple parallel for

DataII data_ii = {floats.data(), Y_raw};
ctx.ParallelFor(CopyII, static_cast<size_t>(total), 2, &data_ii); // test batch parallel for
}

template <typename T>
Expand Down
Loading