diff --git a/java/src/main/java/ai/onnxruntime/MapInfo.java b/java/src/main/java/ai/onnxruntime/MapInfo.java index 552793139549d..5a2ac072f336a 100644 --- a/java/src/main/java/ai/onnxruntime/MapInfo.java +++ b/java/src/main/java/ai/onnxruntime/MapInfo.java @@ -4,6 +4,8 @@ */ package ai.onnxruntime; +import ai.onnxruntime.TensorInfo.OnnxTensorType; + /** Describes an {@link OnnxMap} object or output node. */ public class MapInfo implements ValueInfo { @@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo { this.valueType = valueType; } + /** + * Construct a MapInfo with the specified size, key type and value type. + * + *

Called from JNI. + * + * @param size The size. + * @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys. + * @param valueTypeInt The int representing the {@link OnnxTensorType} of the values. + */ + MapInfo(int size, int keyTypeInt, int valueTypeInt) { + this.size = size; + this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt)); + this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt)); + } + @Override public String toString() { String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size; diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index f71d26c8bfffa..653a3c9be91bd 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -90,14 +90,14 @@ public Object getValue() throws OrtException { case BOOL: return getBool(OnnxRuntime.ortApiHandle, nativeHandle); case STRING: - return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); + return getString(OnnxRuntime.ortApiHandle, nativeHandle); case UNKNOWN: default: throw new OrtException("Extracting the value of an invalid Tensor."); } } else { Object carrier = info.makeCarrier(); - getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier); + getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { // We read the strings out from native code in a flat array and then reshape // to the desired output shape. @@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException; - private native String getString(long apiHandle, long nativeHandle, long allocatorHandle) - throws OrtException; + private native String getString(long apiHandle, long nativeHandle) throws OrtException; private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray( - long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException; + private native void getArray(long apiHandle, long nativeHandle, Object carrier) + throws OrtException; private native void close(long apiHandle, long nativeHandle); diff --git a/java/src/main/java/ai/onnxruntime/SequenceInfo.java b/java/src/main/java/ai/onnxruntime/SequenceInfo.java index a417634b72de0..5497766546958 100644 --- a/java/src/main/java/ai/onnxruntime/SequenceInfo.java +++ b/java/src/main/java/ai/onnxruntime/SequenceInfo.java @@ -4,6 +4,8 @@ */ package ai.onnxruntime; +import ai.onnxruntime.TensorInfo.OnnxTensorType; + /** Describes an {@link OnnxSequence}, including it's element type if known. */ public class SequenceInfo implements ValueInfo { @@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo { this.mapInfo = null; } + /** + * Construct a sequence of known length, with the specified type. This sequence does not contain + * maps. + * + *

Called from JNI. + * + * @param length The length of the sequence. + * @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}. + */ + SequenceInfo(int length, int sequenceTypeInt) { + this.length = length; + this.sequenceType = + OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt)); + this.sequenceOfMaps = false; + this.mapInfo = null; + } + /** * Construct a sequence of known length containing maps. * diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index ead90d83bcba6..4a7a3b833bc2b 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -120,6 +120,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { this.onnxType = onnxType; } + /** + * Constructs a TensorInfo with the specified shape and native type int. + * + *

Called from JNI. + * + * @param shape The tensor shape. + * @param typeInt The native type int. + */ + TensorInfo(long[] shape, int typeInt) { + this.shape = shape; + this.onnxType = OnnxTensorType.mapFromInt(typeInt); + this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); + } + /** * Get a copy of the tensor's shape. * diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 55f26f35b40df..a670179a0eb25 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -8,7 +8,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { // To silence unused-parameter error. - // This function must exist according to the JNI spec, but the arguments aren't necessary for the library to request a specific version. + // This function must exist according to the JNI spec, but the arguments aren't necessary for the library + // to request a specific version. (void)vm; (void) reserved; // Requesting 1.6 to support Android. Will need to be bumped to a later version to call interface default methods // from native code, or to access other new Java features. @@ -213,289 +214,271 @@ typedef union FP32 { float floatVal; } FP32; -jfloat convertHalfToFloat(uint16_t half) { +jfloat convertHalfToFloat(const uint16_t half) { FP32 output; output.intVal = (((half&0x8000)<<16) | (((half&0x7c00)+0x1C000)<<13) | ((half&0x03FF)<<13)); return output.floatVal; } -jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info) { - ONNXType type; - checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(info,&type)); +jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) { + ONNXType type = ONNX_TYPE_UNKNOWN; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetOnnxTypeFromTypeInfo(info, &type)); + if (code != ORT_OK) { + return NULL; + } - switch (type) { - case ONNX_TYPE_TENSOR: { - const OrtTensorTypeAndShapeInfo* tensorInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(info,&tensorInfo)); - return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo); - } - case ONNX_TYPE_SEQUENCE: { - const OrtSequenceTypeInfo* sequenceInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo)); - return convertToSequenceInfo(jniEnv, api, sequenceInfo); - } - case ONNX_TYPE_MAP: { - const OrtMapTypeInfo* mapInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo)); - return convertToMapInfo(jniEnv, api, mapInfo); - } - case ONNX_TYPE_UNKNOWN: - case ONNX_TYPE_OPAQUE: - case ONNX_TYPE_SPARSETENSOR: - default: { - throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found."); - return NULL; - } + switch (type) { + case ONNX_TYPE_TENSOR: { + const OrtTensorTypeAndShapeInfo* tensorInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(info, &tensorInfo)); + if (code == ORT_OK) { + return convertToTensorInfo(jniEnv, api, tensorInfo); + } else { + return NULL; + } + } + case ONNX_TYPE_SEQUENCE: { + const OrtSequenceTypeInfo* sequenceInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToSequenceTypeInfo(info, &sequenceInfo)); + if (code == ORT_OK) { + return convertToSequenceInfo(jniEnv, api, sequenceInfo); + } else { + return NULL; + } + } + case ONNX_TYPE_MAP: { + const OrtMapTypeInfo* mapInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToMapTypeInfo(info, &mapInfo)); + if (code == ORT_OK) { + return convertToMapInfo(jniEnv, api, mapInfo); + } else { + return NULL; + } } + case ONNX_TYPE_UNKNOWN: + case ONNX_TYPE_OPAQUE: + case ONNX_TYPE_SPARSETENSOR: + default: { + throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found."); + return NULL; + } + } } jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info) { - // Extract the information from the info struct. - ONNXTensorElementDataType onnxType; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxType)); - size_t numDim; - checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim)); - //printf("numDim %d\n",numDim); - int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); - checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim)); - jint onnxTypeInt = convertFromONNXDataFormat(onnxType); - - // Create the long array for the shape. - jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim)); - (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions); - // Free the dimensions array - free(dimensions); - dimensions = NULL; - - // Create the ONNXTensorType enum - char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; - jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); - jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); - jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxTensorTypeMapFromInt,onnxTypeInt); - //printf("ONNXTensorType class %p, methodID %p, object %p\n",clazz,onnxTensorTypeMapFromInt,onnxTensorTypeJava); - - // Create the ONNXJavaType enum - char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; - clazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); - jmethodID javaDataTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); - jobject javaDataType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,javaDataTypeMapFromONNXTensorType,onnxTensorTypeJava); - //printf("JavaDataType class %p, methodID %p, object %p\n",clazz,javaDataTypeMapFromONNXTensorType,javaDataType); - - // Create the TensorInfo object - char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; - clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JLai/onnxruntime/OnnxJavaType;Lai/onnxruntime/TensorInfo$OnnxTensorType;)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, javaDataType, onnxTensorTypeJava); - return tensorInfo; + // Extract the information from the info struct. + ONNXTensorElementDataType onnxType; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxType)); + if (code != ORT_OK) { + return NULL; + } + size_t numDim = 0; + code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim)); + if (code != ORT_OK) { + return NULL; + } + //printf("numDim %d\n",numDim); + int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); + code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); + if (code != ORT_OK) { + free((void*) dimensions); + return NULL; + } + jint onnxTypeInt = convertFromONNXDataFormat(onnxType); + + // Create the long array for the shape. + jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim)); + (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions); + // Free the dimensions array + free(dimensions); + dimensions = NULL; + + // Create the TensorInfo object + static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; + jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); + //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + return tensorInfo; } jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) { - // Create the java methods we need to call. - // Get the ONNXTensorType enum static method - char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; - jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); - jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); - - // Get the ONNXJavaType enum static method - char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); - jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); - - // Get the map info class - char *mapInfoClassName = "ai/onnxruntime/MapInfo"; - jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); - jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); - - // Extract the key type - ONNXTensorElementDataType keyType; - checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType)); - - // Convert key type to java - jint onnxTypeKey = convertFromONNXDataFormat(keyType); - jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey); - jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey); - - // according to include/onnxruntime/core/framework/data_types.h only the following values are supported. - // string, int64, float, double - // So extract the value type, then convert it to a tensor type so we can get it's element type. - OrtTypeInfo* valueTypeInfo; - checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo)); - const OrtTensorTypeAndShapeInfo* tensorValueInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo)); - ONNXTensorElementDataType valueType; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType)); - api->ReleaseTypeInfo(valueTypeInfo); - tensorValueInfo = NULL; - valueTypeInfo = NULL; - - // Convert value type to java - jint onnxTypeValue = convertFromONNXDataFormat(valueType); - jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue); - jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue); + // Extract the key type + ONNXTensorElementDataType keyType; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetMapKeyType(info, &keyType)); + if (code != ORT_OK) { + return NULL; + } - // Construct map info - jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue); + // according to include/onnxruntime/core/framework/data_types.h only the following values are supported. + // string, int64, float, double + // So extract the value type, then convert it to a tensor type so we can get it's element type. + OrtTypeInfo* valueTypeInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->GetMapValueType(info, &valueTypeInfo)); + if (code != ORT_OK) { + return NULL; + } + const OrtTensorTypeAndShapeInfo* tensorValueInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(valueTypeInfo, &tensorValueInfo)); + if (code != ORT_OK) { + api->ReleaseTypeInfo(valueTypeInfo); + return NULL; + } + ONNXTensorElementDataType valueType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(tensorValueInfo, &valueType)); + api->ReleaseTypeInfo(valueTypeInfo); + tensorValueInfo = NULL; + valueTypeInfo = NULL; + if (code != ORT_OK) { + return NULL; + } - return mapInfo; -} + // Convert key type to java + jint onnxTypeKey = convertFromONNXDataFormat(keyType); + // Convert value type to java + jint onnxTypeValue = convertFromONNXDataFormat(valueType); -jobject createEmptyMapInfo(JNIEnv *jniEnv) { - // Create the ONNXJavaType enum - char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxJavaTypeClassName); - jmethodID onnxJavaTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/OnnxJavaType;"); - jobject unknownType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxJavaTypeMapFromInt,0); + // Get the map info class + static const char *mapInfoClassName = "ai/onnxruntime/MapInfo"; + jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); + jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, mapInfoClazz, "", "(III)V"); - char *mapInfoClassName = "ai/onnxruntime/MapInfo"; - clazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); - jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz,"","(Lai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); - jobject mapInfo = (*jniEnv)->NewObject(jniEnv,clazz,mapInfoConstructor,unknownType,unknownType); + // Construct map info + jobject mapInfo = (*jniEnv)->NewObject(jniEnv, mapInfoClazz, mapInfoConstructor, (jint)-1, onnxTypeKey, onnxTypeValue); - return mapInfo; + return mapInfo; } jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) { - // Get the sequence info class - char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; - jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); - - // according to include/onnxruntime/core/framework/data_types.h the following values are supported. - // tensor types, map and map - OrtTypeInfo* elementTypeInfo; - checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo)); - ONNXType type; - checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type)); - - jobject sequenceInfo; - - switch (type) { - case ONNX_TYPE_TENSOR: { - // Figure out element type - const OrtTensorTypeAndShapeInfo* elementTensorInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo)); - ONNXTensorElementDataType element; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element)); - - // Convert element type into ONNXTensorType - jint onnxTypeInt = convertFromONNXDataFormat(element); - // Get the ONNXTensorType enum static method - char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; - jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); - jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); - jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt); - - // Get the ONNXJavaType enum static method - char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); - jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); - jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava); - - // Construct sequence info - jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V"); - sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType); - break; - } - case ONNX_TYPE_MAP: { - // Extract the map info - const OrtMapTypeInfo* mapInfo; - checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo)); - - // Convert it using the existing convert function - jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo); + // Get the sequence info class + static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; + jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); + jobject sequenceInfo = NULL; + + // according to include/onnxruntime/core/framework/data_types.h the following values are supported. + // tensor types, map and map + OrtTypeInfo* elementTypeInfo = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSequenceElementType(info, &elementTypeInfo)); + if (code != ORT_OK) { + return NULL; + } + ONNXType type = ONNX_TYPE_UNKNOWN; + code = checkOrtStatus(jniEnv, api, api->GetOnnxTypeFromTypeInfo(elementTypeInfo, &type)); + if (code != ORT_OK) { + goto sequence_cleanup; + } - // Construct sequence info - jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V"); - sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo); - break; - } - default: { - sequenceInfo = createEmptySequenceInfo(jniEnv); - throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence"); - break; - } + switch (type) { + case ONNX_TYPE_TENSOR: { + // Figure out element type + const OrtTensorTypeAndShapeInfo* elementTensorInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(elementTypeInfo, &elementTensorInfo)); + if (code != ORT_OK) { + goto sequence_cleanup; + } + ONNXTensorElementDataType element = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(elementTensorInfo, &element)); + if (code != ORT_OK) { + goto sequence_cleanup; + } + + // Convert element type into ONNXTensorType + jint onnxTypeInt = convertFromONNXDataFormat(element); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)-1, onnxTypeInt); + break; } - api->ReleaseTypeInfo(elementTypeInfo); - elementTypeInfo = NULL; - - return sequenceInfo; -} - -jobject createEmptySequenceInfo(JNIEnv *jniEnv) { - // Create the ONNXJavaType enum - char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxJavaTypeClassName); - jmethodID onnxJavaTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/OnnxJavaType;"); - jobject unknownType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxJavaTypeMapFromInt,0); + case ONNX_TYPE_MAP: { + // Extract the map info + const OrtMapTypeInfo* mapInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToMapTypeInfo(elementTypeInfo, &mapInfo)); + if (code != ORT_OK) { + goto sequence_cleanup; + } + + // Convert it using the existing convert function + jobject javaMapInfo = convertToMapInfo(jniEnv, api, mapInfo); + + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(ILai/onnxruntime/MapInfo;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)-1, javaMapInfo); + break; + } + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid element type found in sequence"); + break; + } + } - char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; - clazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); - jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz,"","(ILai/onnxruntime/OnnxJavaType;)V"); - jobject sequenceInfo = (*jniEnv)->NewObject(jniEnv,clazz,sequenceInfoConstructor,-1,unknownType); +sequence_cleanup: + api->ReleaseTypeInfo(elementTypeInfo); + elementTypeInfo = NULL; - return sequenceInfo; + return sequenceInfo; } -size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input) { - uint32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv,input); - size_t consumedSize = inputLength * onnxTypeSize(onnxType); +int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) { + int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray); + int64_t consumedSize = inputLength * onnxTypeSize(onnxType); switch (onnxType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray) input; - (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * ) tensor); + jbyteArray typedArr = (jbyteArray)inputArray; + (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray) input; - (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * ) tensor); + jshortArray typedArr = (jshortArray)inputArray; + (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray) input; - (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * ) tensor); + jintArray typedArr = (jintArray)inputArray; + (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray) input; - (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * ) tensor); + jlongArray typedArr = (jlongArray)inputArray; + (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported."); - return 0; + return -1; /* float *floatArr = malloc(sizeof(float) * inputLength); - uint16_t *halfArr = (uint16_t *) tensor; + uint16_t *halfArr = (uint16_t *) outputTensor; for (uint32_t i = 0; i < inputLength; i++) { floatArr[i] = convertHalfToFloat(halfArr[i]); } - jfloatArray typedArr = (jfloatArray) input; + jfloatArray typedArr = (jfloatArray) inputArray; (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr); free(floatArr); return consumedSize; */ } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray) input; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * ) tensor); + jfloatArray typedArr = (jfloatArray)inputArray; + (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray) input; - (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * ) tensor); + jdoubleArray typedArr = (jdoubleArray)inputArray; + (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported."); - return 0; + return -1; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray) input; - (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *) tensor); + jbooleanArray typedArr = (jbooleanArray)inputArray; + (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components @@ -503,524 +486,566 @@ size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxTy case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: default: { - throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid tensor element type."); - return 0; + throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type."); + return -1; } } } -size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, - size_t dimensionsRemaining, jarray input) { +int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) { if (dimensionsRemaining == 1) { // write out 1d array of the respective primitive type - return copyJavaToPrimitiveArray(jniEnv,onnxType,tensor,input); + return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor); } else { // recurse through the dimensions // Java arrays are objects until the final dimension - jobjectArray inputObjArr = (jobjectArray) input; - uint32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv,inputObjArr); - size_t sizeConsumed = 0; - for (uint32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv,inputObjArr,i); - sizeConsumed += copyJavaToTensor(jniEnv, onnxType, tensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); + jobjectArray inputObjArr = (jobjectArray)inputArray; + int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr); + int64_t sizeConsumed = 0; + for (int32_t i = 0; i < dimLength; i++) { + jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i); + int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed); + sizeConsumed += consumed; // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv,childArr); + (*jniEnv)->DeleteLocalRef(jniEnv, childArr); + // If we failed to copy an array then break and return. + if (consumed == -1) { + return -1; + } } return sizeConsumed; } } -size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output) { - uint32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv,output); +int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) { + int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray); if (outputLength == 0) return 0; - size_t consumedSize = outputLength * onnxTypeSize(onnxType); + int64_t consumedSize = outputLength * onnxTypeSize(onnxType); switch (onnxType) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray) output; - (*jniEnv)->SetByteArrayRegion(jniEnv, typedArr, 0, outputLength, (jbyte * ) tensor); + jbyteArray typedArr = (jbyteArray)outputArray; + (*jniEnv)->SetByteArrayRegion(jniEnv, typedArr, 0, outputLength, (jbyte * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray) output; - (*jniEnv)->SetShortArrayRegion(jniEnv, typedArr, 0, outputLength, (jshort * ) tensor); + jshortArray typedArr = (jshortArray)outputArray; + (*jniEnv)->SetShortArrayRegion(jniEnv, typedArr, 0, outputLength, (jshort * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray) output; - (*jniEnv)->SetIntArrayRegion(jniEnv, typedArr, 0, outputLength, (jint * ) tensor); + jintArray typedArr = (jintArray)outputArray; + (*jniEnv)->SetIntArrayRegion(jniEnv, typedArr, 0, outputLength, (jint * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray) output; - (*jniEnv)->SetLongArrayRegion(jniEnv, typedArr, 0, outputLength, (jlong * ) tensor); + jlongArray typedArr = (jlongArray)outputArray; + (*jniEnv)->SetLongArrayRegion(jniEnv, typedArr, 0, outputLength, (jlong * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { // stored as a uint16_t jfloat *floatArr = malloc(sizeof(jfloat) * outputLength); - if(floatArr == NULL) { + if (floatArr == NULL) { throwOrtException(jniEnv, 1, "Not enough memory"); - return 0; + return -1; } - uint16_t *halfArr = (uint16_t *) tensor; - for (uint32_t i = 0; i < outputLength; i++) { + uint16_t *halfArr = (uint16_t *)inputTensor; + for (int32_t i = 0; i < outputLength; i++) { floatArr[i] = convertHalfToFloat(halfArr[i]); } - jfloatArray typedArr = (jfloatArray) output; + jfloatArray typedArr = (jfloatArray)outputArray; (*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, floatArr); free(floatArr); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray) output; - (*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, (jfloat * ) tensor); + jfloatArray typedArr = (jfloatArray)outputArray; + (*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, (jfloat * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray) output; - (*jniEnv)->SetDoubleArrayRegion(jniEnv, typedArr, 0, outputLength, (jdouble * ) tensor); + jdoubleArray typedArr = (jdoubleArray)outputArray; + (*jniEnv)->SetDoubleArrayRegion(jniEnv, typedArr, 0, outputLength, (jdouble * )inputTensor); return consumedSize; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string // Shouldn't reach here, as it's caught by a different codepath in the initial OnnxTensor.getArray call. throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported by this codepath, please raise a Github issue as it should not reach here."); - return 0; + return -1; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray) output; - (*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *) tensor); + jbooleanArray typedArr = (jbooleanArray)outputArray; + (*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *)inputTensor); return consumedSize; } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: { + // complex with float32 real and imaginary components + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64."); + return -1; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: { + // complex with float64 real and imaginary components + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128."); + return -1; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: { + // Non-IEEE floating-point format based on IEEE754 single-precision + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16."); + return -1; + } case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: default: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid tensor element type."); - return 0; + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED."); + return -1; } } } -size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, - size_t dimensionsRemaining, jarray output) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyPrimitiveArrayToJava(jniEnv,onnxType,tensor,output); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray outputObjArr = (jobjectArray) output; - uint32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv,outputObjArr); - size_t sizeConsumed = 0; - for (uint32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv,outputObjArr,i); - sizeConsumed += copyTensorToJava(jniEnv, onnxType, tensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv,childArr); - } - return sizeConsumed; +int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, + size_t dimensionsRemaining, jarray outputArray) { + if (dimensionsRemaining == 1) { + // write out 1d array of the respective primitive type + return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray); + } else { + // recurse through the dimensions + // Java arrays are objects until the final dimension + jobjectArray outputObjArr = (jobjectArray)outputArray; + int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr); + int64_t sizeConsumed = 0; + for (int32_t i = 0; i < dimLength; i++) { + jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i); + int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); + sizeConsumed += consumed; + // Cleanup reference to childArr so it doesn't prevent GC. + (*jniEnv)->DeleteLocalRef(jniEnv, childArr); + // If we failed to copy an array then break and return. + if (consumed == -1) { + return -1; + } } + return sizeConsumed; + } } -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { - // Get the buffer size needed - size_t totalStringLength; - checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,&totalStringLength)); - - // Create the character and offset buffers - char * characterBuffer; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(char)*(totalStringLength+1),(void**)&characterBuffer)); - size_t * offsets; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(size_t),(void**)&offsets)); - - // Get a view on the String data - checkOrtStatus(jniEnv,api,api->GetStringTensorContent(tensor,characterBuffer,totalStringLength,offsets,1)); - - size_t curSize = (offsets[0]) + 1; - characterBuffer[curSize-1] = '\0'; - jobject tempString = (*jniEnv)->NewStringUTF(jniEnv,characterBuffer); +jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { + jobject tempString = NULL; + // Get the buffer size needed + size_t totalStringLength = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); + if (code != ORT_OK) { + return NULL; + } - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,characterBuffer)); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,offsets)); + // Create the character and offset buffers, character is one larger to allow zero termination. + char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1)); + if (characterBuffer == NULL) { + throwOrtException(jniEnv, 1, "OOM error"); + } else { + size_t * offsets = malloc(sizeof(size_t)); + if (offsets != NULL) { + // Get a view on the String data + code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1)); + + if (code == ORT_OK) { + size_t curSize = (offsets[0]) + 1; + characterBuffer[curSize-1] = '\0'; + tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer); + } + + free((void*)characterBuffer); + free((void*)offsets); + } + } - return tempString; + return tempString; } -void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray) { - // Get the buffer size needed - size_t totalStringLength; - checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,&totalStringLength)); - - // Create the character and offset buffers - char * characterBuffer; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(char)*(totalStringLength+length),(void**)&characterBuffer)); - // length + 1 as we need to write out the final offset - size_t * offsets; - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(size_t)*(length+1),(void**)&offsets)); +OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { + char * tempBuffer = NULL; + // Get the buffer size needed + size_t totalStringLength = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); + if (code != ORT_OK) { + return code; + } - // Get a view on the String data - checkOrtStatus(jniEnv,api,api->GetStringTensorContent(tensor,characterBuffer,totalStringLength,offsets,length)); + // Create the character and offset buffers + char * characterBuffer = malloc(sizeof(char)*(totalStringLength+length)); + if (characterBuffer == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return ORT_FAIL; + } + // length + 1 as we need to write out the final offset + size_t * offsets = malloc(sizeof(size_t)*(length+1)); + if (offsets == NULL) { + free((void*)characterBuffer); + throwOrtException(jniEnv, 1, "Not enough memory"); + return ORT_FAIL; + } + // Get a view on the String data + code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, length)); + if (code == ORT_OK) { // Get the final offset, write to the end of the array. - checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,offsets+length)); - - char * tempBuffer = NULL; - size_t bufferSize = 0; - for (size_t i = 0; i < length; i++) { + code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, offsets+length)); + if (code == ORT_OK) { + size_t bufferSize = 0; + for (size_t i = 0; i < length; i++) { size_t curSize = (offsets[i+1] - offsets[i]) + 1; if (curSize > bufferSize) { - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,tempBuffer)); - checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,curSize,(void**)&tempBuffer)); - bufferSize = curSize; - } - if (tempBuffer == NULL) { + if (tempBuffer != NULL) { + free((void*)tempBuffer); + } + tempBuffer = malloc(sizeof(char) * curSize); + if (tempBuffer == NULL) { throwOrtException(jniEnv, 1, "Not enough memory"); - return; + goto string_tensor_cleanup; + } + bufferSize = curSize; } memcpy(tempBuffer,characterBuffer+offsets[i],curSize); tempBuffer[curSize-1] = '\0'; jobject tempString = (*jniEnv)->NewStringUTF(jniEnv,tempBuffer); (*jniEnv)->SetObjectArrayElement(jniEnv,outputArray,safecast_size_t_to_jsize(i),tempString); + } } + } - if (tempBuffer != NULL) { - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,tempBuffer)); - } - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,characterBuffer)); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,offsets)); +string_tensor_cleanup: + if (tempBuffer != NULL) { + free((void*)tempBuffer); + } + free((void*)offsets); + free((void*)characterBuffer); + return code; } -jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { +jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { // Extract tensor info - OrtTensorTypeAndShapeInfo* tensorInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo)); + OrtTensorTypeAndShapeInfo* tensorInfo = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &tensorInfo)); + if (code != ORT_OK) { + return NULL; + } // Get the element count of this tensor - size_t length; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length)); + size_t length = 0; + code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(tensorInfo, &length)); api->ReleaseTensorTypeAndShapeInfo(tensorInfo); + if (code != ORT_OK) { + return NULL; + } // Create the java array - jclass stringClazz = (*jniEnv)->FindClass(jniEnv,"java/lang/String"); - jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(length),stringClazz, NULL); + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(length), stringClazz, NULL); - copyStringTensorToArray(jniEnv, api, allocator, tensor, length, outputArray); + code = copyStringTensorToArray(jniEnv, api, tensor, length, outputArray); + if (code != ORT_OK) { + outputArray = NULL; + } return outputArray; } jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - // Extract tensor type - OrtTensorTypeAndShapeInfo* tensorInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo)); - ONNXTensorElementDataType value; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value)); - - // Get the element count of this tensor - size_t length; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length)); + jlongArray outputArray = NULL; + // Extract tensor type + OrtTensorTypeAndShapeInfo* tensorInfo = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo)); + if (code == ORT_OK) { + ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value)); + if ((code == ORT_OK) && ((value == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) || (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64))) { + // Get the element count of this tensor + size_t length = 0; + code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length)); + if (code == ORT_OK) { + // Extract the values + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr)); + if (code == ORT_OK) { + // Create the java array and copy to it. + outputArray = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(length)); + int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); + if (consumed == -1) { + outputArray = NULL; + } + } + } + } api->ReleaseTensorTypeAndShapeInfo(tensorInfo); - - // Extract the values - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr)); - - // Create the java array and copy to it. - jlongArray outputArray = (*jniEnv)->NewLongArray(jniEnv,safecast_size_t_to_jsize(length)); - copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); - return outputArray; + } + return outputArray; } jfloatArray createFloatArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - // Extract tensor type - OrtTensorTypeAndShapeInfo* tensorInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo)); - ONNXTensorElementDataType value; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value)); - - // Get the element count of this tensor - size_t length; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length)); + jfloatArray outputArray = NULL; + // Extract tensor type + OrtTensorTypeAndShapeInfo* tensorInfo = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo)); + if (code == ORT_OK) { + ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value)); + if ((code == ORT_OK) && (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) { + // Get the element count of this tensor + size_t length = 0; + code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length)); + if (code == ORT_OK) { + // Extract the values + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr)); + if (code == ORT_OK) { + // Create the java array and copy to it. + outputArray = (*jniEnv)->NewFloatArray(jniEnv, safecast_size_t_to_jsize(length)); + int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); + if (consumed == -1) { + outputArray = NULL; + } + } + } + } api->ReleaseTensorTypeAndShapeInfo(tensorInfo); - - // Extract the values - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr)); - - // Create the java array and copy to it. - jfloatArray outputArray = (*jniEnv)->NewFloatArray(jniEnv,safecast_size_t_to_jsize(length)); - copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); - return outputArray; + } + return outputArray; } jdoubleArray createDoubleArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - // Extract tensor type - OrtTensorTypeAndShapeInfo* tensorInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo)); - ONNXTensorElementDataType value; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value)); - - // Get the element count of this tensor - size_t length; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length)); + jdoubleArray outputArray = NULL; + // Extract tensor type + OrtTensorTypeAndShapeInfo* tensorInfo = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo)); + if (code == ORT_OK) { + ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value)); + if ((code == ORT_OK) && (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)) { + // Get the element count of this tensor + size_t length = 0; + code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length)); + if (code == ORT_OK) { + // Extract the values + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr)); + if (code == ORT_OK) { + // Create the java array and copy to it. + outputArray = (*jniEnv)->NewDoubleArray(jniEnv, safecast_size_t_to_jsize(length)); + int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); + if (consumed == -1) { + outputArray = NULL; + } + } + } + } api->ReleaseTensorTypeAndShapeInfo(tensorInfo); - - // Extract the values - uint8_t* arr; - checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr)); - - // Create the java array and copy to it. - jdoubleArray outputArray = (*jniEnv)->NewDoubleArray(jniEnv,safecast_size_t_to_jsize(length)); - copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray); - return outputArray; + } + return outputArray; } jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) { - // Extract the type information - OrtTensorTypeAndShapeInfo* info; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &info)); - - // Construct the TensorInfo object - jobject tensorInfo = convertToTensorInfo(jniEnv, api, info); + // Extract the type information + OrtTensorTypeAndShapeInfo* info = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &info)); + if (code != ORT_OK) { + return NULL; + } - // Release the info object - api->ReleaseTensorTypeAndShapeInfo(info); + // Construct the TensorInfo object + jobject tensorInfo = convertToTensorInfo(jniEnv, api, info); + // Release the info object + api->ReleaseTensorTypeAndShapeInfo(info); + if (tensorInfo == NULL) { + return NULL; + } - // Construct the ONNXTensor object - char *tensorClassName = "ai/onnxruntime/OnnxTensor"; - jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName); - jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "(JJLai/onnxruntime/TensorInfo;)V"); - jobject javaTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, tensorInfo); + // Construct the ONNXTensor object + static const char *tensorClassName = "ai/onnxruntime/OnnxTensor"; + jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName); + jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "(JJLai/onnxruntime/TensorInfo;)V"); + jobject javaTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, tensorInfo); - return javaTensor; + return javaTensor; } jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence) { - // Setup - // Get the ONNXTensorType enum static method - char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; - jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); - jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); - - // Get the ONNXJavaType enum static method - char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); - jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); - - // Get the sequence info class - char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; - jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); + // Get the sequence info class + static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo"; + jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName); - // Get the element count of this sequence - size_t count; - checkOrtStatus(jniEnv,api,api->GetValueCount(sequence,&count)); + // setup return value + jobject sequenceInfo = NULL; + // Get the element count of this sequence + size_t count = 0; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count)); + if (code != ORT_OK) { + return NULL; + } else if (count == 0) { + // Construct empty sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, 0, + convertFromONNXDataFormat(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)); + } else { // Extract the first element - OrtValue* firstElement; - checkOrtStatus(jniEnv,api,api->GetValue(sequence,0,allocator,&firstElement)); - ONNXType elementType; - checkOrtStatus(jniEnv,api,api->GetValueType(firstElement,&elementType)); - jobject sequenceInfo; - switch (elementType) { + OrtValue* firstElement = NULL; + code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, 0, allocator, &firstElement)); + if (code != ORT_OK) { + return NULL; + } + ONNXType elementType = ONNX_TYPE_UNKNOWN; + code = checkOrtStatus(jniEnv, api, api->GetValueType(firstElement, &elementType)); + if (code == ORT_OK) { + switch (elementType) { case ONNX_TYPE_TENSOR: { - // Figure out element type - OrtTensorTypeAndShapeInfo* firstElementInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(firstElement,&firstElementInfo)); - ONNXTensorElementDataType element; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(firstElementInfo,&element)); + // Figure out element type + OrtTensorTypeAndShapeInfo* firstElementInfo = NULL; + code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(firstElement, &firstElementInfo)); + if (code == ORT_OK) { + ONNXTensorElementDataType element = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(firstElementInfo, &element)); api->ReleaseTensorTypeAndShapeInfo(firstElementInfo); + if (code == ORT_OK) { + // Convert element type into ONNXTensorType + jint onnxTypeInt = convertFromONNXDataFormat(element); - // Convert element type into ONNXTensorType - jint onnxTypeInt = convertFromONNXDataFormat(element); - jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt); - jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava); - - // Construct sequence info - jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V"); - sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)count,onnxJavaType); - break; + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)count, onnxTypeInt); + } + } + break; } case ONNX_TYPE_MAP: { - // Extract key - OrtValue* keys; - checkOrtStatus(jniEnv,api,api->GetValue(firstElement,0,allocator,&keys)); - - // Extract key type - OrtTensorTypeAndShapeInfo* keysInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(keys,&keysInfo)); - ONNXTensorElementDataType key; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(keysInfo,&key)); - - // Get the element count of this map - size_t mapCount; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(keysInfo,&mapCount)); - - api->ReleaseTensorTypeAndShapeInfo(keysInfo); - - // Convert key type to java - jint onnxTypeKey = convertFromONNXDataFormat(key); - jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey); - jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey); - - // Extract value - OrtValue* values; - checkOrtStatus(jniEnv,api,api->GetValue(firstElement,1,allocator,&values)); - - // Extract value type - OrtTensorTypeAndShapeInfo* valuesInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(values,&valuesInfo)); - ONNXTensorElementDataType value; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(valuesInfo,&value)); - api->ReleaseTensorTypeAndShapeInfo(valuesInfo); - - // Convert value type to java - jint onnxTypeValue = convertFromONNXDataFormat(value); - jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue); - jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue); - - // Get the map info class - char *mapInfoClassName = "ai/onnxruntime/MapInfo"; - jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); - // Construct map info - jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); - jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)mapCount,onnxJavaTypeKey,onnxJavaTypeValue); - - // Free the intermediate tensors. - api->ReleaseValue(keys); - api->ReleaseValue(values); - - // Construct sequence info - jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V"); - sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)count,mapInfo); - break; + jobject mapInfo = createMapInfoFromValue(jniEnv, api, allocator, firstElement); + if (mapInfo != NULL) { + // Construct sequence info + jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(ILai/onnxruntime/MapInfo;)V"); + sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)count, mapInfo); + } + break; } default: { - sequenceInfo = createEmptySequenceInfo(jniEnv); - throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence"); - break; + throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid element type found in sequence"); + break; } + } } - // Free the intermediate tensor. api->ReleaseValue(firstElement); + } + + jobject javaSequence = NULL; + if (sequenceInfo != NULL) { // Construct the ONNXSequence object - char *sequenceClassName = "ai/onnxruntime/OnnxSequence"; + static const char *sequenceClassName = "ai/onnxruntime/OnnxSequence"; jclass sequenceClazz = (*jniEnv)->FindClass(jniEnv, sequenceClassName); - jmethodID sequenceConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceClazz, "", "(JJLai/onnxruntime/SequenceInfo;)V"); - jobject javaSequence = (*jniEnv)->NewObject(jniEnv, sequenceClazz, sequenceConstructor, (jlong)sequence, (jlong)allocator, sequenceInfo); + jmethodID sequenceConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceClazz, "", "(JJLai/onnxruntime/SequenceInfo;)V"); + javaSequence = (*jniEnv)->NewObject(jniEnv, sequenceClazz, sequenceConstructor, (jlong)sequence, (jlong)allocator, sequenceInfo); + } - return javaSequence; + return javaSequence; } jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map) { - // Setup - // Get the ONNXTensorType enum static method - char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType"; - jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName); - jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;"); - - // Get the ONNXJavaType enum static method - char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType"; - jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName); - jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;"); - - // Get the map info class - char *mapInfoClassName = "ai/onnxruntime/MapInfo"; - jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); - - // Extract key - OrtValue* keys; - checkOrtStatus(jniEnv,api,api->GetValue(map,0,allocator,&keys)); - - // Extract key type - OrtTensorTypeAndShapeInfo* keysInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(keys,&keysInfo)); - ONNXTensorElementDataType key; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(keysInfo,&key)); - - // Get the element count of this map - size_t mapCount; - checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(keysInfo,&mapCount)); - - api->ReleaseTensorTypeAndShapeInfo(keysInfo); - - // Convert key type to java - jint onnxTypeKey = convertFromONNXDataFormat(key); - jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey); - jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey); - - // Extract value - OrtValue* values; - checkOrtStatus(jniEnv,api,api->GetValue(map,1,allocator,&values)); - - // Extract value type - OrtTensorTypeAndShapeInfo* valuesInfo; - checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(values,&valuesInfo)); - ONNXTensorElementDataType value; - checkOrtStatus(jniEnv,api,api->GetTensorElementType(valuesInfo,&value)); - api->ReleaseTensorTypeAndShapeInfo(valuesInfo); - - // Convert value type to java - jint onnxTypeValue = convertFromONNXDataFormat(value); - jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue); - jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue); - - // Construct map info - jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V"); - jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)mapCount,onnxJavaTypeKey,onnxJavaTypeValue); - - // Free the intermediate tensors. - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys)); - checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,values)); - - // Construct the ONNXMap object - char *mapClassName = "ai/onnxruntime/OnnxMap"; - jclass mapClazz = (*jniEnv)->FindClass(jniEnv, mapClassName); - jmethodID mapConstructor = (*jniEnv)->GetMethodID(jniEnv,mapClazz, "", "(JJLai/onnxruntime/MapInfo;)V"); - jobject javaMap = (*jniEnv)->NewObject(jniEnv, mapClazz, mapConstructor, (jlong)map, (jlong) allocator, mapInfo); - - return javaMap; + jobject mapInfo = createMapInfoFromValue(jniEnv, api, allocator, map); + if (mapInfo == NULL) { + return NULL; + } + + // Get the map class & constructor + static const char *mapClassName = "ai/onnxruntime/OnnxMap"; + jclass mapClazz = (*jniEnv)->FindClass(jniEnv, mapClassName); + jmethodID mapConstructor = (*jniEnv)->GetMethodID(jniEnv, mapClazz, "", "(JJLai/onnxruntime/MapInfo;)V"); + + // Construct the ONNXMap object + jobject javaMap = (*jniEnv)->NewObject(jniEnv, mapClazz, mapConstructor, (jlong)map, (jlong) allocator, mapInfo); + + return javaMap; +} + +jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map) { + // Extract key + OrtValue* keys = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue(map, 0, allocator, &keys)); + if (code != ORT_OK) { + return NULL; + } + + JavaTensorTypeShape keyInfo; + code = getTensorTypeShape(jniEnv, &keyInfo, api, keys); + api->ReleaseValue(keys); + if (code != ORT_OK) { + return NULL; + } + + // Extract value + OrtValue* values = NULL; + code = checkOrtStatus(jniEnv, api, api->GetValue(map, 1, allocator, &values)); + if (code != ORT_OK) { + return NULL; + } + + JavaTensorTypeShape valueInfo; + code = getTensorTypeShape(jniEnv, &valueInfo, api, values); + api->ReleaseValue(values); + if (code != ORT_OK) { + return NULL; + } + + // Convert key and value type to java + jint onnxTypeKey = convertFromONNXDataFormat(keyInfo.onnxTypeEnum); + jint onnxTypeValue = convertFromONNXDataFormat(valueInfo.onnxTypeEnum); + + // Get the map info class & constructor + static const char *mapInfoClassName = "ai/onnxruntime/MapInfo"; + jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName); + jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, mapInfoClazz, "", "(III)V"); + + // Construct map info + jobject mapInfo = (*jniEnv)->NewObject(jniEnv, mapInfoClazz, mapInfoConstructor, (jint)keyInfo.elementCount, onnxTypeKey, onnxTypeValue); + return mapInfo; } jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue) { - // Note this is the ONNXType C enum - ONNXType valueType; - checkOrtStatus(jniEnv,api,api->GetValueType(onnxValue,&valueType)); - switch (valueType) { - case ONNX_TYPE_TENSOR: { - return createJavaTensorFromONNX(jniEnv, api, allocator, onnxValue); - } - case ONNX_TYPE_SEQUENCE: { - return createJavaSequenceFromONNX(jniEnv, api, allocator, onnxValue); - } - case ONNX_TYPE_MAP: { - return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue); - } - case ONNX_TYPE_UNKNOWN: - case ONNX_TYPE_OPAQUE: - case ONNX_TYPE_OPTIONAL: - case ONNX_TYPE_SPARSETENSOR: { - throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR."); - break; - } - } + // Note this is the ONNXType C enum + ONNXType valueType = ONNX_TYPE_UNKNOWN; + OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetValueType(onnxValue,&valueType)); + if (code != ORT_OK) { return NULL; + } + switch (valueType) { + case ONNX_TYPE_TENSOR: { + return createJavaTensorFromONNX(jniEnv, api, allocator, onnxValue); + } + case ONNX_TYPE_SEQUENCE: { + return createJavaSequenceFromONNX(jniEnv, api, allocator, onnxValue); + } + case ONNX_TYPE_MAP: { + return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue); + } + case ONNX_TYPE_UNKNOWN: + case ONNX_TYPE_OPAQUE: + case ONNX_TYPE_OPTIONAL: + case ONNX_TYPE_SPARSETENSOR: + default: { + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR."); + return NULL; + } + } } jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) { - jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); + jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); - char *className = "ai/onnxruntime/OrtException"; - jclass exClazz = (*jniEnv)->FindClass(jniEnv,className); - jmethodID exConstructor = (*jniEnv)->GetMethodID(jniEnv, exClazz, "", "(ILjava/lang/String;)V"); - jobject javaException = (*jniEnv)->NewObject(jniEnv, exClazz, exConstructor, messageId, messageStr); + static const char *className = "ai/onnxruntime/OrtException"; + jclass exClazz = (*jniEnv)->FindClass(jniEnv, className); + jmethodID exConstructor = (*jniEnv)->GetMethodID(jniEnv, exClazz, "", "(ILjava/lang/String;)V"); + jobject javaException = (*jniEnv)->NewObject(jniEnv, exClazz, exConstructor, messageId, messageStr); - return (*jniEnv)->Throw(jniEnv,javaException); + return (*jniEnv)->Throw(jniEnv, javaException); } jint convertErrorCode(OrtErrorCode code) { diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 075202333e181..616a20503ad42 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -38,29 +38,27 @@ OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, c jfloat convertHalfToFloat(uint16_t half); -jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info); +jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info); jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info); jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info); -jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -jobject createEmptyMapInfo(JNIEnv *jniEnv); -jobject createEmptySequenceInfo(JNIEnv *jniEnv); +jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input); +int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor); -size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray input); +int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor); -size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output); +int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray); -size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray output); +int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); +jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); -void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray); +OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); -jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor); +jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); @@ -74,6 +72,8 @@ jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map); +jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map); + jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue); jint throwOrtException(JNIEnv *env, int messageId, const char *message); diff --git a/java/src/main/native/ai_onnxruntime_OnnxMap.c b/java/src/main/native/ai_onnxruntime_OnnxMap.c index dce0677c84672..161fa23ae81cc 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxMap.c +++ b/java/src/main/native/ai_onnxruntime_OnnxMap.c @@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv* OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys)); if (code == ORT_OK) { // Convert to Java String array - jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys); + jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys); api->ReleaseValue(keys); @@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values)); if (code == ORT_OK) { // Convert to Java String array - jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values); + jobjectArray output = createStringArrayFromTensor(jniEnv, api, values); api->ReleaseValue(values); diff --git a/java/src/main/native/ai_onnxruntime_OnnxSequence.c b/java/src/main/native/ai_onnxruntime_OnnxSequence.c index 12d3266f8a90f..22cfca6315768 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxSequence.c +++ b/java/src/main/native/ai_onnxruntime_OnnxSequence.c @@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN if (code == ORT_OK) { // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, allocator, keys); + output = createStringArrayFromTensor(jniEnv, api, keys); // Release if valid api->ReleaseValue(element); } @@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues( if (code == ORT_OK) { // Convert to Java String array - output = createStringArrayFromTensor(jniEnv, api, allocator, values); + output = createStringArrayFromTensor(jniEnv, api, values); // Release if valid api->ReleaseValue(element); } @@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn OrtValue* element; code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element)); if (code == ORT_OK) { - jobject str = createStringFromStringTensor(jniEnv, api, allocator, element); + jobject str = createStringFromStringTensor(jniEnv, api, element); if (str == NULL) { api->ReleaseValue(element); // bail out as exception has been thrown diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index 06a8ad42e7f2d..1656b4043cfe9 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -44,8 +44,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor // Check if we're copying a scalar or not if (shapeLen == 0) { // Scalars are passed in as a single element array - size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj); - failed = copied == 0 ? 1 : failed; + int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData); + failed = copied == -1 ? 1 : failed; } else { // Extract the tensor shape information JavaTensorTypeShape typeShape; @@ -53,9 +53,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor if (code == ORT_OK) { // Copy the java array into the tensor - size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount, - typeShape.dimensions, dataObj); - failed = copied == 0 ? 1 : failed; + int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount, + typeShape.dimensions, dataObj, tensorData); + failed = copied == -1 ? 1 : failed; } else { failed = 1; } @@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong * Signature: (JJ)Ljava/lang/String; */ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; // Extract a String array - if this becomes a performance issue we'll refactor later. - jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle, - (OrtValue*) handle); + jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle); if (outputArray != NULL) { // Get reference to the string jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0); @@ -410,7 +409,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool * Signature: (JJLjava/lang/Object;)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -418,7 +417,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value); if (code == ORT_OK) { if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier); + copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { uint8_t* arr = NULL; code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));