From 76734a56b9972c5914e9036f2c5aa7b2c5847870 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 7 Sep 2022 12:09:56 -0400 Subject: [PATCH] Removing unnecessary use of OrtAllocator. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 11 +++---- java/src/main/native/OrtJniUtil.c | 33 +++++++++++++------ java/src/main/native/OrtJniUtil.h | 6 ++-- java/src/main/native/ai_onnxruntime_OnnxMap.c | 4 +-- .../main/native/ai_onnxruntime_OnnxSequence.c | 6 ++-- .../main/native/ai_onnxruntime_OnnxTensor.c | 9 +++-- 6 files changed, 40 insertions(+), 29 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index f71d26c8bfffa..653a3c9be91bd 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -90,14 +90,14 @@ public Object getValue() throws OrtException { case BOOL: return getBool(OnnxRuntime.ortApiHandle, nativeHandle); case STRING: - return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); + return getString(OnnxRuntime.ortApiHandle, nativeHandle); case UNKNOWN: default: throw new OrtException("Extracting the value of an invalid Tensor."); } } else { Object carrier = info.makeCarrier(); - getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier); + getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { // We read the strings out from native code in a flat array and then reshape // to the desired output shape. @@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException; - private native String getString(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native String getString(long apiHandle, long nativeHandle) throws OrtException; private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray( - long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException; + private native void getArray(long apiHandle, long nativeHandle, Object carrier) + throws OrtException; private native void close(long apiHandle, long nativeHandle); diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index b2e758f7ce913..ca949c94ab690 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -581,13 +581,25 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT (*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *)inputTensor); return consumedSize; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: { + // complex with float32 real and imaginary components + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64."); + return -1; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { + // complex with float64 real and imaginary components + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128."); + return -1; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: { + // Non-IEEE floating-point format based on IEEE754 single-precision + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16."); + return -1; + } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: default: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type."); - return -1; + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED."); + return -1; } } } @@ -618,7 +630,7 @@ int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, con } } -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { +jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { jobject tempString = NULL; // Get the buffer size needed size_t totalStringLength = 0; @@ -651,7 +663,7 @@ jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllo return tempString; } -OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray) { +OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { char * tempBuffer = NULL; // Get the buffer size needed size_t totalStringLength = 0; @@ -711,7 +723,7 @@ OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllo return code; } -jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { +jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { // Extract tensor info OrtTensorTypeAndShapeInfo* tensorInfo = NULL; OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &tensorInfo)); @@ -731,7 +743,7 @@ jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, Ort jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(length), stringClazz, NULL); - code = copyStringTensorToArray(jniEnv, api, allocator, tensor, length, outputArray); + code = copyStringTensorToArray(jniEnv, api, tensor, length, outputArray); if (code != ORT_OK) { outputArray = NULL; } @@ -1025,7 +1037,8 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca case ONNX_TYPE_UNKNOWN: case ONNX_TYPE_OPAQUE: case ONNX_TYPE_OPTIONAL: - case ONNX_TYPE_SPARSETENSOR: { + case ONNX_TYPE_SPARSETENSOR: + default: { throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR."); return NULL; } diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index d671ceb50d8ed..616a20503ad42 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -54,11 +54,11 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); +jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); -OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray); +OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); -jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); +jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); diff --git a/java/src/main/native/ai_onnxruntime_OnnxMap.c b/java/src/main/native/ai_onnxruntime_OnnxMap.c index dce0677c84672..161fa23ae81cc 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxMap.c +++ b/java/src/main/native/ai_onnxruntime_OnnxMap.c @@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv* OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys)); if (code == ORT_OK) { // Convert to Java String array - jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys); + jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys); api->ReleaseValue(keys); @@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values)); if (code == ORT_OK) { // Convert to Java String array - jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values); + jobjectArray output = createStringArrayFromTensor(jniEnv, api, values); api->ReleaseValue(values); diff --git a/java/src/main/native/ai_onnxruntime_OnnxSequence.c b/java/src/main/native/ai_onnxruntime_OnnxSequence.c index 12d3266f8a90f..22cfca6315768 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxSequence.c +++ b/java/src/main/native/ai_onnxruntime_OnnxSequence.c @@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN if (code == ORT_OK) { // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, allocator, keys); + output = createStringArrayFromTensor(jniEnv, api, keys); // Release if valid api->ReleaseValue(element); } @@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues( if (code == ORT_OK) { // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, allocator, values); + output = createStringArrayFromTensor(jniEnv, api, values); // Release if valid api->ReleaseValue(element); } @@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn OrtValue* element; code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); if (code == ORT_OK) { - jobject str = createStringFromStringTensor(jniEnv, api, allocator, element); + jobject str = createStringFromStringTensor(jniEnv, api, element); if (str == NULL) { api->ReleaseValue(element); // bail out as exception has been thrown diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index edc507075feb1..1656b4043cfe9 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong * Signature: (JJ)Ljava/lang/String; */ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; // Extract a String array - if this becomes a performance issue we'll refactor later. - jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle, - (OrtValue*) handle); + jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle); if (outputArray != NULL) { // Get reference to the string jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0); @@ -410,7 +409,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool * Signature: (JJLjava/lang/Object;)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -418,7 +417,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value); if (code == ORT_OK) { if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier); + copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { uint8_t* arr = NULL; code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));