Skip to content

Commit

Permalink
[objc] Add check for ORTValue being a tensor in ORTValue methods that…
Browse files Browse the repository at this point in the history
… should only be used with tensors. (#19946)

Add check to report error instead of crashing.
  • Loading branch information
edgchen1 authored Mar 18, 2024
1 parent 7e0d424 commit 4d31076
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions objectivec/ort_value.mm
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ - (nullable ORTValueTypeInfo*)typeInfoWithError:(NSError**)error {
- (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
return CXXAPIToPublicTensorTypeAndShapeInfo(tensorTypeAndShapeInfo);
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
Expand All @@ -156,6 +159,9 @@ - (nullable ORTTensorTypeAndShapeInfo*)tensorTypeAndShapeInfoWithError:(NSError*
- (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
if (tensorTypeAndShapeInfo.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
ORT_CXX_API_THROW(
"This ORTValue holds string data. Please call tensorStringDataWithError: "
Expand All @@ -182,6 +188,9 @@ - (nullable NSMutableData*)tensorDataWithError:(NSError**)error {
- (nullable NSArray<NSString*>*)tensorStringDataWithError:(NSError**)error {
try {
const auto tensorTypeAndShapeInfo = _typeInfo->GetTensorTypeAndShapeInfo();
if (!tensorTypeAndShapeInfo) {
ORT_CXX_API_THROW("ORTValue is not a tensor.", ORT_RUNTIME_EXCEPTION);
}
const size_t elementCount = tensorTypeAndShapeInfo.GetElementCount();
const size_t tensorStringDataLength = _value->GetStringTensorDataLength();
std::vector<char> tensorStringData(tensorStringDataLength, '\0');
Expand Down

0 comments on commit 4d31076

Please sign in to comment.