diff --git a/java/src/main/java/ai/onnxruntime/MapInfo.java b/java/src/main/java/ai/onnxruntime/MapInfo.java
index 552793139549d..5a2ac072f336a 100644
--- a/java/src/main/java/ai/onnxruntime/MapInfo.java
+++ b/java/src/main/java/ai/onnxruntime/MapInfo.java
@@ -4,6 +4,8 @@
*/
package ai.onnxruntime;
+import ai.onnxruntime.TensorInfo.OnnxTensorType;
+
/** Describes an {@link OnnxMap} object or output node. */
public class MapInfo implements ValueInfo {
@@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo {
this.valueType = valueType;
}
+ /**
+ * Construct a MapInfo with the specified size, key type and value type.
+ *
+ *
Called from JNI.
+ *
+ * @param size The size.
+ * @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys.
+ * @param valueTypeInt The int representing the {@link OnnxTensorType} of the values.
+ */
+ MapInfo(int size, int keyTypeInt, int valueTypeInt) {
+ this.size = size;
+ this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt));
+ this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt));
+ }
+
@Override
public String toString() {
String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size;
diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
index f71d26c8bfffa..653a3c9be91bd 100644
--- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java
+++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
@@ -90,14 +90,14 @@ public Object getValue() throws OrtException {
case BOOL:
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
case STRING:
- return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
+ return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case UNKNOWN:
default:
throw new OrtException("Extracting the value of an invalid Tensor.");
}
} else {
Object carrier = info.makeCarrier();
- getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
+ getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
@@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)
private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;
- private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
- throws OrtException;
+ private native String getString(long apiHandle, long nativeHandle) throws OrtException;
private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;
- private native void getArray(
- long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
+ private native void getArray(long apiHandle, long nativeHandle, Object carrier)
+ throws OrtException;
private native void close(long apiHandle, long nativeHandle);
diff --git a/java/src/main/java/ai/onnxruntime/SequenceInfo.java b/java/src/main/java/ai/onnxruntime/SequenceInfo.java
index a417634b72de0..5497766546958 100644
--- a/java/src/main/java/ai/onnxruntime/SequenceInfo.java
+++ b/java/src/main/java/ai/onnxruntime/SequenceInfo.java
@@ -4,6 +4,8 @@
*/
package ai.onnxruntime;
+import ai.onnxruntime.TensorInfo.OnnxTensorType;
+
/** Describes an {@link OnnxSequence}, including it's element type if known. */
public class SequenceInfo implements ValueInfo {
@@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo {
this.mapInfo = null;
}
+ /**
+ * Construct a sequence of known length, with the specified type. This sequence does not contain
+ * maps.
+ *
+ *
Called from JNI.
+ *
+ * @param length The length of the sequence.
+ * @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}.
+ */
+ SequenceInfo(int length, int sequenceTypeInt) {
+ this.length = length;
+ this.sequenceType =
+ OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt));
+ this.sequenceOfMaps = false;
+ this.mapInfo = null;
+ }
+
/**
* Construct a sequence of known length containing maps.
*
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index ead90d83bcba6..4a7a3b833bc2b 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -120,6 +120,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.onnxType = onnxType;
}
+ /**
+ * Constructs a TensorInfo with the specified shape and native type int.
+ *
+ *
Called from JNI.
+ *
+ * @param shape The tensor shape.
+ * @param typeInt The native type int.
+ */
+ TensorInfo(long[] shape, int typeInt) {
+ this.shape = shape;
+ this.onnxType = OnnxTensorType.mapFromInt(typeInt);
+ this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
+ }
+
/**
* Get a copy of the tensor's shape.
*
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 55f26f35b40df..a670179a0eb25 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -8,7 +8,8 @@
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
// To silence unused-parameter error.
- // This function must exist according to the JNI spec, but the arguments aren't necessary for the library to request a specific version.
+ // This function must exist according to the JNI spec, but the arguments aren't necessary for the library
+ // to request a specific version.
(void)vm; (void) reserved;
// Requesting 1.6 to support Android. Will need to be bumped to a later version to call interface default methods
// from native code, or to access other new Java features.
@@ -213,289 +214,271 @@ typedef union FP32 {
float floatVal;
} FP32;
-jfloat convertHalfToFloat(uint16_t half) {
+jfloat convertHalfToFloat(const uint16_t half) {
FP32 output;
output.intVal = (((half&0x8000)<<16) | (((half&0x7c00)+0x1C000)<<13) | ((half&0x03FF)<<13));
return output.floatVal;
}
-jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info) {
- ONNXType type;
- checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(info,&type));
+jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info) {
+ ONNXType type = ONNX_TYPE_UNKNOWN;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetOnnxTypeFromTypeInfo(info, &type));
+ if (code != ORT_OK) {
+ return NULL;
+ }
- switch (type) {
- case ONNX_TYPE_TENSOR: {
- const OrtTensorTypeAndShapeInfo* tensorInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(info,&tensorInfo));
- return convertToTensorInfo(jniEnv, api, (const OrtTensorTypeAndShapeInfo *) tensorInfo);
- }
- case ONNX_TYPE_SEQUENCE: {
- const OrtSequenceTypeInfo* sequenceInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToSequenceTypeInfo(info,&sequenceInfo));
- return convertToSequenceInfo(jniEnv, api, sequenceInfo);
- }
- case ONNX_TYPE_MAP: {
- const OrtMapTypeInfo* mapInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(info,&mapInfo));
- return convertToMapInfo(jniEnv, api, mapInfo);
- }
- case ONNX_TYPE_UNKNOWN:
- case ONNX_TYPE_OPAQUE:
- case ONNX_TYPE_SPARSETENSOR:
- default: {
- throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found.");
- return NULL;
- }
+ switch (type) {
+ case ONNX_TYPE_TENSOR: {
+ const OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(info, &tensorInfo));
+ if (code == ORT_OK) {
+ return convertToTensorInfo(jniEnv, api, tensorInfo);
+ } else {
+ return NULL;
+ }
+ }
+ case ONNX_TYPE_SEQUENCE: {
+ const OrtSequenceTypeInfo* sequenceInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToSequenceTypeInfo(info, &sequenceInfo));
+ if (code == ORT_OK) {
+ return convertToSequenceInfo(jniEnv, api, sequenceInfo);
+ } else {
+ return NULL;
+ }
+ }
+ case ONNX_TYPE_MAP: {
+ const OrtMapTypeInfo* mapInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToMapTypeInfo(info, &mapInfo));
+ if (code == ORT_OK) {
+ return convertToMapInfo(jniEnv, api, mapInfo);
+ } else {
+ return NULL;
+ }
}
+ case ONNX_TYPE_UNKNOWN:
+ case ONNX_TYPE_OPAQUE:
+ case ONNX_TYPE_SPARSETENSOR:
+ default: {
+ throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"Invalid ONNXType found.");
+ return NULL;
+ }
+ }
}
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info) {
- // Extract the information from the info struct.
- ONNXTensorElementDataType onnxType;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(info,&onnxType));
- size_t numDim;
- checkOrtStatus(jniEnv,api,api->GetDimensionsCount(info,&numDim));
- //printf("numDim %d\n",numDim);
- int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
- checkOrtStatus(jniEnv,api,api->GetDimensions(info, dimensions, numDim));
- jint onnxTypeInt = convertFromONNXDataFormat(onnxType);
-
- // Create the long array for the shape.
- jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
- (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
- // Free the dimensions array
- free(dimensions);
- dimensions = NULL;
-
- // Create the ONNXTensorType enum
- char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
- jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
- jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
- jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxTensorTypeMapFromInt,onnxTypeInt);
- //printf("ONNXTensorType class %p, methodID %p, object %p\n",clazz,onnxTensorTypeMapFromInt,onnxTensorTypeJava);
-
- // Create the ONNXJavaType enum
- char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
- clazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
- jmethodID javaDataTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
- jobject javaDataType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,javaDataTypeMapFromONNXTensorType,onnxTensorTypeJava);
- //printf("JavaDataType class %p, methodID %p, object %p\n",clazz,javaDataTypeMapFromONNXTensorType,javaDataType);
-
- // Create the TensorInfo object
- char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
- clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
- jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JLai/onnxruntime/OnnxJavaType;Lai/onnxruntime/TensorInfo$OnnxTensorType;)V");
- //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
- jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, javaDataType, onnxTensorTypeJava);
- return tensorInfo;
+ // Extract the information from the info struct.
+ ONNXTensorElementDataType onnxType;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(info, &onnxType));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+ size_t numDim = 0;
+ code = checkOrtStatus(jniEnv, api, api->GetDimensionsCount(info, &numDim));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+ //printf("numDim %d\n",numDim);
+ int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
+ code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
+ if (code != ORT_OK) {
+ free((void*) dimensions);
+ return NULL;
+ }
+ jint onnxTypeInt = convertFromONNXDataFormat(onnxType);
+
+ // Create the long array for the shape.
+ jlongArray shape = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(numDim));
+ (*jniEnv)->SetLongArrayRegion(jniEnv, shape, 0, safecast_size_t_to_jsize(numDim), (jlong*)dimensions);
+ // Free the dimensions array
+ free(dimensions);
+ dimensions = NULL;
+
+ // Create the TensorInfo object
+ static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
+ jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
+ jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V");
+ //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
+ jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
+ return tensorInfo;
}
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info) {
- // Create the java methods we need to call.
- // Get the ONNXTensorType enum static method
- char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
- jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
- jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
-
- // Get the ONNXJavaType enum static method
- char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
- jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
-
- // Get the map info class
- char *mapInfoClassName = "ai/onnxruntime/MapInfo";
- jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
- jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
-
- // Extract the key type
- ONNXTensorElementDataType keyType;
- checkOrtStatus(jniEnv,api,api->GetMapKeyType(info,&keyType));
-
- // Convert key type to java
- jint onnxTypeKey = convertFromONNXDataFormat(keyType);
- jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey);
- jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey);
-
- // according to include/onnxruntime/core/framework/data_types.h only the following values are supported.
- // string, int64, float, double
- // So extract the value type, then convert it to a tensor type so we can get it's element type.
- OrtTypeInfo* valueTypeInfo;
- checkOrtStatus(jniEnv,api,api->GetMapValueType(info,&valueTypeInfo));
- const OrtTensorTypeAndShapeInfo* tensorValueInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(valueTypeInfo,&tensorValueInfo));
- ONNXTensorElementDataType valueType;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorValueInfo,&valueType));
- api->ReleaseTypeInfo(valueTypeInfo);
- tensorValueInfo = NULL;
- valueTypeInfo = NULL;
-
- // Convert value type to java
- jint onnxTypeValue = convertFromONNXDataFormat(valueType);
- jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue);
- jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue);
+ // Extract the key type
+ ONNXTensorElementDataType keyType;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetMapKeyType(info, &keyType));
+ if (code != ORT_OK) {
+ return NULL;
+ }
- // Construct map info
- jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)-1,onnxJavaTypeKey,onnxJavaTypeValue);
+ // according to include/onnxruntime/core/framework/data_types.h only the following values are supported.
+ // string, int64, float, double
+ // So extract the value type, then convert it to a tensor type so we can get it's element type.
+ OrtTypeInfo* valueTypeInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->GetMapValueType(info, &valueTypeInfo));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+ const OrtTensorTypeAndShapeInfo* tensorValueInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(valueTypeInfo, &tensorValueInfo));
+ if (code != ORT_OK) {
+ api->ReleaseTypeInfo(valueTypeInfo);
+ return NULL;
+ }
+ ONNXTensorElementDataType valueType = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(tensorValueInfo, &valueType));
+ api->ReleaseTypeInfo(valueTypeInfo);
+ tensorValueInfo = NULL;
+ valueTypeInfo = NULL;
+ if (code != ORT_OK) {
+ return NULL;
+ }
- return mapInfo;
-}
+ // Convert key type to java
+ jint onnxTypeKey = convertFromONNXDataFormat(keyType);
+ // Convert value type to java
+ jint onnxTypeValue = convertFromONNXDataFormat(valueType);
-jobject createEmptyMapInfo(JNIEnv *jniEnv) {
- // Create the ONNXJavaType enum
- char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxJavaTypeClassName);
- jmethodID onnxJavaTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/OnnxJavaType;");
- jobject unknownType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxJavaTypeMapFromInt,0);
+ // Get the map info class
+ static const char *mapInfoClassName = "ai/onnxruntime/MapInfo";
+ jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
+ jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, mapInfoClazz, "", "(III)V");
- char *mapInfoClassName = "ai/onnxruntime/MapInfo";
- clazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
- jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz,"","(Lai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
- jobject mapInfo = (*jniEnv)->NewObject(jniEnv,clazz,mapInfoConstructor,unknownType,unknownType);
+ // Construct map info
+ jobject mapInfo = (*jniEnv)->NewObject(jniEnv, mapInfoClazz, mapInfoConstructor, (jint)-1, onnxTypeKey, onnxTypeValue);
- return mapInfo;
+ return mapInfo;
}
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info) {
- // Get the sequence info class
- char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
- jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
-
- // according to include/onnxruntime/core/framework/data_types.h the following values are supported.
- // tensor types, map and map
- OrtTypeInfo* elementTypeInfo;
- checkOrtStatus(jniEnv,api,api->GetSequenceElementType(info,&elementTypeInfo));
- ONNXType type;
- checkOrtStatus(jniEnv,api,api->GetOnnxTypeFromTypeInfo(elementTypeInfo,&type));
-
- jobject sequenceInfo;
-
- switch (type) {
- case ONNX_TYPE_TENSOR: {
- // Figure out element type
- const OrtTensorTypeAndShapeInfo* elementTensorInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToTensorInfo(elementTypeInfo,&elementTensorInfo));
- ONNXTensorElementDataType element;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(elementTensorInfo,&element));
-
- // Convert element type into ONNXTensorType
- jint onnxTypeInt = convertFromONNXDataFormat(element);
- // Get the ONNXTensorType enum static method
- char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
- jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
- jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
- jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt);
-
- // Get the ONNXJavaType enum static method
- char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
- jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
- jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava);
-
- // Construct sequence info
- jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V");
- sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,onnxJavaType);
- break;
- }
- case ONNX_TYPE_MAP: {
- // Extract the map info
- const OrtMapTypeInfo* mapInfo;
- checkOrtStatus(jniEnv,api,api->CastTypeInfoToMapTypeInfo(elementTypeInfo,&mapInfo));
-
- // Convert it using the existing convert function
- jobject javaMapInfo = convertToMapInfo(jniEnv,api,mapInfo);
+ // Get the sequence info class
+ static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
+ jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
+ jobject sequenceInfo = NULL;
+
+ // according to include/onnxruntime/core/framework/data_types.h the following values are supported.
+ // tensor types, map and map
+ OrtTypeInfo* elementTypeInfo = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetSequenceElementType(info, &elementTypeInfo));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+ ONNXType type = ONNX_TYPE_UNKNOWN;
+ code = checkOrtStatus(jniEnv, api, api->GetOnnxTypeFromTypeInfo(elementTypeInfo, &type));
+ if (code != ORT_OK) {
+ goto sequence_cleanup;
+ }
- // Construct sequence info
- jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V");
- sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)-1,javaMapInfo);
- break;
- }
- default: {
- sequenceInfo = createEmptySequenceInfo(jniEnv);
- throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence");
- break;
- }
+ switch (type) {
+ case ONNX_TYPE_TENSOR: {
+ // Figure out element type
+ const OrtTensorTypeAndShapeInfo* elementTensorInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToTensorInfo(elementTypeInfo, &elementTensorInfo));
+ if (code != ORT_OK) {
+ goto sequence_cleanup;
+ }
+ ONNXTensorElementDataType element = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(elementTensorInfo, &element));
+ if (code != ORT_OK) {
+ goto sequence_cleanup;
+ }
+
+ // Convert element type into ONNXTensorType
+ jint onnxTypeInt = convertFromONNXDataFormat(element);
+
+ // Construct sequence info
+ jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V");
+ sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)-1, onnxTypeInt);
+ break;
}
- api->ReleaseTypeInfo(elementTypeInfo);
- elementTypeInfo = NULL;
-
- return sequenceInfo;
-}
-
-jobject createEmptySequenceInfo(JNIEnv *jniEnv) {
- // Create the ONNXJavaType enum
- char *onnxJavaTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass clazz = (*jniEnv)->FindClass(jniEnv, onnxJavaTypeClassName);
- jmethodID onnxJavaTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,clazz, "mapFromInt", "(I)Lai/onnxruntime/OnnxJavaType;");
- jobject unknownType = (*jniEnv)->CallStaticObjectMethod(jniEnv,clazz,onnxJavaTypeMapFromInt,0);
+ case ONNX_TYPE_MAP: {
+ // Extract the map info
+ const OrtMapTypeInfo* mapInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->CastTypeInfoToMapTypeInfo(elementTypeInfo, &mapInfo));
+ if (code != ORT_OK) {
+ goto sequence_cleanup;
+ }
+
+ // Convert it using the existing convert function
+ jobject javaMapInfo = convertToMapInfo(jniEnv, api, mapInfo);
+
+ // Construct sequence info
+ jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(ILai/onnxruntime/MapInfo;)V");
+ sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)-1, javaMapInfo);
+ break;
+ }
+ default: {
+ throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid element type found in sequence");
+ break;
+ }
+ }
- char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
- clazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
- jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz,"","(ILai/onnxruntime/OnnxJavaType;)V");
- jobject sequenceInfo = (*jniEnv)->NewObject(jniEnv,clazz,sequenceInfoConstructor,-1,unknownType);
+sequence_cleanup:
+ api->ReleaseTypeInfo(elementTypeInfo);
+ elementTypeInfo = NULL;
- return sequenceInfo;
+ return sequenceInfo;
}
-size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input) {
- uint32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv,input);
- size_t consumedSize = inputLength * onnxTypeSize(onnxType);
+int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) {
+ int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray);
+ int64_t consumedSize = inputLength * onnxTypeSize(onnxType);
switch (onnxType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t
- jbyteArray typedArr = (jbyteArray) input;
- (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * ) tensor);
+ jbyteArray typedArr = (jbyteArray)inputArray;
+ (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t
- jshortArray typedArr = (jshortArray) input;
- (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * ) tensor);
+ jshortArray typedArr = (jshortArray)inputArray;
+ (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t
- jintArray typedArr = (jintArray) input;
- (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * ) tensor);
+ jintArray typedArr = (jintArray)inputArray;
+ (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t
- jlongArray typedArr = (jlongArray) input;
- (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * ) tensor);
+ jlongArray typedArr = (jlongArray)inputArray;
+ (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported.");
- return 0;
+ return -1;
/*
float *floatArr = malloc(sizeof(float) * inputLength);
- uint16_t *halfArr = (uint16_t *) tensor;
+ uint16_t *halfArr = (uint16_t *) outputTensor;
for (uint32_t i = 0; i < inputLength; i++) {
floatArr[i] = convertHalfToFloat(halfArr[i]);
}
- jfloatArray typedArr = (jfloatArray) input;
+ jfloatArray typedArr = (jfloatArray) inputArray;
(*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr);
free(floatArr);
return consumedSize;
*/
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float
- jfloatArray typedArr = (jfloatArray) input;
- (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * ) tensor);
+ jfloatArray typedArr = (jfloatArray)inputArray;
+ (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double
- jdoubleArray typedArr = (jdoubleArray) input;
- (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * ) tensor);
+ jdoubleArray typedArr = (jdoubleArray)inputArray;
+ (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported.");
- return 0;
+ return -1;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
- jbooleanArray typedArr = (jbooleanArray) input;
- (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *) tensor);
+ jbooleanArray typedArr = (jbooleanArray)inputArray;
+ (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components
@@ -503,524 +486,566 @@ size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxTy
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
default: {
- throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid tensor element type.");
- return 0;
+ throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type.");
+ return -1;
}
}
}
-size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize,
- size_t dimensionsRemaining, jarray input) {
+int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) {
if (dimensionsRemaining == 1) {
// write out 1d array of the respective primitive type
- return copyJavaToPrimitiveArray(jniEnv,onnxType,tensor,input);
+ return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor);
} else {
// recurse through the dimensions
// Java arrays are objects until the final dimension
- jobjectArray inputObjArr = (jobjectArray) input;
- uint32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv,inputObjArr);
- size_t sizeConsumed = 0;
- for (uint32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv,inputObjArr,i);
- sizeConsumed += copyJavaToTensor(jniEnv, onnxType, tensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr);
+ jobjectArray inputObjArr = (jobjectArray)inputArray;
+ int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr);
+ int64_t sizeConsumed = 0;
+ for (int32_t i = 0; i < dimLength; i++) {
+ jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i);
+ int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed);
+ sizeConsumed += consumed;
// Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv,childArr);
+ (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
+ // If we failed to copy an array then break and return.
+ if (consumed == -1) {
+ return -1;
+ }
}
return sizeConsumed;
}
}
-size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output) {
- uint32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv,output);
+int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) {
+ int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray);
if (outputLength == 0) return 0;
- size_t consumedSize = outputLength * onnxTypeSize(onnxType);
+ int64_t consumedSize = outputLength * onnxTypeSize(onnxType);
switch (onnxType) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t
- jbyteArray typedArr = (jbyteArray) output;
- (*jniEnv)->SetByteArrayRegion(jniEnv, typedArr, 0, outputLength, (jbyte * ) tensor);
+ jbyteArray typedArr = (jbyteArray)outputArray;
+ (*jniEnv)->SetByteArrayRegion(jniEnv, typedArr, 0, outputLength, (jbyte * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t
- jshortArray typedArr = (jshortArray) output;
- (*jniEnv)->SetShortArrayRegion(jniEnv, typedArr, 0, outputLength, (jshort * ) tensor);
+ jshortArray typedArr = (jshortArray)outputArray;
+ (*jniEnv)->SetShortArrayRegion(jniEnv, typedArr, 0, outputLength, (jshort * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t
- jintArray typedArr = (jintArray) output;
- (*jniEnv)->SetIntArrayRegion(jniEnv, typedArr, 0, outputLength, (jint * ) tensor);
+ jintArray typedArr = (jintArray)outputArray;
+ (*jniEnv)->SetIntArrayRegion(jniEnv, typedArr, 0, outputLength, (jint * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t
- jlongArray typedArr = (jlongArray) output;
- (*jniEnv)->SetLongArrayRegion(jniEnv, typedArr, 0, outputLength, (jlong * ) tensor);
+ jlongArray typedArr = (jlongArray)outputArray;
+ (*jniEnv)->SetLongArrayRegion(jniEnv, typedArr, 0, outputLength, (jlong * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { // stored as a uint16_t
jfloat *floatArr = malloc(sizeof(jfloat) * outputLength);
- if(floatArr == NULL) {
+ if (floatArr == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
- return 0;
+ return -1;
}
- uint16_t *halfArr = (uint16_t *) tensor;
- for (uint32_t i = 0; i < outputLength; i++) {
+ uint16_t *halfArr = (uint16_t *)inputTensor;
+ for (int32_t i = 0; i < outputLength; i++) {
floatArr[i] = convertHalfToFloat(halfArr[i]);
}
- jfloatArray typedArr = (jfloatArray) output;
+ jfloatArray typedArr = (jfloatArray)outputArray;
(*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, floatArr);
free(floatArr);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float
- jfloatArray typedArr = (jfloatArray) output;
- (*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, (jfloat * ) tensor);
+ jfloatArray typedArr = (jfloatArray)outputArray;
+ (*jniEnv)->SetFloatArrayRegion(jniEnv, typedArr, 0, outputLength, (jfloat * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double
- jdoubleArray typedArr = (jdoubleArray) output;
- (*jniEnv)->SetDoubleArrayRegion(jniEnv, typedArr, 0, outputLength, (jdouble * ) tensor);
+ jdoubleArray typedArr = (jdoubleArray)outputArray;
+ (*jniEnv)->SetDoubleArrayRegion(jniEnv, typedArr, 0, outputLength, (jdouble * )inputTensor);
return consumedSize;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string
// Shouldn't reach here, as it's caught by a different codepath in the initial OnnxTensor.getArray call.
throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported by this codepath, please raise a Github issue as it should not reach here.");
- return 0;
+ return -1;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
- jbooleanArray typedArr = (jbooleanArray) output;
- (*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *) tensor);
+ jbooleanArray typedArr = (jbooleanArray)outputArray;
+ (*jniEnv)->SetBooleanArrayRegion(jniEnv, typedArr, 0, outputLength, (jboolean *)inputTensor);
return consumedSize;
}
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: {
+ // complex with float32 real and imaginary components
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64.");
+ return -1;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: {
+ // complex with float64 real and imaginary components
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128.");
+ return -1;
+ }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: {
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16.");
+ return -1;
+ }
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
default: {
- throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid tensor element type.");
- return 0;
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Invalid inputTensor element type ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED.");
+ return -1;
}
}
}
-size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize,
- size_t dimensionsRemaining, jarray output) {
- if (dimensionsRemaining == 1) {
- // write out 1d array of the respective primitive type
- return copyPrimitiveArrayToJava(jniEnv,onnxType,tensor,output);
- } else {
- // recurse through the dimensions
- // Java arrays are objects until the final dimension
- jobjectArray outputObjArr = (jobjectArray) output;
- uint32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv,outputObjArr);
- size_t sizeConsumed = 0;
- for (uint32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv,outputObjArr,i);
- sizeConsumed += copyTensorToJava(jniEnv, onnxType, tensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr);
- // Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv,childArr);
- }
- return sizeConsumed;
+int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize,
+ size_t dimensionsRemaining, jarray outputArray) {
+ if (dimensionsRemaining == 1) {
+ // write out 1d array of the respective primitive type
+ return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray);
+ } else {
+ // recurse through the dimensions
+ // Java arrays are objects until the final dimension
+ jobjectArray outputObjArr = (jobjectArray)outputArray;
+ int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr);
+ int64_t sizeConsumed = 0;
+ for (int32_t i = 0; i < dimLength; i++) {
+ jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i);
+ int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr);
+ sizeConsumed += consumed;
+ // Cleanup reference to childArr so it doesn't prevent GC.
+ (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
+ // If we failed to copy an array then break and return.
+ if (consumed == -1) {
+ return -1;
+ }
}
+ return sizeConsumed;
+ }
}
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
- // Get the buffer size needed
- size_t totalStringLength;
- checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,&totalStringLength));
-
- // Create the character and offset buffers
- char * characterBuffer;
- checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(char)*(totalStringLength+1),(void**)&characterBuffer));
- size_t * offsets;
- checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(size_t),(void**)&offsets));
-
- // Get a view on the String data
- checkOrtStatus(jniEnv,api,api->GetStringTensorContent(tensor,characterBuffer,totalStringLength,offsets,1));
-
- size_t curSize = (offsets[0]) + 1;
- characterBuffer[curSize-1] = '\0';
- jobject tempString = (*jniEnv)->NewStringUTF(jniEnv,characterBuffer);
+jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
+ jobject tempString = NULL;
+ // Get the buffer size needed
+ size_t totalStringLength = 0;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength));
+ if (code != ORT_OK) {
+ return NULL;
+ }
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,characterBuffer));
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,offsets));
+ // Create the character and offset buffers, character is one larger to allow zero termination.
+ char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1));
+ if (characterBuffer == NULL) {
+ throwOrtException(jniEnv, 1, "OOM error");
+ } else {
+ size_t * offsets = malloc(sizeof(size_t));
+ if (offsets != NULL) {
+ // Get a view on the String data
+ code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1));
+
+ if (code == ORT_OK) {
+ size_t curSize = (offsets[0]) + 1;
+ characterBuffer[curSize-1] = '\0';
+ tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer);
+ }
+
+ free((void*)characterBuffer);
+ free((void*)offsets);
+ }
+ }
- return tempString;
+ return tempString;
}
-void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray) {
- // Get the buffer size needed
- size_t totalStringLength;
- checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,&totalStringLength));
-
- // Create the character and offset buffers
- char * characterBuffer;
- checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(char)*(totalStringLength+length),(void**)&characterBuffer));
- // length + 1 as we need to write out the final offset
- size_t * offsets;
- checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,sizeof(size_t)*(length+1),(void**)&offsets));
+OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) {
+ char * tempBuffer = NULL;
+ // Get the buffer size needed
+ size_t totalStringLength = 0;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength));
+ if (code != ORT_OK) {
+ return code;
+ }
- // Get a view on the String data
- checkOrtStatus(jniEnv,api,api->GetStringTensorContent(tensor,characterBuffer,totalStringLength,offsets,length));
+ // Create the character and offset buffers
+ char * characterBuffer = malloc(sizeof(char)*(totalStringLength+length));
+ if (characterBuffer == NULL) {
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return ORT_FAIL;
+ }
+ // length + 1 as we need to write out the final offset
+ size_t * offsets = malloc(sizeof(size_t)*(length+1));
+ if (offsets == NULL) {
+ free((void*)characterBuffer);
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return ORT_FAIL;
+ }
+ // Get a view on the String data
+ code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, length));
+ if (code == ORT_OK) {
// Get the final offset, write to the end of the array.
- checkOrtStatus(jniEnv,api,api->GetStringTensorDataLength(tensor,offsets+length));
-
- char * tempBuffer = NULL;
- size_t bufferSize = 0;
- for (size_t i = 0; i < length; i++) {
+ code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, offsets+length));
+ if (code == ORT_OK) {
+ size_t bufferSize = 0;
+ for (size_t i = 0; i < length; i++) {
size_t curSize = (offsets[i+1] - offsets[i]) + 1;
if (curSize > bufferSize) {
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,tempBuffer));
- checkOrtStatus(jniEnv,api,api->AllocatorAlloc(allocator,curSize,(void**)&tempBuffer));
- bufferSize = curSize;
- }
- if (tempBuffer == NULL) {
+ if (tempBuffer != NULL) {
+ free((void*)tempBuffer);
+ }
+ tempBuffer = malloc(sizeof(char) * curSize);
+ if (tempBuffer == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
- return;
+ goto string_tensor_cleanup;
+ }
+ bufferSize = curSize;
}
memcpy(tempBuffer,characterBuffer+offsets[i],curSize);
tempBuffer[curSize-1] = '\0';
jobject tempString = (*jniEnv)->NewStringUTF(jniEnv,tempBuffer);
(*jniEnv)->SetObjectArrayElement(jniEnv,outputArray,safecast_size_t_to_jsize(i),tempString);
+ }
}
+ }
- if (tempBuffer != NULL) {
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,tempBuffer));
- }
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,characterBuffer));
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,offsets));
+string_tensor_cleanup:
+ if (tempBuffer != NULL) {
+ free((void*)tempBuffer);
+ }
+ free((void*)offsets);
+ free((void*)characterBuffer);
+ return code;
}
-jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
+jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
// Extract tensor info
- OrtTensorTypeAndShapeInfo* tensorInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo));
+ OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &tensorInfo));
+ if (code != ORT_OK) {
+ return NULL;
+ }
// Get the element count of this tensor
- size_t length;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length));
+ size_t length = 0;
+ code = checkOrtStatus(jniEnv, api, api->GetTensorShapeElementCount(tensorInfo, &length));
api->ReleaseTensorTypeAndShapeInfo(tensorInfo);
+ if (code != ORT_OK) {
+ return NULL;
+ }
// Create the java array
- jclass stringClazz = (*jniEnv)->FindClass(jniEnv,"java/lang/String");
- jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv,safecast_size_t_to_jsize(length),stringClazz, NULL);
+ jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
+ jobjectArray outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(length), stringClazz, NULL);
- copyStringTensorToArray(jniEnv, api, allocator, tensor, length, outputArray);
+ code = copyStringTensorToArray(jniEnv, api, tensor, length, outputArray);
+ if (code != ORT_OK) {
+ outputArray = NULL;
+ }
return outputArray;
}
jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
- // Extract tensor type
- OrtTensorTypeAndShapeInfo* tensorInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo));
- ONNXTensorElementDataType value;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value));
-
- // Get the element count of this tensor
- size_t length;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length));
+ jlongArray outputArray = NULL;
+ // Extract tensor type
+ OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo));
+ if (code == ORT_OK) {
+ ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value));
+ if ((code == ORT_OK) && ((value == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) || (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64))) {
+ // Get the element count of this tensor
+ size_t length = 0;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length));
+ if (code == ORT_OK) {
+ // Extract the values
+ uint8_t* arr = NULL;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr));
+ if (code == ORT_OK) {
+ // Create the java array and copy to it.
+ outputArray = (*jniEnv)->NewLongArray(jniEnv, safecast_size_t_to_jsize(length));
+ int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
+ if (consumed == -1) {
+ outputArray = NULL;
+ }
+ }
+ }
+ }
api->ReleaseTensorTypeAndShapeInfo(tensorInfo);
-
- // Extract the values
- uint8_t* arr;
- checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr));
-
- // Create the java array and copy to it.
- jlongArray outputArray = (*jniEnv)->NewLongArray(jniEnv,safecast_size_t_to_jsize(length));
- copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
- return outputArray;
+ }
+ return outputArray;
}
jfloatArray createFloatArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
- // Extract tensor type
- OrtTensorTypeAndShapeInfo* tensorInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo));
- ONNXTensorElementDataType value;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value));
-
- // Get the element count of this tensor
- size_t length;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length));
+ jfloatArray outputArray = NULL;
+ // Extract tensor type
+ OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo));
+ if (code == ORT_OK) {
+ ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value));
+ if ((code == ORT_OK) && (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) {
+ // Get the element count of this tensor
+ size_t length = 0;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length));
+ if (code == ORT_OK) {
+ // Extract the values
+ uint8_t* arr = NULL;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr));
+ if (code == ORT_OK) {
+ // Create the java array and copy to it.
+ outputArray = (*jniEnv)->NewFloatArray(jniEnv, safecast_size_t_to_jsize(length));
+ int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
+ if (consumed == -1) {
+ outputArray = NULL;
+ }
+ }
+ }
+ }
api->ReleaseTensorTypeAndShapeInfo(tensorInfo);
-
- // Extract the values
- uint8_t* arr;
- checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr));
-
- // Create the java array and copy to it.
- jfloatArray outputArray = (*jniEnv)->NewFloatArray(jniEnv,safecast_size_t_to_jsize(length));
- copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
- return outputArray;
+ }
+ return outputArray;
}
jdoubleArray createDoubleArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
- // Extract tensor type
- OrtTensorTypeAndShapeInfo* tensorInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor,&tensorInfo));
- ONNXTensorElementDataType value;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo,&value));
-
- // Get the element count of this tensor
- size_t length;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo,&length));
+ jdoubleArray outputArray = NULL;
+ // Extract tensor type
+ OrtTensorTypeAndShapeInfo* tensorInfo = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &tensorInfo));
+ if (code == ORT_OK) {
+ ONNXTensorElementDataType value = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorElementType(tensorInfo, &value));
+ if ((code == ORT_OK) && (value == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)) {
+ // Get the element count of this tensor
+ size_t length = 0;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(tensorInfo, &length));
+ if (code == ORT_OK) {
+ // Extract the values
+ uint8_t* arr = NULL;
+ code = checkOrtStatus(jniEnv,api,api->GetTensorMutableData(tensor, (void**)&arr));
+ if (code == ORT_OK) {
+ // Create the java array and copy to it.
+ outputArray = (*jniEnv)->NewDoubleArray(jniEnv, safecast_size_t_to_jsize(length));
+ int64_t consumed = copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
+ if (consumed == -1) {
+ outputArray = NULL;
+ }
+ }
+ }
+ }
api->ReleaseTensorTypeAndShapeInfo(tensorInfo);
-
- // Extract the values
- uint8_t* arr;
- checkOrtStatus(jniEnv,api,api->GetTensorMutableData((OrtValue*)tensor,(void**)&arr));
-
- // Create the java array and copy to it.
- jdoubleArray outputArray = (*jniEnv)->NewDoubleArray(jniEnv,safecast_size_t_to_jsize(length));
- copyPrimitiveArrayToJava(jniEnv, value, arr, outputArray);
- return outputArray;
+ }
+ return outputArray;
}
jobject createJavaTensorFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor) {
- // Extract the type information
- OrtTensorTypeAndShapeInfo* info;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(tensor, &info));
-
- // Construct the TensorInfo object
- jobject tensorInfo = convertToTensorInfo(jniEnv, api, info);
+ // Extract the type information
+ OrtTensorTypeAndShapeInfo* info = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(tensor, &info));
+ if (code != ORT_OK) {
+ return NULL;
+ }
- // Release the info object
- api->ReleaseTensorTypeAndShapeInfo(info);
+ // Construct the TensorInfo object
+ jobject tensorInfo = convertToTensorInfo(jniEnv, api, info);
+ // Release the info object
+ api->ReleaseTensorTypeAndShapeInfo(info);
+ if (tensorInfo == NULL) {
+ return NULL;
+ }
- // Construct the ONNXTensor object
- char *tensorClassName = "ai/onnxruntime/OnnxTensor";
- jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName);
- jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "(JJLai/onnxruntime/TensorInfo;)V");
- jobject javaTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, tensorInfo);
+ // Construct the ONNXTensor object
+ static const char *tensorClassName = "ai/onnxruntime/OnnxTensor";
+ jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorClassName);
+ jmethodID tensorConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "(JJLai/onnxruntime/TensorInfo;)V");
+ jobject javaTensor = (*jniEnv)->NewObject(jniEnv, clazz, tensorConstructor, (jlong) tensor, (jlong) allocator, tensorInfo);
- return javaTensor;
+ return javaTensor;
}
jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* sequence) {
- // Setup
- // Get the ONNXTensorType enum static method
- char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
- jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
- jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
-
- // Get the ONNXJavaType enum static method
- char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
- jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
-
- // Get the sequence info class
- char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
- jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
+ // Get the sequence info class
+ static const char *sequenceInfoClassName = "ai/onnxruntime/SequenceInfo";
+ jclass sequenceInfoClazz = (*jniEnv)->FindClass(jniEnv, sequenceInfoClassName);
- // Get the element count of this sequence
- size_t count;
- checkOrtStatus(jniEnv,api,api->GetValueCount(sequence,&count));
+ // setup return value
+ jobject sequenceInfo = NULL;
+ // Get the element count of this sequence
+ size_t count = 0;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValueCount(sequence, &count));
+ if (code != ORT_OK) {
+ return NULL;
+ } else if (count == 0) {
+ // Construct empty sequence info
+ jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V");
+ sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, 0,
+ convertFromONNXDataFormat(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED));
+ } else {
// Extract the first element
- OrtValue* firstElement;
- checkOrtStatus(jniEnv,api,api->GetValue(sequence,0,allocator,&firstElement));
- ONNXType elementType;
- checkOrtStatus(jniEnv,api,api->GetValueType(firstElement,&elementType));
- jobject sequenceInfo;
- switch (elementType) {
+ OrtValue* firstElement = NULL;
+ code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, 0, allocator, &firstElement));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+ ONNXType elementType = ONNX_TYPE_UNKNOWN;
+ code = checkOrtStatus(jniEnv, api, api->GetValueType(firstElement, &elementType));
+ if (code == ORT_OK) {
+ switch (elementType) {
case ONNX_TYPE_TENSOR: {
- // Figure out element type
- OrtTensorTypeAndShapeInfo* firstElementInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(firstElement,&firstElementInfo));
- ONNXTensorElementDataType element;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(firstElementInfo,&element));
+ // Figure out element type
+ OrtTensorTypeAndShapeInfo* firstElementInfo = NULL;
+ code = checkOrtStatus(jniEnv, api, api->GetTensorTypeAndShape(firstElement, &firstElementInfo));
+ if (code == ORT_OK) {
+ ONNXTensorElementDataType element = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
+ code = checkOrtStatus(jniEnv, api, api->GetTensorElementType(firstElementInfo, &element));
api->ReleaseTensorTypeAndShapeInfo(firstElementInfo);
+ if (code == ORT_OK) {
+ // Convert element type into ONNXTensorType
+ jint onnxTypeInt = convertFromONNXDataFormat(element);
- // Convert element type into ONNXTensorType
- jint onnxTypeInt = convertFromONNXDataFormat(element);
- jobject onnxTensorTypeJava = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeInt);
- jobject onnxJavaType = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJava);
-
- // Construct sequence info
- jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;)V");
- sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)count,onnxJavaType);
- break;
+ // Construct sequence info
+ jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(II)V");
+ sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)count, onnxTypeInt);
+ }
+ }
+ break;
}
case ONNX_TYPE_MAP: {
- // Extract key
- OrtValue* keys;
- checkOrtStatus(jniEnv,api,api->GetValue(firstElement,0,allocator,&keys));
-
- // Extract key type
- OrtTensorTypeAndShapeInfo* keysInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(keys,&keysInfo));
- ONNXTensorElementDataType key;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(keysInfo,&key));
-
- // Get the element count of this map
- size_t mapCount;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(keysInfo,&mapCount));
-
- api->ReleaseTensorTypeAndShapeInfo(keysInfo);
-
- // Convert key type to java
- jint onnxTypeKey = convertFromONNXDataFormat(key);
- jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey);
- jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey);
-
- // Extract value
- OrtValue* values;
- checkOrtStatus(jniEnv,api,api->GetValue(firstElement,1,allocator,&values));
-
- // Extract value type
- OrtTensorTypeAndShapeInfo* valuesInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(values,&valuesInfo));
- ONNXTensorElementDataType value;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(valuesInfo,&value));
- api->ReleaseTensorTypeAndShapeInfo(valuesInfo);
-
- // Convert value type to java
- jint onnxTypeValue = convertFromONNXDataFormat(value);
- jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue);
- jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue);
-
- // Get the map info class
- char *mapInfoClassName = "ai/onnxruntime/MapInfo";
- jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
- // Construct map info
- jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
- jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)mapCount,onnxJavaTypeKey,onnxJavaTypeValue);
-
- // Free the intermediate tensors.
- api->ReleaseValue(keys);
- api->ReleaseValue(values);
-
- // Construct sequence info
- jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceInfoClazz,"","(ILai/onnxruntime/MapInfo;)V");
- sequenceInfo = (*jniEnv)->NewObject(jniEnv,sequenceInfoClazz,sequenceInfoConstructor,(jint)count,mapInfo);
- break;
+ jobject mapInfo = createMapInfoFromValue(jniEnv, api, allocator, firstElement);
+ if (mapInfo != NULL) {
+ // Construct sequence info
+ jmethodID sequenceInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceInfoClazz, "", "(ILai/onnxruntime/MapInfo;)V");
+ sequenceInfo = (*jniEnv)->NewObject(jniEnv, sequenceInfoClazz, sequenceInfoConstructor, (jint)count, mapInfo);
+ }
+ break;
}
default: {
- sequenceInfo = createEmptySequenceInfo(jniEnv);
- throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"Invalid element type found in sequence");
- break;
+ throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid element type found in sequence");
+ break;
}
+ }
}
-
// Free the intermediate tensor.
api->ReleaseValue(firstElement);
+ }
+
+ jobject javaSequence = NULL;
+ if (sequenceInfo != NULL) {
// Construct the ONNXSequence object
- char *sequenceClassName = "ai/onnxruntime/OnnxSequence";
+ static const char *sequenceClassName = "ai/onnxruntime/OnnxSequence";
jclass sequenceClazz = (*jniEnv)->FindClass(jniEnv, sequenceClassName);
- jmethodID sequenceConstructor = (*jniEnv)->GetMethodID(jniEnv,sequenceClazz, "", "(JJLai/onnxruntime/SequenceInfo;)V");
- jobject javaSequence = (*jniEnv)->NewObject(jniEnv, sequenceClazz, sequenceConstructor, (jlong)sequence, (jlong)allocator, sequenceInfo);
+ jmethodID sequenceConstructor = (*jniEnv)->GetMethodID(jniEnv, sequenceClazz, "", "(JJLai/onnxruntime/SequenceInfo;)V");
+ javaSequence = (*jniEnv)->NewObject(jniEnv, sequenceClazz, sequenceConstructor, (jlong)sequence, (jlong)allocator, sequenceInfo);
+ }
- return javaSequence;
+ return javaSequence;
}
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map) {
- // Setup
- // Get the ONNXTensorType enum static method
- char *onnxTensorTypeClassName = "ai/onnxruntime/TensorInfo$OnnxTensorType";
- jclass onnxTensorTypeClazz = (*jniEnv)->FindClass(jniEnv, onnxTensorTypeClassName);
- jmethodID onnxTensorTypeMapFromInt = (*jniEnv)->GetStaticMethodID(jniEnv,onnxTensorTypeClazz, "mapFromInt", "(I)Lai/onnxruntime/TensorInfo$OnnxTensorType;");
-
- // Get the ONNXJavaType enum static method
- char *javaDataTypeClassName = "ai/onnxruntime/OnnxJavaType";
- jclass onnxJavaTypeClazz = (*jniEnv)->FindClass(jniEnv, javaDataTypeClassName);
- jmethodID onnxJavaTypeMapFromONNXTensorType = (*jniEnv)->GetStaticMethodID(jniEnv,onnxJavaTypeClazz, "mapFromOnnxTensorType", "(Lai/onnxruntime/TensorInfo$OnnxTensorType;)Lai/onnxruntime/OnnxJavaType;");
-
- // Get the map info class
- char *mapInfoClassName = "ai/onnxruntime/MapInfo";
- jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
-
- // Extract key
- OrtValue* keys;
- checkOrtStatus(jniEnv,api,api->GetValue(map,0,allocator,&keys));
-
- // Extract key type
- OrtTensorTypeAndShapeInfo* keysInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(keys,&keysInfo));
- ONNXTensorElementDataType key;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(keysInfo,&key));
-
- // Get the element count of this map
- size_t mapCount;
- checkOrtStatus(jniEnv,api,api->GetTensorShapeElementCount(keysInfo,&mapCount));
-
- api->ReleaseTensorTypeAndShapeInfo(keysInfo);
-
- // Convert key type to java
- jint onnxTypeKey = convertFromONNXDataFormat(key);
- jobject onnxTensorTypeJavaKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeKey);
- jobject onnxJavaTypeKey = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaKey);
-
- // Extract value
- OrtValue* values;
- checkOrtStatus(jniEnv,api,api->GetValue(map,1,allocator,&values));
-
- // Extract value type
- OrtTensorTypeAndShapeInfo* valuesInfo;
- checkOrtStatus(jniEnv,api,api->GetTensorTypeAndShape(values,&valuesInfo));
- ONNXTensorElementDataType value;
- checkOrtStatus(jniEnv,api,api->GetTensorElementType(valuesInfo,&value));
- api->ReleaseTensorTypeAndShapeInfo(valuesInfo);
-
- // Convert value type to java
- jint onnxTypeValue = convertFromONNXDataFormat(value);
- jobject onnxTensorTypeJavaValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxTensorTypeClazz,onnxTensorTypeMapFromInt,onnxTypeValue);
- jobject onnxJavaTypeValue = (*jniEnv)->CallStaticObjectMethod(jniEnv,onnxJavaTypeClazz,onnxJavaTypeMapFromONNXTensorType,onnxTensorTypeJavaValue);
-
- // Construct map info
- jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,mapInfoClazz,"","(ILai/onnxruntime/OnnxJavaType;Lai/onnxruntime/OnnxJavaType;)V");
- jobject mapInfo = (*jniEnv)->NewObject(jniEnv,mapInfoClazz,mapInfoConstructor,(jint)mapCount,onnxJavaTypeKey,onnxJavaTypeValue);
-
- // Free the intermediate tensors.
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,keys));
- checkOrtStatus(jniEnv,api,api->AllocatorFree(allocator,values));
-
- // Construct the ONNXMap object
- char *mapClassName = "ai/onnxruntime/OnnxMap";
- jclass mapClazz = (*jniEnv)->FindClass(jniEnv, mapClassName);
- jmethodID mapConstructor = (*jniEnv)->GetMethodID(jniEnv,mapClazz, "", "(JJLai/onnxruntime/MapInfo;)V");
- jobject javaMap = (*jniEnv)->NewObject(jniEnv, mapClazz, mapConstructor, (jlong)map, (jlong) allocator, mapInfo);
-
- return javaMap;
+ jobject mapInfo = createMapInfoFromValue(jniEnv, api, allocator, map);
+ if (mapInfo == NULL) {
+ return NULL;
+ }
+
+ // Get the map class & constructor
+ static const char *mapClassName = "ai/onnxruntime/OnnxMap";
+ jclass mapClazz = (*jniEnv)->FindClass(jniEnv, mapClassName);
+ jmethodID mapConstructor = (*jniEnv)->GetMethodID(jniEnv, mapClazz, "", "(JJLai/onnxruntime/MapInfo;)V");
+
+ // Construct the ONNXMap object
+ jobject javaMap = (*jniEnv)->NewObject(jniEnv, mapClazz, mapConstructor, (jlong)map, (jlong) allocator, mapInfo);
+
+ return javaMap;
+}
+
+jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map) {
+ // Extract key
+ OrtValue* keys = NULL;
+ OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue(map, 0, allocator, &keys));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+
+ JavaTensorTypeShape keyInfo;
+ code = getTensorTypeShape(jniEnv, &keyInfo, api, keys);
+ api->ReleaseValue(keys);
+ if (code != ORT_OK) {
+ return NULL;
+ }
+
+ // Extract value
+ OrtValue* values = NULL;
+ code = checkOrtStatus(jniEnv, api, api->GetValue(map, 1, allocator, &values));
+ if (code != ORT_OK) {
+ return NULL;
+ }
+
+ JavaTensorTypeShape valueInfo;
+ code = getTensorTypeShape(jniEnv, &valueInfo, api, values);
+ api->ReleaseValue(values);
+ if (code != ORT_OK) {
+ return NULL;
+ }
+
+ // Convert key and value type to java
+ jint onnxTypeKey = convertFromONNXDataFormat(keyInfo.onnxTypeEnum);
+ jint onnxTypeValue = convertFromONNXDataFormat(valueInfo.onnxTypeEnum);
+
+ // Get the map info class & constructor
+ static const char *mapInfoClassName = "ai/onnxruntime/MapInfo";
+ jclass mapInfoClazz = (*jniEnv)->FindClass(jniEnv, mapInfoClassName);
+ jmethodID mapInfoConstructor = (*jniEnv)->GetMethodID(jniEnv, mapInfoClazz, "", "(III)V");
+
+ // Construct map info
+ jobject mapInfo = (*jniEnv)->NewObject(jniEnv, mapInfoClazz, mapInfoConstructor, (jint)keyInfo.elementCount, onnxTypeKey, onnxTypeValue);
+ return mapInfo;
}
jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue) {
- // Note this is the ONNXType C enum
- ONNXType valueType;
- checkOrtStatus(jniEnv,api,api->GetValueType(onnxValue,&valueType));
- switch (valueType) {
- case ONNX_TYPE_TENSOR: {
- return createJavaTensorFromONNX(jniEnv, api, allocator, onnxValue);
- }
- case ONNX_TYPE_SEQUENCE: {
- return createJavaSequenceFromONNX(jniEnv, api, allocator, onnxValue);
- }
- case ONNX_TYPE_MAP: {
- return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue);
- }
- case ONNX_TYPE_UNKNOWN:
- case ONNX_TYPE_OPAQUE:
- case ONNX_TYPE_OPTIONAL:
- case ONNX_TYPE_SPARSETENSOR: {
- throwOrtException(jniEnv,convertErrorCode(ORT_NOT_IMPLEMENTED),"These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR.");
- break;
- }
- }
+ // Note this is the ONNXType C enum
+ ONNXType valueType = ONNX_TYPE_UNKNOWN;
+ OrtErrorCode code = checkOrtStatus(jniEnv,api,api->GetValueType(onnxValue,&valueType));
+ if (code != ORT_OK) {
return NULL;
+ }
+ switch (valueType) {
+ case ONNX_TYPE_TENSOR: {
+ return createJavaTensorFromONNX(jniEnv, api, allocator, onnxValue);
+ }
+ case ONNX_TYPE_SEQUENCE: {
+ return createJavaSequenceFromONNX(jniEnv, api, allocator, onnxValue);
+ }
+ case ONNX_TYPE_MAP: {
+ return createJavaMapFromONNX(jniEnv, api, allocator, onnxValue);
+ }
+ case ONNX_TYPE_UNKNOWN:
+ case ONNX_TYPE_OPAQUE:
+ case ONNX_TYPE_OPTIONAL:
+ case ONNX_TYPE_SPARSETENSOR:
+ default: {
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "These types are unsupported - ONNX_TYPE_UNKNOWN, ONNX_TYPE_OPAQUE, ONNX_TYPE_SPARSETENSOR.");
+ return NULL;
+ }
+ }
}
jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) {
- jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message);
+ jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message);
- char *className = "ai/onnxruntime/OrtException";
- jclass exClazz = (*jniEnv)->FindClass(jniEnv,className);
- jmethodID exConstructor = (*jniEnv)->GetMethodID(jniEnv, exClazz, "", "(ILjava/lang/String;)V");
- jobject javaException = (*jniEnv)->NewObject(jniEnv, exClazz, exConstructor, messageId, messageStr);
+ static const char *className = "ai/onnxruntime/OrtException";
+ jclass exClazz = (*jniEnv)->FindClass(jniEnv, className);
+ jmethodID exConstructor = (*jniEnv)->GetMethodID(jniEnv, exClazz, "", "(ILjava/lang/String;)V");
+ jobject javaException = (*jniEnv)->NewObject(jniEnv, exClazz, exConstructor, messageId, messageStr);
- return (*jniEnv)->Throw(jniEnv,javaException);
+ return (*jniEnv)->Throw(jniEnv, javaException);
}
jint convertErrorCode(OrtErrorCode code) {
diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h
index 075202333e181..616a20503ad42 100644
--- a/java/src/main/native/OrtJniUtil.h
+++ b/java/src/main/native/OrtJniUtil.h
@@ -38,29 +38,27 @@ OrtErrorCode getTensorTypeShape(JNIEnv * jniEnv, JavaTensorTypeShape * output, c
jfloat convertHalfToFloat(uint16_t half);
-jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, OrtTypeInfo * info);
+jobject convertToValueInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTypeInfo * info);
jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorTypeAndShapeInfo * info);
jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInfo * info);
-jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
-jobject createEmptyMapInfo(JNIEnv *jniEnv);
-jobject createEmptySequenceInfo(JNIEnv *jniEnv);
+jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
-size_t copyJavaToPrimitiveArray(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray input);
+int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor);
-size_t copyJavaToTensor(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray input);
+int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor);
-size_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, jarray output);
+int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray);
-size_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, uint8_t* tensor, size_t tensorSize, size_t dimensionsRemaining, jarray output);
+int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
+jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
-void copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor, size_t length, jobjectArray outputArray);
+OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);
-jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* tensor);
+jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
jlongArray createLongArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
@@ -74,6 +72,8 @@ jobject createJavaSequenceFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca
jobject createJavaMapFromONNX(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* map);
+jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator * allocator, const OrtValue * map);
+
jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue);
jint throwOrtException(JNIEnv *env, int messageId, const char *message);
diff --git a/java/src/main/native/ai_onnxruntime_OnnxMap.c b/java/src/main/native/ai_onnxruntime_OnnxMap.c
index dce0677c84672..161fa23ae81cc 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxMap.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxMap.c
@@ -22,7 +22,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringKeys(JNIEnv*
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 0, allocator, &keys));
if (code == ORT_OK) {
// Convert to Java String array
- jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
+ jobjectArray output = createStringArrayFromTensor(jniEnv, api, keys);
api->ReleaseValue(keys);
@@ -72,7 +72,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxMap_getStringValues(JNIEn
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetValue((OrtValue*)handle, 1, allocator, &values));
if (code == ORT_OK) {
// Convert to Java String array
- jobjectArray output = createStringArrayFromTensor(jniEnv, api, allocator, values);
+ jobjectArray output = createStringArrayFromTensor(jniEnv, api, values);
api->ReleaseValue(values);
diff --git a/java/src/main/native/ai_onnxruntime_OnnxSequence.c b/java/src/main/native/ai_onnxruntime_OnnxSequence.c
index 12d3266f8a90f..22cfca6315768 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxSequence.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxSequence.c
@@ -28,7 +28,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringKeys(JN
if (code == ORT_OK) {
// Convert to Java String array
- output = createStringArrayFromTensor(jniEnv, api, allocator, keys);
+ output = createStringArrayFromTensor(jniEnv, api, keys);
// Release if valid
api->ReleaseValue(element);
}
@@ -94,7 +94,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStringValues(
if (code == ORT_OK) {
// Convert to Java String array
- output = createStringArrayFromTensor(jniEnv, api, allocator, values);
+ output = createStringArrayFromTensor(jniEnv, api, values);
// Release if valid
api->ReleaseValue(element);
}
@@ -230,7 +230,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OnnxSequence_getStrings(JNIEn
OrtValue* element;
code = checkOrtStatus(jniEnv, api, api->GetValue(sequence, (int)i, allocator, &element));
if (code == ORT_OK) {
- jobject str = createStringFromStringTensor(jniEnv, api, allocator, element);
+ jobject str = createStringFromStringTensor(jniEnv, api, element);
if (str == NULL) {
api->ReleaseValue(element);
// bail out as exception has been thrown
diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
index 06a8ad42e7f2d..1656b4043cfe9 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
@@ -44,8 +44,8 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
// Check if we're copying a scalar or not
if (shapeLen == 0) {
// Scalars are passed in as a single element array
- size_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, tensorData, dataObj);
- failed = copied == 0 ? 1 : failed;
+ int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData);
+ failed = copied == -1 ? 1 : failed;
} else {
// Extract the tensor shape information
JavaTensorTypeShape typeShape;
@@ -53,9 +53,9 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
if (code == ORT_OK) {
// Copy the java array into the tensor
- size_t copied = copyJavaToTensor(jniEnv, onnxType, tensorData, typeShape.elementCount,
- typeShape.dimensions, dataObj);
- failed = copied == 0 ? 1 : failed;
+ int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount,
+ typeShape.dimensions, dataObj, tensorData);
+ failed = copied == -1 ? 1 : failed;
} else {
failed = 1;
}
@@ -367,12 +367,11 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_getLong
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OnnxTensor_getString
- (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle) {
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
// Extract a String array - if this becomes a performance issue we'll refactor later.
- jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtAllocator*) allocatorHandle,
- (OrtValue*) handle);
+ jobjectArray outputArray = createStringArrayFromTensor(jniEnv, api, (OrtValue*) handle);
if (outputArray != NULL) {
// Get reference to the string
jobject output = (*jniEnv)->GetObjectArrayElement(jniEnv, outputArray, 0);
@@ -410,7 +409,7 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
* Signature: (JJLjava/lang/Object;)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
- (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jlong allocatorHandle, jobject carrier) {
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* value = (OrtValue*) handle;
@@ -418,7 +417,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, value);
if (code == ORT_OK) {
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
- copyStringTensorToArray(jniEnv, api, (OrtAllocator*) allocatorHandle, value, typeShape.elementCount, carrier);
+ copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
} else {
uint8_t* arr = NULL;
code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));