diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index e1ee2c14fd9d1..3f276a3670156 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -54,7 +54,7 @@ public class OnnxTensor extends OnnxTensorLike { * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is - * a copy of a user buffer.) + * a copy of a user buffer or array.) */ public boolean ownsBuffer() { return this.ownsBuffer; @@ -62,8 +62,8 @@ public boolean ownsBuffer() { /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not - * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by - * ORT) this method returns an empty {@link Optional}. + * backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty + * {@link Optional}. * *
Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be
* used to repeatedly update a single tensor for multiple different inferences without allocating
@@ -77,7 +77,116 @@ public boolean ownsBuffer() {
* @return A reference to the buffer.
*/
public Optional Can be replaced to a call to buf.duplicate() in Java 9+.
+ *
+ * @param buf The buffer to duplicate.
+ * @return A copy of the buffer which refers to the same underlying memory, but has an independent
+ * position, limit and mark.
+ */
+ private static Buffer duplicate(Buffer buf) {
+ if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder());
+ } else if (buf instanceof ShortBuffer) {
+ return ((ShortBuffer) buf).duplicate();
+ } else if (buf instanceof IntBuffer) {
+ return ((IntBuffer) buf).duplicate();
+ } else if (buf instanceof LongBuffer) {
+ return ((LongBuffer) buf).duplicate();
+ } else if (buf instanceof FloatBuffer) {
+ return ((FloatBuffer) buf).duplicate();
+ } else if (buf instanceof DoubleBuffer) {
+ return ((DoubleBuffer) buf).duplicate();
+ } else {
+ throw new IllegalStateException("Unknown buffer type " + buf.getClass());
+ }
+ }
+
+ /**
+ * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link
+ * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link
+ * IllegalStateException}.
+ *
+ * Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve
+ * compatibility with existing {@link #getValue} calls.
+ *
+ * @param buf The buffer to convert.
+ * @return The buffer with the expected type.
+ */
+ private Buffer castBuffer(Buffer buf) {
+ switch (info.type) {
+ case FLOAT:
+ if (buf instanceof FloatBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asFloatBuffer();
+ }
+ break;
+ case DOUBLE:
+ if (buf instanceof DoubleBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asDoubleBuffer();
+ }
+ break;
+ case BOOL:
+ case INT8:
+ case UINT8:
+ if (buf instanceof ByteBuffer) {
+ return buf;
+ }
+ break;
+ case BFLOAT16:
+ if (buf instanceof ShortBuffer) {
+ ShortBuffer bf16Buf = (ShortBuffer) buf;
+ return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
+ } else if (buf instanceof ByteBuffer) {
+ ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer();
+ return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf);
+ }
+ break;
+ case FLOAT16:
+ if (buf instanceof ShortBuffer) {
+ ShortBuffer fp16Buf = (ShortBuffer) buf;
+ return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
+ } else if (buf instanceof ByteBuffer) {
+ ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer();
+ return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf);
+ }
+ break;
+ case INT16:
+ if (buf instanceof ShortBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asShortBuffer();
+ }
+ break;
+ case INT32:
+ if (buf instanceof IntBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asIntBuffer();
+ }
+ break;
+ case INT64:
+ if (buf instanceof LongBuffer) {
+ return buf;
+ } else if (buf instanceof ByteBuffer) {
+ return ((ByteBuffer) buf).asLongBuffer();
+ }
+ break;
+ }
+ throw new IllegalStateException(
+ "Invalid buffer type for cast operation, found "
+ + buf.getClass()
+ + " expected something convertible to "
+ + info.type);
}
@Override
@@ -133,15 +242,26 @@ public Object getValue() throws OrtException {
Object carrier = info.makeCarrier();
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
- getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
- }
- if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
- // We read the strings out from native code in a flat array and then reshape
- // to the desired output shape.
- return OrtUtil.reshape((String[]) carrier, info.shape);
- } else {
- return carrier;
+ if (info.type == OnnxJavaType.STRING) {
+ // We read the strings out from native code in a flat array and then reshape
+ // to the desired output shape if necessary.
+ getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier);
+ if (info.shape.length != 1) {
+ carrier = OrtUtil.reshape((String[]) carrier, info.shape);
+ }
+ } else {
+ // Wrap ORT owned memory in buffer, otherwise use our reference
+ Buffer buf;
+ if (buffer == null) {
+ buf = castBuffer(getBuffer());
+ } else {
+ buf = castBuffer(duplicate(buffer));
+ }
+ // Copy out buffer into arrays
+ OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier);
+ }
}
+ return carrier;
}
}
@@ -175,8 +295,8 @@ public synchronized void close() {
public ByteBuffer getByteBuffer() {
checkClosed();
if (info.type != OnnxJavaType.STRING) {
- ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
- ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
+ ByteBuffer buffer = getBuffer();
+ ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder());
output.put(buffer);
output.rewind();
return output;
@@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() {
output.rewind();
return output;
} else if (info.type == OnnxJavaType.FLOAT16) {
- // if it's fp16 we need to copy it out by hand.
+ // if it's fp16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer);
} else if (info.type == OnnxJavaType.BFLOAT16) {
- // if it's bf16 we need to copy it out by hand.
+ // if it's bf16 we need to convert it.
ByteBuffer buf = getBuffer();
ShortBuffer buffer = buf.asShortBuffer();
return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer);
@@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)
private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;
- private native void getArray(long apiHandle, long nativeHandle, Object carrier)
+ private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier)
throws OrtException;
private native void close(long apiHandle, long nativeHandle);
@@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec
info);
}
} else {
+ Buffer buf;
if (info.shape.length == 0) {
- data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data);
- if (data == null) {
+ buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data);
+ if (buf == null) {
throw new OrtException(
"Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = "
+ info.type
+ ", object = "
+ data);
}
+ } else {
+ buf = OrtUtil.convertArrayToBuffer(info, data);
}
return new OnnxTensor(
- createTensor(
- OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value),
+ createTensorFromBuffer(
+ OnnxRuntime.ortApiHandle,
+ allocator.handle,
+ buf,
+ 0,
+ info.type.size * info.numElements,
+ info.shape,
+ info.onnxType.value),
allocator.handle,
- info);
+ info,
+ buf,
+ true);
}
} else {
throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator.");
@@ -627,7 +758,26 @@ static OnnxTensor createTensor(
*/
public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape)
throws OrtException {
- return createTensor(env, env.defaultAllocator, data, shape);
+ return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16);
+ }
+
+ /**
+ * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder.
+ *
+ * If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime
+ * of the tensor. Uses the default allocator.
+ *
+ * @param env The current OrtEnvironment.
+ * @param data The tensor data.
+ * @param shape The shape of tensor.
+ * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
+ * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
+ * @return An OnnxTensor of the required shape.
+ * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
+ */
+ public static OnnxTensor createTensor(
+ OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException {
+ return createTensor(env, env.defaultAllocator, data, shape, type);
}
/**
@@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long
* @param allocator The allocator to use.
* @param data The tensor data.
* @param shape The shape of tensor.
+ * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16},
+ * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}.
* @return An OnnxTensor of the required shape.
* @throws OrtException Thrown if there is an onnx error or if the data and shape don't match.
*/
static OnnxTensor createTensor(
- OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape)
+ OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type)
throws OrtException {
if (!allocator.isClosed()) {
- OnnxJavaType type = OnnxJavaType.INT16;
- return createTensor(type, allocator, data, shape);
+ if ((type == OnnxJavaType.BFLOAT16)
+ || (type == OnnxJavaType.FLOAT16)
+ || (type == OnnxJavaType.INT16)) {
+ return createTensor(type, allocator, data, shape);
+ } else {
+ throw new IllegalArgumentException(
+ "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer.");
+ }
} else {
throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
}
@@ -768,10 +926,6 @@ private static OnnxTensor createTensor(
tuple.isCopy);
}
- private static native long createTensor(
- long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType)
- throws OrtException;
-
private static native long createTensorFromBuffer(
long apiHandle,
long allocatorHandle,
diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java
index 4f3dee3c00b91..2f44236e4ef67 100644
--- a/java/src/main/java/ai/onnxruntime/OrtUtil.java
+++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java
@@ -26,10 +26,10 @@ public final class OrtUtil {
private OrtUtil() {}
/**
- * Converts an long shape into a int shape.
+ * Converts a long shape into an int shape.
*
- * Validates that the shape has more than 1 elements, less than 9 elements, each element is
- * less than {@link Integer#MAX_VALUE} and that each entry is non-negative.
+ * Validates that the shape has more than 1 element, less than 9 elements, each element is less
+ * than {@link Integer#MAX_VALUE} and that each entry is non-negative.
*
* @param shape The long shape.
* @return The int shape.
@@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) {
}
}
+ /**
+ * Stores a boxed primitive in a single element buffer of the unboxed type.
+ *
+ * If it's not a boxed primitive then it returns null.
+ *
+ * @param javaType The type of the boxed primitive.
+ * @param data The boxed primitive.
+ * @return The primitive in a direct buffer.
+ */
+ static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) {
+ switch (javaType) {
+ case FLOAT:
+ {
+ FloatBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asFloatBuffer();
+ buf.put(0, (Float) data);
+ return buf;
+ }
+ case DOUBLE:
+ {
+ DoubleBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asDoubleBuffer();
+ buf.put(0, (Double) data);
+ return buf;
+ }
+ case BOOL:
+ {
+ ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
+ buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0);
+ return buf;
+ }
+ case UINT8:
+ case INT8:
+ {
+ ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder());
+ buf.put(0, (Byte) data);
+ return buf;
+ }
+ case FLOAT16:
+ case BFLOAT16:
+ case INT16:
+ {
+ ShortBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asShortBuffer();
+ buf.put(0, (Short) data);
+ return buf;
+ }
+ case INT32:
+ {
+ IntBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer();
+ buf.put(0, (Integer) data);
+ return buf;
+ }
+ case INT64:
+ {
+ LongBuffer buf =
+ ByteBuffer.allocateDirect(javaType.size)
+ .order(ByteOrder.nativeOrder())
+ .asLongBuffer();
+ buf.put(0, (Long) data);
+ return buf;
+ }
+ case STRING:
+ case UNKNOWN:
+ default:
+ return null;
+ }
+ }
+
+ /**
+ * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}.
+ *
+ * Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or
+ * if the array is ragged.
+ *
+ * @param info The tensor info object containing the types and shape of the array.
+ * @param array The array object.
+ * @return A direct buffer containing all the elements.
+ */
+ static Buffer convertArrayToBuffer(TensorInfo info, Object array) {
+ ByteBuffer byteBuffer =
+ ByteBuffer.allocateDirect((int) info.numElements * info.type.size)
+ .order(ByteOrder.nativeOrder());
+
+ Buffer buffer;
+ switch (info.type) {
+ case FLOAT:
+ buffer = byteBuffer.asFloatBuffer();
+ break;
+ case DOUBLE:
+ buffer = byteBuffer.asDoubleBuffer();
+ break;
+ case BOOL:
+ case INT8:
+ case UINT8:
+ // no-op, it's already a bytebuffer
+ buffer = byteBuffer;
+ break;
+ case BFLOAT16:
+ case FLOAT16:
+ case INT16:
+ buffer = byteBuffer.asShortBuffer();
+ break;
+ case INT32:
+ buffer = byteBuffer.asIntBuffer();
+ break;
+ case INT64:
+ buffer = byteBuffer.asLongBuffer();
+ break;
+ case STRING:
+ case UNKNOWN:
+ default:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+
+ fillBufferFromArray(info, array, 0, buffer);
+
+ if (buffer.remaining() != 0) {
+ throw new IllegalArgumentException(
+ "Failed to copy all elements into the buffer, expected to copy "
+ + info.numElements
+ + " into a buffer of capacity "
+ + buffer.capacity()
+ + " but had "
+ + buffer.remaining()
+ + " values left over.");
+ }
+ buffer.rewind();
+
+ return buffer;
+ }
+
+ /**
+ * Fills the provided buffer with the values from the array, recursing through the array
+ * structure.
+ *
+ * @param info The tensor info containing the type and shape of the array.
+ * @param array The array object to read from.
+ * @param curDim The current dimension we're processing.
+ * @param buffer The buffer to write to.
+ */
+ private static void fillBufferFromArray(
+ TensorInfo info, Object array, int curDim, Buffer buffer) {
+ if (curDim == info.shape.length - 1) {
+ // Reached primitive values, copy into buffer
+ switch (info.type) {
+ case FLOAT:
+ float[] fArr = (float[]) array;
+ FloatBuffer fBuf = (FloatBuffer) buffer;
+ fBuf.put(fArr);
+ break;
+ case DOUBLE:
+ double[] dArr = (double[]) array;
+ DoubleBuffer dBuf = (DoubleBuffer) buffer;
+ dBuf.put(dArr);
+ break;
+ case INT8:
+ case UINT8:
+ byte[] bArr = (byte[]) array;
+ ByteBuffer bBuf = (ByteBuffer) buffer;
+ bBuf.put(bArr);
+ break;
+ case FLOAT16:
+ case BFLOAT16:
+ case INT16:
+ short[] sArr = (short[]) array;
+ ShortBuffer sBuf = (ShortBuffer) buffer;
+ sBuf.put(sArr);
+ break;
+ case INT32:
+ int[] iArr = (int[]) array;
+ IntBuffer iBuf = (IntBuffer) buffer;
+ iBuf.put(iArr);
+ break;
+ case INT64:
+ long[] lArr = (long[]) array;
+ LongBuffer lBuf = (LongBuffer) buffer;
+ lBuf.put(lArr);
+ break;
+ case BOOL:
+ boolean[] boolArr = (boolean[]) array;
+ ByteBuffer boolBuf = (ByteBuffer) buffer;
+ for (int i = 0; i < boolArr.length; i++) {
+ boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0);
+ }
+ break;
+ case STRING:
+ case UNKNOWN:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+ } else {
+ // Recurse through array
+ long expectedSize = info.shape[curDim];
+ long actualSize = Array.getLength(array);
+ if (expectedSize != actualSize) {
+ throw new IllegalArgumentException(
+ "Mismatch in array sizes, expected "
+ + expectedSize
+ + " at dim "
+ + curDim
+ + " from shape "
+ + Arrays.toString(info.shape)
+ + ", found "
+ + actualSize);
+ } else {
+ for (int i = 0; i < actualSize; i++) {
+ fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer);
+ }
+ }
+ }
+ }
+
+ /**
+ * Fills the provided array with the values from the buffer, recursing through the array
+ * structure.
+ *
+ * @param info The tensor info containing the type and shape of the array.
+ * @param buffer The buffer to read from.
+ * @param curDim The current dimension we're processing.
+ * @param array The array object to write to.
+ */
+ static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) {
+ if (curDim == info.shape.length - 1) {
+ // Reached primitive values, copy into buffer
+ switch (info.type) {
+ case FLOAT16:
+ case BFLOAT16:
+ case FLOAT:
+ float[] fArr = (float[]) array;
+ FloatBuffer fBuf = (FloatBuffer) buffer;
+ fBuf.get(fArr);
+ break;
+ case DOUBLE:
+ double[] dArr = (double[]) array;
+ DoubleBuffer dBuf = (DoubleBuffer) buffer;
+ dBuf.get(dArr);
+ break;
+ case INT8:
+ case UINT8:
+ byte[] bArr = (byte[]) array;
+ ByteBuffer bBuf = (ByteBuffer) buffer;
+ bBuf.get(bArr);
+ break;
+ case INT16:
+ short[] sArr = (short[]) array;
+ ShortBuffer sBuf = (ShortBuffer) buffer;
+ sBuf.get(sArr);
+ break;
+ case INT32:
+ int[] iArr = (int[]) array;
+ IntBuffer iBuf = (IntBuffer) buffer;
+ iBuf.get(iArr);
+ break;
+ case INT64:
+ long[] lArr = (long[]) array;
+ LongBuffer lBuf = (LongBuffer) buffer;
+ lBuf.get(lArr);
+ break;
+ case BOOL:
+ boolean[] boolArr = (boolean[]) array;
+ ByteBuffer boolBuf = (ByteBuffer) buffer;
+ for (int i = 0; i < boolArr.length; i++) {
+ // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false.
+ boolArr[i] = boolBuf.get() != 0;
+ }
+ break;
+ case STRING:
+ case UNKNOWN:
+ throw new IllegalArgumentException(
+ "Unexpected type, expected Java primitive found " + info.type);
+ }
+ } else {
+ // Recurse through array
+ long expectedSize = info.shape[curDim];
+ long actualSize = Array.getLength(array);
+ if (expectedSize != actualSize) {
+ throw new IllegalArgumentException(
+ "Mismatch in array sizes, expected "
+ + expectedSize
+ + " at dim "
+ + curDim
+ + " from shape "
+ + Arrays.toString(info.shape)
+ + ", found "
+ + actualSize);
+ } else {
+ for (int i = 0; i < actualSize; i++) {
+ fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i));
+ }
+ }
+ }
+ }
+
/**
* Returns expected JDK map capacity for a given size, this factors in the default JDK load factor
*
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index 1c21387b50455..f3e9f21ef408d 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -323,6 +323,9 @@ public long getNumElements() {
* all elements as that's the expected format of the native code. It can be reshaped to the
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
+ * For fp16 and bf16 tensors the output carrier type is float, and so this method produces
+ * multidimensional float arrays.
+ *
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
* greater than an int).
@@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException {
+ Arrays.toString(shape));
}
switch (type) {
+ case BFLOAT16:
+ case FLOAT16:
case FLOAT:
return OrtUtil.newFloatArray(shape);
case DOUBLE:
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 7b26291581395..6a3c279073860 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque
return sequenceInfo;
}
-int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) {
- int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray);
- int64_t consumedSize = inputLength * onnxTypeSize(onnxType);
- switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t
- jbyteArray typedArr = (jbyteArray)inputArray;
- (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t
- jshortArray typedArr = (jshortArray)inputArray;
- (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t
- jintArray typedArr = (jintArray)inputArray;
- (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t
- jlongArray typedArr = (jlongArray)inputArray;
- (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: {
- throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported.");
- return -1;
- /*
- float *floatArr = malloc(sizeof(float) * inputLength);
- uint16_t *halfArr = (uint16_t *) outputTensor;
- for (uint32_t i = 0; i < inputLength; i++) {
- floatArr[i] = convertHalfToFloat(halfArr[i]);
- }
- jfloatArray typedArr = (jfloatArray) inputArray;
- (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr);
- free(floatArr);
- return consumedSize;
- */
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float
- jfloatArray typedArr = (jfloatArray)inputArray;
- (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double
- jdoubleArray typedArr = (jdoubleArray)inputArray;
- (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string
- throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported.");
- return -1;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
- jbooleanArray typedArr = (jbooleanArray)inputArray;
- (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor);
- return consumedSize;
- }
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED:
- default: {
- throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type.");
- return -1;
- }
- }
-}
-
-int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) {
- if (dimensionsRemaining == 1) {
- // write out 1d array of the respective primitive type
- return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor);
- } else {
- // recurse through the dimensions
- // Java arrays are objects until the final dimension
- jobjectArray inputObjArr = (jobjectArray)inputArray;
- int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr);
- int64_t sizeConsumed = 0;
- for (int32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i);
- int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed);
- sizeConsumed += consumed;
- // Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
- // If we failed to copy an array then break and return.
- if (consumed == -1) {
- return -1;
- }
- }
- return sizeConsumed;
- }
-}
-
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) {
int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray);
if (outputLength == 0) return 0;
@@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT
}
}
-int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize,
- size_t dimensionsRemaining, jarray outputArray) {
- if (dimensionsRemaining == 1) {
- // write out 1d array of the respective primitive type
- return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray);
- } else {
- // recurse through the dimensions
- // Java arrays are objects until the final dimension
- jobjectArray outputObjArr = (jobjectArray)outputArray;
- int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr);
- int64_t sizeConsumed = 0;
- for (int32_t i = 0; i < dimLength; i++) {
- jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i);
- int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr);
- sizeConsumed += consumed;
- // Cleanup reference to childArr so it doesn't prevent GC.
- (*jniEnv)->DeleteLocalRef(jniEnv, childArr);
- // If we failed to copy an array then break and return.
- if (consumed == -1) {
- return -1;
- }
- }
- return sizeConsumed;
- }
-}
-
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) {
- jobject tempString = NULL;
- // Get the buffer size needed
- size_t totalStringLength = 0;
- OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength));
- if (code != ORT_OK) {
- return NULL;
- }
-
- // Create the character and offset buffers, character is one larger to allow zero termination.
- char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1));
- if (characterBuffer == NULL) {
- throwOrtException(jniEnv, 1, "OOM error");
- } else {
- size_t * offsets = malloc(sizeof(size_t));
- if (offsets != NULL) {
- // Get a view on the String data
- code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1));
-
- if (code == ORT_OK) {
- size_t curSize = (offsets[0]) + 1;
- characterBuffer[curSize-1] = '\0';
- tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer);
- }
-
- free((void*)characterBuffer);
- free((void*)offsets);
- }
- }
-
- return tempString;
-}
-
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) {
size_t bufferSize = 16;
char * tempBuffer = malloc(bufferSize);
diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h
index 023bc0c739583..7f41e06371f2a 100644
--- a/java/src/main/native/OrtJniUtil.h
+++ b/java/src/main/native/OrtJniUtil.h
@@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf
jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info);
-int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor);
-
-int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor);
-
int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray);
-int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray);
-
-jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
-
OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray);
jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor);
diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
index b694f57357bb5..d757bd6281499 100644
--- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c
+++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c
@@ -8,72 +8,6 @@
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OnnxTensor.h"
-/*
- * Class: ai_onnxruntime_OnnxTensor
- * Method: createTensor
- * Signature: (JJLjava/lang/Object;[JI)J
- */
-JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor
- (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj,
- jlongArray shape, jint onnxTypeJava) {
- (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
- const OrtApi* api = (const OrtApi*) apiHandle;
- OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
- // Convert type to ONNX C enum
- ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava);
-
- // Extract the shape information
- jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL);
- jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape);
-
- // Create the OrtValue
- OrtValue* ortValue = NULL;
- OrtErrorCode code = checkOrtStatus(jniEnv, api,
- api->CreateTensorAsOrtValue(
- allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue
- )
- );
- (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT);
-
- int failed = 0;
- if (code == ORT_OK) {
- // Get a reference to the OrtValue's data
- uint8_t* tensorData = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData));
- if (code == ORT_OK) {
- // Check if we're copying a scalar or not
- if (shapeLen == 0) {
- // Scalars are passed in as a single element array
- int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData);
- failed = copied == -1 ? 1 : failed;
- } else {
- // Extract the tensor shape information
- JavaTensorTypeShape typeShape;
- code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue);
-
- if (code == ORT_OK) {
- // Copy the java array into the tensor
- int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount,
- typeShape.dimensions, dataObj, tensorData);
- failed = copied == -1 ? 1 : failed;
- } else {
- failed = 1;
- }
- }
- } else {
- failed = 1;
- }
- }
-
- if (failed) {
- api->ReleaseValue(ortValue);
- ortValue = NULL;
- }
-
- // Return the pointer to the OrtValue
- return (jlong) ortValue;
-}
-
/*
* Class: ai_onnxruntime_OnnxTensor
* Method: createTensorFromBuffer
@@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer
size_t sizeBytes = typeShape.elementCount * typeSize;
uint8_t* arr = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr));
+ code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr));
if (code == ORT_OK) {
return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes);
@@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool
/*
* Class: ai_onnxruntime_OnnxTensor
- * Method: getArray
- * Signature: (JJLjava/lang/Object;)V
+ * Method: getStringArray
+ * Signature: (JJ[Ljava/lang/String;)V
*/
-JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
- (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) {
+JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray
+ (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtValue* value = (OrtValue*) handle;
@@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray
if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier);
} else {
- uint8_t* arr = NULL;
- code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr));
- if (code == ORT_OK) {
- copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount,
- typeShape.dimensions, (jarray)carrier);
- }
+ throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here.");
}
}
}
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index 11141a3a65a3e..7cb6305923279 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException {
container.put("wrong_name", OnnxTensor.createTensor(env, tensor));
try {
session.run(container);
- OnnxValue.close(container.values());
fail("Should throw exception for incorrect name.");
} catch (OrtException e) {
- OnnxValue.close(container.values());
String msg = e.getMessage();
assertTrue(msg.contains("Unknown input name"));
+ } finally {
+ OnnxValue.close(container.values());
}
}
}
@@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException {
container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor));
try {
session.run(container);
- OnnxValue.close(container.values());
fail("Should throw exception for incorrect type.");
} catch (OrtException e) {
- OnnxValue.close(container.values());
String msg = e.getMessage();
assertTrue(msg.contains("Unexpected input data type"));
+ } finally {
+ OnnxValue.close(container.values());
+ }
+ }
+ }
+
+ @Test
+ public void throwWrongSizeInput() throws OrtException {
+ SqueezeNetTuple tuple = openSessionSqueezeNet();
+ try (OrtSession session = tuple.session) {
+
+ float[] inputData = tuple.inputData;
+ NodeInfo inputMeta = session.getInputInfo().values().iterator().next();
+ Map