Skip to content

Commit

Permalink
Removing unnecessary use of OrtAllocator.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Sep 7, 2022
1 parent 5070ceb commit 76734a5
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 29 deletions.
11 changes: 5 additions & 6 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ public Object getValue() throws OrtException {
case BOOL:
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
case STRING:
return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case UNKNOWN:
default:
throw new OrtException("Extracting the value of an invalid Tensor.");
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
Expand Down Expand Up @@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)

private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;

private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native String getString(long apiHandle, long nativeHandle) throws OrtException;

private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;

private native void getArray(
long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
private native void getArray(long apiHandle, long nativeHandle, Object carrier)
throws OrtException;

private native void close(long apiHandle, long nativeHandle);

Expand Down
33 changes: 23 additions & 10 deletions java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,25 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT
(*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *)inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: {
// complex with float32 real and imaginary components
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64.");
return -1;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: {
// complex with float64 real and imaginary components
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128.");
return -1;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: {
// Non-IEEE floating-point format based on IEEE754 single-precision
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16.");
return -1;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type.");
return -1;
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED.");
return -1;
}
}
}
Expand Down Expand Up @@ -618,7 +630,7 @@ int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, con
}
}

jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
jobject tempString = NULL;
// Get the buffer size needed
size_t totalStringLength = 0;
Expand Down Expand Up @@ -651,7 +663,7 @@ jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllo
return tempString;
}

OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray) {
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) {
char * tempBuffer = NULL;
// Get the buffer size needed
size_t totalStringLength = 0;
Expand Down Expand Up @@ -711,7 +723,7 @@ OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllo
return code;
}

jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
// Extract tensor info
OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &tensorInfo));
Expand All @@ -731,7 +743,7 @@ jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, Ort
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(length), stringClazz, NULL);

code = copyStringTensorToArray(jniEnv, api, allocator, tensor, length, outputArray);
code = copyStringTensorToArray(jniEnv, api, tensor, length, outputArray);
if (code != ORT_OK) {
outputArray = NULL;
}
Expand Down Expand Up @@ -1025,7 +1037,8 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
case ONNX_TYPE_UNKNOWN:
case ONNX_TYPE_OPAQUE:
case ONNX_TYPE_OPTIONAL:
case ONNX_TYPE_SPARSETENSOR: {
case ONNX_TYPE_SPARSETENSOR:
default: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR.");
return NULL;
}
Expand Down
6 changes: 3 additions & 3 deletions java/src/main/native/OrtJniUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT

int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);

jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);

OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray);
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);

jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);

jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);

Expand Down
4 changes: 2 additions & 2 deletions java/src/main/native/ai_onnxruntime_OnnxMap.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv*
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys));
if (code == ORT_OK) {
// Convert to Java String array
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys);

api->ReleaseValue(keys);

Expand Down Expand Up @@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java String array
jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values);
jobjectArray output = createStringArrayFromTensor(jniEnv, api, values);

api->ReleaseValue(values);

Expand Down
6 changes: 3 additions & 3 deletions java/src/main/native/ai_onnxruntime_OnnxSequence.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN

if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
output = createStringArrayFromTensor(jniEnv, api, keys);
// Release if valid
api->ReleaseValue(element);
}
Expand Down Expand Up @@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(

if (code == ORT_OK) {
// Convert to Java String array
output = createStringArrayFromTensor(jniEnv, api, allocator, values);
output = createStringArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
Expand Down Expand Up @@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
jobject str = createStringFromStringTensor(jniEnv, api, allocator, element);
jobject str = createStringFromStringTensor(jniEnv, api, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
Expand Down
9 changes: 4 additions & 5 deletions java/src/main/native/ai_onnxruntime_OnnxTensor.c
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
// Extract a String array - if this becomes a performance issue we'll refactor later.
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
(OrtValue*) handle);
jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle);
if (outputArray != NULL) {
// Get reference to the string
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
Expand Down Expand Up @@ -410,15 +409,15 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
* Signature: (JJLjava/lang/Object;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* value = (OrtValue*) handle;
JavaTensorTypeShape typeShape;
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
if (code == ORT_OK) {
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
} else {
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
Expand Down

0 comments on commit 76734a5

Please sign in to comment.