Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Ort objects to be stored in a resizable std::vector #22608

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,9 @@
* constructors to construct an instance of a Status object from exceptions.
*/
struct Status : detail::Base<OrtStatus> {
using Base = detail::Base<OrtStatus>;
using Base::Base;

explicit Status(std::nullptr_t) noexcept {} ///< Create an empty object, must be assigned a valid one to be used
explicit Status(OrtStatus* status) noexcept; ///< Takes ownership of OrtStatus instance returned from the C API.
explicit Status(const Exception&) noexcept; ///< Creates status instance out of exception
Expand Down Expand Up @@ -728,6 +731,9 @@
*
*/
struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
using Base = detail::Base<OrtCustomOpDomain>;
using Base::Base;

explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used

/// \brief Wraps OrtApi::CreateCustomOpDomain
Expand Down Expand Up @@ -963,8 +969,10 @@
*
*/
struct ModelMetadata : detail::Base<OrtModelMetadata> {
explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
using Base = detail::Base<OrtModelMetadata>;
using Base::Base;

explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used

Check warning on line 975 in include/onnxruntime/core/session/onnxruntime_cxx_api.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api.h:975: Lines should be <= 120 characters long [whitespace/line_length] [2]

/** \brief Returns a copy of the producer name.
*
Expand Down Expand Up @@ -1237,6 +1245,9 @@
*
*/
struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
using Base = detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo>;
using Base::Base;

explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
Expand All @@ -1258,6 +1269,9 @@
*
*/
struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
using Base = detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo>;
using Base::Base;

explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
Expand Down Expand Up @@ -1293,6 +1307,9 @@
*
*/
struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
using Base = detail::MapTypeInfoImpl<OrtMapTypeInfo>;
using Base::Base;

explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
Expand Down Expand Up @@ -1324,6 +1341,9 @@
/// the information about contained sequence or map depending on the ONNXType.
/// </summary>
struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
using Base = detail::TypeInfoImpl<OrtTypeInfo>;
using Base::Base;

explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop

Expand Down Expand Up @@ -1661,11 +1681,11 @@
*/
struct Value : detail::ValueImpl<OrtValue> {
using Base = detail::ValueImpl<OrtValue>;
using Base::Base;
using OrtSparseValuesParam = detail::OrtSparseValuesParam;
using Shape = detail::Shape;

explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
Value(Value&&) = default;
Value& operator=(Value&&) = default;

Expand Down Expand Up @@ -1941,6 +1961,10 @@
/// This struct provides life time management for custom op attribute
/// </summary>
struct OpAttr : detail::Base<OrtOpAttr> {
using Base = detail::Base<OrtOpAttr>;
using Base::Base;

explicit OpAttr(std::nullptr_t) {}
OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
};

Expand Down Expand Up @@ -2183,6 +2207,8 @@
/// so it does not destroy the pointer the kernel does not own.
/// </summary>
struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
using Base = detail::KernelInfoImpl<OrtKernelInfo>;
using Base::Base;
explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
Expand All @@ -2192,6 +2218,9 @@
/// Create and own custom defined operation.
/// </summary>
struct Op : detail::Base<OrtOp> {
using Base = detail::Base<OrtOp>;
using Base::Base;

explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used

explicit Op(OrtOp*); ///< Take ownership of the OrtOp
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ inline void ThrowOnError(const Status& st) {
}
}

inline Status::Status(OrtStatus* status) noexcept : Base<OrtStatus>{status} {
inline Status::Status(OrtStatus* status) noexcept : detail::Base<OrtStatus>{status} {
}

inline Status::Status(const std::exception& e) noexcept {
Expand Down Expand Up @@ -1908,7 +1908,7 @@ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::

inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}

inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
inline Op::Op(OrtOp* p) : detail::Base<OrtOp>(p) {}

inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
const char** type_constraint_names,
Expand Down
26 changes: 26 additions & 0 deletions onnxruntime/test/shared_lib/test_nontensor_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,32 @@ TEST(CApiTest, SparseTensorFillSparseTensorFormatAPI) {
}
}

TEST(CApi, TestResize) {
std::vector<Ort::Value> values;
values.resize(10);

std::vector<Ort::Status> sts;
sts.resize(5);

std::vector<Ort::CustomOpDomain> domains;
domains.resize(5);

std::vector<Ort::TensorTypeAndShapeInfo> type_and_shape;
type_and_shape.resize(5);

std::vector<Ort::SequenceTypeInfo> seq_type_info;
seq_type_info.resize(5);

std::vector<Ort::MapTypeInfo> map_type_info;
map_type_info.resize(5);

std::vector<Ort::TypeInfo> type_info;
type_info.resize(5);

std::vector<Ort::OpAttr> op_attr;
op_attr.resize(5);
}

TEST(CApiTest, SparseTensorFillSparseFormatStringsAPI) {
auto allocator = Ort::AllocatorWithDefaultOptions();
Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
Expand Down
Loading