diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index e1ee2c14fd9d1..3f276a3670156 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -54,7 +54,7 @@ public class OnnxTensor extends OnnxTensorLike { * the state of this buffer without first getting the reference via {@link #getBufferRef()}. * * @return True if the buffer in this OnnxTensor was allocated by it on construction (i.e., it is - * a copy of a user buffer.) + * a copy of a user buffer or array.) */ public boolean ownsBuffer() { return this.ownsBuffer; @@ -62,8 +62,8 @@ public boolean ownsBuffer() { /** * Returns a reference to the buffer which backs this {@code OnnxTensor}. If the tensor is not - * backed by a buffer (i.e., it was created from a Java array, or is backed by memory allocated by - * ORT) this method returns an empty {@link Optional}. + * backed by a buffer (i.e., it is backed by memory allocated by ORT) this method returns an empty + * {@link Optional}. * *

Changes to the buffer elements will be reflected in the native {@code OrtValue}, this can be * used to repeatedly update a single tensor for multiple different inferences without allocating @@ -77,7 +77,116 @@ public boolean ownsBuffer() { * @return A reference to the buffer. */ public Optional getBufferRef() { - return Optional.ofNullable(buffer); + return Optional.ofNullable(duplicate(buffer)); + } + + /** + * Duplicates the buffer to ensure concurrent reads don't disrupt the buffer position. Concurrent + * writes will modify the underlying memory in a racy way, don't do that. + * + *

Can be replaced to a call to buf.duplicate() in Java 9+. + * + * @param buf The buffer to duplicate. + * @return A copy of the buffer which refers to the same underlying memory, but has an independent + * position, limit and mark. + */ + private static Buffer duplicate(Buffer buf) { + if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).duplicate().order(ByteOrder.nativeOrder()); + } else if (buf instanceof ShortBuffer) { + return ((ShortBuffer) buf).duplicate(); + } else if (buf instanceof IntBuffer) { + return ((IntBuffer) buf).duplicate(); + } else if (buf instanceof LongBuffer) { + return ((LongBuffer) buf).duplicate(); + } else if (buf instanceof FloatBuffer) { + return ((FloatBuffer) buf).duplicate(); + } else if (buf instanceof DoubleBuffer) { + return ((DoubleBuffer) buf).duplicate(); + } else { + throw new IllegalStateException("Unknown buffer type " + buf.getClass()); + } + } + + /** + * Checks that the buffer is the right type for the {@code info.type}, and if it's a {@link + * ByteBuffer} then convert it to the right type. If it's not convertible it throws {@link + * IllegalStateException}. + * + *

Note this method converts FP16 and BFLOAT16 ShortBuffers into FP32 FloatBuffers, to preserve + * compatibility with existing {@link #getValue} calls. + * + * @param buf The buffer to convert. + * @return The buffer with the expected type. + */ + private Buffer castBuffer(Buffer buf) { + switch (info.type) { + case FLOAT: + if (buf instanceof FloatBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asFloatBuffer(); + } + break; + case DOUBLE: + if (buf instanceof DoubleBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asDoubleBuffer(); + } + break; + case BOOL: + case INT8: + case UINT8: + if (buf instanceof ByteBuffer) { + return buf; + } + break; + case BFLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer bf16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer bf16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertBf16BufferToFloatBuffer(bf16Buf); + } + break; + case FLOAT16: + if (buf instanceof ShortBuffer) { + ShortBuffer fp16Buf = (ShortBuffer) buf; + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } else if (buf instanceof ByteBuffer) { + ShortBuffer fp16Buf = ((ByteBuffer) buf).asShortBuffer(); + return Fp16Conversions.convertFp16BufferToFloatBuffer(fp16Buf); + } + break; + case INT16: + if (buf instanceof ShortBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asShortBuffer(); + } + break; + case INT32: + if (buf instanceof IntBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asIntBuffer(); + } + break; + case INT64: + if (buf instanceof LongBuffer) { + return buf; + } else if (buf instanceof ByteBuffer) { + return ((ByteBuffer) buf).asLongBuffer(); + } + break; + } + throw new IllegalStateException( + "Invalid buffer type for cast operation, found " + + buf.getClass() + + " expected something convertible to " + + info.type); } @Override @@ -133,15 +242,26 @@ public Object getValue() throws OrtException { Object carrier = info.makeCarrier(); if (info.getNumElements() > 0) { // If the tensor has values copy them out - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); - } - if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { - // We read the strings out from native code in a flat array and then reshape - // to the desired output shape. - return OrtUtil.reshape((String[]) carrier, info.shape); - } else { - return carrier; + if (info.type == OnnxJavaType.STRING) { + // We read the strings out from native code in a flat array and then reshape + // to the desired output shape if necessary. + getStringArray(OnnxRuntime.ortApiHandle, nativeHandle, (String[]) carrier); + if (info.shape.length != 1) { + carrier = OrtUtil.reshape((String[]) carrier, info.shape); + } + } else { + // Wrap ORT owned memory in buffer, otherwise use our reference + Buffer buf; + if (buffer == null) { + buf = castBuffer(getBuffer()); + } else { + buf = castBuffer(duplicate(buffer)); + } + // Copy out buffer into arrays + OrtUtil.fillArrayFromBuffer(info, buf, 0, carrier); + } } + return carrier; } } @@ -175,8 +295,8 @@ public synchronized void close() { public ByteBuffer getByteBuffer() { checkClosed(); if (info.type != OnnxJavaType.STRING) { - ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); - ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); + ByteBuffer buffer = getBuffer(); + ByteBuffer output = ByteBuffer.allocate(buffer.capacity()).order(ByteOrder.nativeOrder()); output.put(buffer); output.rewind(); return output; @@ -201,12 +321,12 @@ public FloatBuffer getFloatBuffer() { output.rewind(); return output; } else if (info.type == OnnxJavaType.FLOAT16) { - // if it's fp16 we need to copy it out by hand. + // if it's fp16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertFp16BufferToFloatBuffer(buffer); } else if (info.type == OnnxJavaType.BFLOAT16) { - // if it's bf16 we need to copy it out by hand. + // if it's bf16 we need to convert it. ByteBuffer buf = getBuffer(); ShortBuffer buffer = buf.asShortBuffer(); return Fp16Conversions.convertBf16BufferToFloatBuffer(buffer); @@ -331,7 +451,7 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType) private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException; - private native void getArray(long apiHandle, long nativeHandle, Object carrier) + private native void getStringArray(long apiHandle, long nativeHandle, String[] carrier) throws OrtException; private native void close(long apiHandle, long nativeHandle); @@ -387,21 +507,32 @@ static OnnxTensor createTensor(OrtEnvironment env, OrtAllocator allocator, Objec info); } } else { + Buffer buf; if (info.shape.length == 0) { - data = OrtUtil.convertBoxedPrimitiveToArray(info.type, data); - if (data == null) { + buf = OrtUtil.convertBoxedPrimitiveToBuffer(info.type, data); + if (buf == null) { throw new OrtException( "Failed to convert a boxed primitive to an array, this is an error with the ORT Java API, please report this message & stack trace. JavaType = " + info.type + ", object = " + data); } + } else { + buf = OrtUtil.convertArrayToBuffer(info, data); } return new OnnxTensor( - createTensor( - OnnxRuntime.ortApiHandle, allocator.handle, data, info.shape, info.onnxType.value), + createTensorFromBuffer( + OnnxRuntime.ortApiHandle, + allocator.handle, + buf, + 0, + info.type.size * info.numElements, + info.shape, + info.onnxType.value), allocator.handle, - info); + info, + buf, + true); } } else { throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator."); @@ -627,7 +758,26 @@ static OnnxTensor createTensor( */ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long[] shape) throws OrtException { - return createTensor(env, env.defaultAllocator, data, shape); + return createTensor(env, env.defaultAllocator, data, shape, OnnxJavaType.INT16); + } + + /** + * Create an OnnxTensor backed by a direct ShortBuffer. The buffer should be in nativeOrder. + * + *

If the supplied buffer is not a direct buffer, a direct copy is created tied to the lifetime + * of the tensor. Uses the default allocator. + * + * @param env The current OrtEnvironment. + * @param data The tensor data. + * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. + * @return An OnnxTensor of the required shape. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + public static OnnxTensor createTensor( + OrtEnvironment env, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { + return createTensor(env, env.defaultAllocator, data, shape, type); } /** @@ -640,15 +790,23 @@ public static OnnxTensor createTensor(OrtEnvironment env, ShortBuffer data, long * @param allocator The allocator to use. * @param data The tensor data. * @param shape The shape of tensor. + * @param type The type of the data in the buffer, can be either {@link OnnxJavaType#INT16}, + * {@link OnnxJavaType#FLOAT16} or {@link OnnxJavaType#BFLOAT16}. * @return An OnnxTensor of the required shape. * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. */ static OnnxTensor createTensor( - OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape) + OrtEnvironment env, OrtAllocator allocator, ShortBuffer data, long[] shape, OnnxJavaType type) throws OrtException { if (!allocator.isClosed()) { - OnnxJavaType type = OnnxJavaType.INT16; - return createTensor(type, allocator, data, shape); + if ((type == OnnxJavaType.BFLOAT16) + || (type == OnnxJavaType.FLOAT16) + || (type == OnnxJavaType.INT16)) { + return createTensor(type, allocator, data, shape); + } else { + throw new IllegalArgumentException( + "Only int16, float16 or bfloat16 tensors can be created from ShortBuffer."); + } } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); } @@ -768,10 +926,6 @@ private static OnnxTensor createTensor( tuple.isCopy); } - private static native long createTensor( - long apiHandle, long allocatorHandle, Object data, long[] shape, int onnxType) - throws OrtException; - private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 4f3dee3c00b91..2f44236e4ef67 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -26,10 +26,10 @@ public final class OrtUtil { private OrtUtil() {} /** - * Converts an long shape into a int shape. + * Converts a long shape into an int shape. * - *

Validates that the shape has more than 1 elements, less than 9 elements, each element is - * less than {@link Integer#MAX_VALUE} and that each entry is non-negative. + *

Validates that the shape has more than 1 element, less than 9 elements, each element is less + * than {@link Integer#MAX_VALUE} and that each entry is non-negative. * * @param shape The long shape. * @return The int shape. @@ -460,6 +460,308 @@ static Object convertBoxedPrimitiveToArray(OnnxJavaType javaType, Object data) { } } + /** + * Stores a boxed primitive in a single element buffer of the unboxed type. + * + *

If it's not a boxed primitive then it returns null. + * + * @param javaType The type of the boxed primitive. + * @param data The boxed primitive. + * @return The primitive in a direct buffer. + */ + static Buffer convertBoxedPrimitiveToBuffer(OnnxJavaType javaType, Object data) { + switch (javaType) { + case FLOAT: + { + FloatBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asFloatBuffer(); + buf.put(0, (Float) data); + return buf; + } + case DOUBLE: + { + DoubleBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer(); + buf.put(0, (Double) data); + return buf; + } + case BOOL: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, ((boolean) data) ? (byte) 1 : (byte) 0); + return buf; + } + case UINT8: + case INT8: + { + ByteBuffer buf = ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()); + buf.put(0, (Byte) data); + return buf; + } + case FLOAT16: + case BFLOAT16: + case INT16: + { + ShortBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asShortBuffer(); + buf.put(0, (Short) data); + return buf; + } + case INT32: + { + IntBuffer buf = + ByteBuffer.allocateDirect(javaType.size).order(ByteOrder.nativeOrder()).asIntBuffer(); + buf.put(0, (Integer) data); + return buf; + } + case INT64: + { + LongBuffer buf = + ByteBuffer.allocateDirect(javaType.size) + .order(ByteOrder.nativeOrder()) + .asLongBuffer(); + buf.put(0, (Long) data); + return buf; + } + case STRING: + case UNKNOWN: + default: + return null; + } + } + + /** + * Copies a Java (possibly multidimensional) array into a direct {@link Buffer}. + * + *

Throws {@link IllegalArgumentException} if the array is not an array of Java primitives or + * if the array is ragged. + * + * @param info The tensor info object containing the types and shape of the array. + * @param array The array object. + * @return A direct buffer containing all the elements. + */ + static Buffer convertArrayToBuffer(TensorInfo info, Object array) { + ByteBuffer byteBuffer = + ByteBuffer.allocateDirect((int) info.numElements * info.type.size) + .order(ByteOrder.nativeOrder()); + + Buffer buffer; + switch (info.type) { + case FLOAT: + buffer = byteBuffer.asFloatBuffer(); + break; + case DOUBLE: + buffer = byteBuffer.asDoubleBuffer(); + break; + case BOOL: + case INT8: + case UINT8: + // no-op, it's already a bytebuffer + buffer = byteBuffer; + break; + case BFLOAT16: + case FLOAT16: + case INT16: + buffer = byteBuffer.asShortBuffer(); + break; + case INT32: + buffer = byteBuffer.asIntBuffer(); + break; + case INT64: + buffer = byteBuffer.asLongBuffer(); + break; + case STRING: + case UNKNOWN: + default: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + + fillBufferFromArray(info, array, 0, buffer); + + if (buffer.remaining() != 0) { + throw new IllegalArgumentException( + "Failed to copy all elements into the buffer, expected to copy " + + info.numElements + + " into a buffer of capacity " + + buffer.capacity() + + " but had " + + buffer.remaining() + + " values left over."); + } + buffer.rewind(); + + return buffer; + } + + /** + * Fills the provided buffer with the values from the array, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param array The array object to read from. + * @param curDim The current dimension we're processing. + * @param buffer The buffer to write to. + */ + private static void fillBufferFromArray( + TensorInfo info, Object array, int curDim, Buffer buffer) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.put(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.put(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.put(bArr); + break; + case FLOAT16: + case BFLOAT16: + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.put(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.put(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.put(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + boolBuf.put(boolArr[i] ? (byte) 1 : (byte) 0); + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillBufferFromArray(info, Array.get(array, i), curDim + 1, buffer); + } + } + } + } + + /** + * Fills the provided array with the values from the buffer, recursing through the array + * structure. + * + * @param info The tensor info containing the type and shape of the array. + * @param buffer The buffer to read from. + * @param curDim The current dimension we're processing. + * @param array The array object to write to. + */ + static void fillArrayFromBuffer(TensorInfo info, Buffer buffer, int curDim, Object array) { + if (curDim == info.shape.length - 1) { + // Reached primitive values, copy into buffer + switch (info.type) { + case FLOAT16: + case BFLOAT16: + case FLOAT: + float[] fArr = (float[]) array; + FloatBuffer fBuf = (FloatBuffer) buffer; + fBuf.get(fArr); + break; + case DOUBLE: + double[] dArr = (double[]) array; + DoubleBuffer dBuf = (DoubleBuffer) buffer; + dBuf.get(dArr); + break; + case INT8: + case UINT8: + byte[] bArr = (byte[]) array; + ByteBuffer bBuf = (ByteBuffer) buffer; + bBuf.get(bArr); + break; + case INT16: + short[] sArr = (short[]) array; + ShortBuffer sBuf = (ShortBuffer) buffer; + sBuf.get(sArr); + break; + case INT32: + int[] iArr = (int[]) array; + IntBuffer iBuf = (IntBuffer) buffer; + iBuf.get(iArr); + break; + case INT64: + long[] lArr = (long[]) array; + LongBuffer lBuf = (LongBuffer) buffer; + lBuf.get(lArr); + break; + case BOOL: + boolean[] boolArr = (boolean[]) array; + ByteBuffer boolBuf = (ByteBuffer) buffer; + for (int i = 0; i < boolArr.length; i++) { + // Test to see if the byte is non-zero, non-zero bytes are true, zero bytes are false. + boolArr[i] = boolBuf.get() != 0; + } + break; + case STRING: + case UNKNOWN: + throw new IllegalArgumentException( + "Unexpected type, expected Java primitive found " + info.type); + } + } else { + // Recurse through array + long expectedSize = info.shape[curDim]; + long actualSize = Array.getLength(array); + if (expectedSize != actualSize) { + throw new IllegalArgumentException( + "Mismatch in array sizes, expected " + + expectedSize + + " at dim " + + curDim + + " from shape " + + Arrays.toString(info.shape) + + ", found " + + actualSize); + } else { + for (int i = 0; i < actualSize; i++) { + fillArrayFromBuffer(info, buffer, curDim + 1, Array.get(array, i)); + } + } + } + } + /** * Returns expected JDK map capacity for a given size, this factors in the default JDK load factor * diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 1c21387b50455..f3e9f21ef408d 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -323,6 +323,9 @@ public long getNumElements() { * all elements as that's the expected format of the native code. It can be reshaped to the * correct shape using {@link OrtUtil#reshape(String[],long[])}. * + *

For fp16 and bf16 tensors the output carrier type is float, and so this method produces + * multidimensional float arrays. + * * @return A multidimensional array of the appropriate primitive type (or String). * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). @@ -335,6 +338,8 @@ public Object makeCarrier() throws OrtException { + Arrays.toString(shape)); } switch (type) { + case BFLOAT16: + case FLOAT16: case FLOAT: return OrtUtil.newFloatArray(shape); case DOUBLE: diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 7b26291581395..6a3c279073860 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -502,104 +502,6 @@ jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSeque return sequenceInfo; } -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor) { - int32_t inputLength = (*jniEnv)->GetArrayLength(jniEnv, inputArray); - int64_t consumedSize = inputLength * onnxTypeSize(onnxType); - switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: // maps to c type uint8_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { // maps to c type int8_t - jbyteArray typedArr = (jbyteArray)inputArray; - (*jniEnv)->GetByteArrayRegion(jniEnv, typedArr, 0, inputLength, (jbyte * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: // maps to c type uint16_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { // maps to c type int16_t - jshortArray typedArr = (jshortArray)inputArray; - (*jniEnv)->GetShortArrayRegion(jniEnv, typedArr, 0, inputLength, (jshort * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: // maps to c type uint32_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { // maps to c type int32_t - jintArray typedArr = (jintArray)inputArray; - (*jniEnv)->GetIntArrayRegion(jniEnv, typedArr, 0, inputLength, (jint * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: // maps to c type uint64_t - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { // maps to c type int64_t - jlongArray typedArr = (jlongArray)inputArray; - (*jniEnv)->GetLongArrayRegion(jniEnv, typedArr, 0, inputLength, (jlong * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: // Non-IEEE floating-point format based on IEEE754 single-precision - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "16-bit float not supported."); - return -1; - /* - float *floatArr = malloc(sizeof(float) * inputLength); - uint16_t *halfArr = (uint16_t *) outputTensor; - for (uint32_t i = 0; i < inputLength; i++) { - floatArr[i] = convertHalfToFloat(halfArr[i]); - } - jfloatArray typedArr = (jfloatArray) inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, floatArr); - free(floatArr); - return consumedSize; - */ - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { // maps to c type float - jfloatArray typedArr = (jfloatArray)inputArray; - (*jniEnv)->GetFloatArrayRegion(jniEnv, typedArr, 0, inputLength, (jfloat * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { // maps to c type double - jdoubleArray typedArr = (jdoubleArray)inputArray; - (*jniEnv)->GetDoubleArrayRegion(jniEnv, typedArr, 0, inputLength, (jdouble * )outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: { // maps to c++ type std::string - throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "String is not supported."); - return -1; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - jbooleanArray typedArr = (jbooleanArray)inputArray; - (*jniEnv)->GetBooleanArrayRegion(jniEnv, typedArr, 0, inputLength, (jboolean *)outputTensor); - return consumedSize; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64: // complex with float32 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128: // complex with float64 real and imaginary components - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED: - default: { - throwOrtException(jniEnv, convertErrorCode(ORT_INVALID_ARGUMENT), "Invalid outputTensor element type."); - return -1; - } - } -} - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyJavaToPrimitiveArray(jniEnv, onnxType, inputArray, outputTensor); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray inputObjArr = (jobjectArray)inputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, inputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, inputObjArr, i); - int64_t consumed = copyJavaToTensor(jniEnv, onnxType, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr, outputTensor + sizeConsumed); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray) { int32_t outputLength = (*jniEnv)->GetArrayLength(jniEnv, outputArray); if (outputLength == 0) return 0; @@ -697,65 +599,6 @@ int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxT } } -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, - size_t dimensionsRemaining, jarray outputArray) { - if (dimensionsRemaining == 1) { - // write out 1d array of the respective primitive type - return copyPrimitiveArrayToJava(jniEnv, onnxType, inputTensor, outputArray); - } else { - // recurse through the dimensions - // Java arrays are objects until the final dimension - jobjectArray outputObjArr = (jobjectArray)outputArray; - int32_t dimLength = (*jniEnv)->GetArrayLength(jniEnv, outputObjArr); - int64_t sizeConsumed = 0; - for (int32_t i = 0; i < dimLength; i++) { - jarray childArr = (jarray) (*jniEnv)->GetObjectArrayElement(jniEnv, outputObjArr, i); - int64_t consumed = copyTensorToJava(jniEnv, onnxType, inputTensor + sizeConsumed, tensorSize - sizeConsumed, dimensionsRemaining - 1, childArr); - sizeConsumed += consumed; - // Cleanup reference to childArr so it doesn't prevent GC. - (*jniEnv)->DeleteLocalRef(jniEnv, childArr); - // If we failed to copy an array then break and return. - if (consumed == -1) { - return -1; - } - } - return sizeConsumed; - } -} - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor) { - jobject tempString = NULL; - // Get the buffer size needed - size_t totalStringLength = 0; - OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetStringTensorDataLength(tensor, &totalStringLength)); - if (code != ORT_OK) { - return NULL; - } - - // Create the character and offset buffers, character is one larger to allow zero termination. - char * characterBuffer = malloc(sizeof(char)*(totalStringLength+1)); - if (characterBuffer == NULL) { - throwOrtException(jniEnv, 1, "OOM error"); - } else { - size_t * offsets = malloc(sizeof(size_t)); - if (offsets != NULL) { - // Get a view on the String data - code = checkOrtStatus(jniEnv, api, api->GetStringTensorContent(tensor, characterBuffer, totalStringLength, offsets, 1)); - - if (code == ORT_OK) { - size_t curSize = (offsets[0]) + 1; - characterBuffer[curSize-1] = '\0'; - tempString = (*jniEnv)->NewStringUTF(jniEnv, characterBuffer); - } - - free((void*)characterBuffer); - free((void*)offsets); - } - } - - return tempString; -} - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray) { size_t bufferSize = 16; char * tempBuffer = malloc(bufferSize); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 023bc0c739583..7f41e06371f2a 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -54,16 +54,8 @@ jobject convertToMapInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtMapTypeInf jobject convertToSequenceInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtSequenceTypeInfo * info); -int64_t copyJavaToPrimitiveArray(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, jarray inputArray, uint8_t* outputTensor); - -int64_t copyJavaToTensor(JNIEnv* jniEnv, ONNXTensorElementDataType onnxType, size_t tensorSize, size_t dimensionsRemaining, jarray inputArray, uint8_t* outputTensor); - int64_t copyPrimitiveArrayToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, jarray outputArray); -int64_t copyTensorToJava(JNIEnv *jniEnv, ONNXTensorElementDataType onnxType, const uint8_t* inputTensor, size_t tensorSize, size_t dimensionsRemaining, jarray outputArray); - -jobject createStringFromStringTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); - OrtErrorCode copyStringTensorToArray(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor, size_t length, jobjectArray outputArray); jobjectArray createStringArrayFromTensor(JNIEnv *jniEnv, const OrtApi * api, OrtValue* tensor); diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index b694f57357bb5..d757bd6281499 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -8,72 +8,6 @@ #include "OrtJniUtil.h" #include "ai_onnxruntime_OnnxTensor.h" -/* - * Class: ai_onnxruntime_OnnxTensor - * Method: createTensor - * Signature: (JJLjava/lang/Object;[JI)J - */ -JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensor - (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jobject dataObj, - jlongArray shape, jint onnxTypeJava) { - (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. - const OrtApi* api = (const OrtApi*) apiHandle; - OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; - // Convert type to ONNX C enum - ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); - - // Extract the shape information - jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); - jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); - - // Create the OrtValue - OrtValue* ortValue = NULL; - OrtErrorCode code = checkOrtStatus(jniEnv, api, - api->CreateTensorAsOrtValue( - allocator, (int64_t*)shapeArr, shapeLen, onnxType, &ortValue - ) - ); - (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); - - int failed = 0; - if (code == ORT_OK) { - // Get a reference to the OrtValue's data - uint8_t* tensorData = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&tensorData)); - if (code == ORT_OK) { - // Check if we're copying a scalar or not - if (shapeLen == 0) { - // Scalars are passed in as a single element array - int64_t copied = copyJavaToPrimitiveArray(jniEnv, onnxType, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - // Extract the tensor shape information - JavaTensorTypeShape typeShape; - code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); - - if (code == ORT_OK) { - // Copy the java array into the tensor - int64_t copied = copyJavaToTensor(jniEnv, onnxType, typeShape.elementCount, - typeShape.dimensions, dataObj, tensorData); - failed = copied == -1 ? 1 : failed; - } else { - failed = 1; - } - } - } else { - failed = 1; - } - } - - if (failed) { - api->ReleaseValue(ortValue); - ortValue = NULL; - } - - // Return the pointer to the OrtValue - return (jlong) ortValue; -} - /* * Class: ai_onnxruntime_OnnxTensor * Method: createTensorFromBuffer @@ -227,7 +161,7 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer size_t sizeBytes = typeShape.elementCount * typeSize; uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData((OrtValue*)handle, (void**)&arr)); + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr)); if (code == ORT_OK) { return (*jniEnv)->NewDirectByteBuffer(jniEnv, arr, (jlong)sizeBytes); @@ -401,11 +335,11 @@ JNIEXPORT jboolean JNICALL Java_ai_onnxruntime_OnnxTensor_getBool /* * Class: ai_onnxruntime_OnnxTensor - * Method: getArray - * Signature: (JJLjava/lang/Object;)V + * Method: getStringArray + * Signature: (JJ[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobject carrier) { +JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getStringArray + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jobjectArray carrier) { (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; OrtValue* value = (OrtValue*) handle; @@ -415,12 +349,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OnnxTensor_getArray if (typeShape.onnxTypeEnum == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { copyStringTensorToArray(jniEnv, api, value, typeShape.elementCount, carrier); } else { - uint8_t* arr = NULL; - code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(value, (void**)&arr)); - if (code == ORT_OK) { - copyTensorToJava(jniEnv, typeShape.onnxTypeEnum, arr, typeShape.elementCount, - typeShape.dimensions, (jarray)carrier); - } + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Non-string types are not supported by this codepath, please raise a Github issue as it should not reach here."); } } } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 11141a3a65a3e..7cb6305923279 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -495,12 +495,12 @@ public void throwWrongInputName() throws OrtException { container.put("wrong_name", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect name."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unknown input name")); + } finally { + OnnxValue.close(container.values()); } } } @@ -522,12 +522,57 @@ public void throwWrongInputType() throws OrtException { container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for incorrect type."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected input data type")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongSizeInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + float[] wrongSizeData = Arrays.copyOf(inputData, 2 * 224 * 224); + Object tensor = OrtUtil.reshape(wrongSizeData, new long[] {1, 2, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Got invalid dimensions for input")); + } finally { + OnnxValue.close(container.values()); + } + } + } + + @Test + public void throwWrongRankInput() throws OrtException { + SqueezeNetTuple tuple = openSessionSqueezeNet(); + try (OrtSession session = tuple.session) { + + float[] inputData = tuple.inputData; + NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); + Map container = new HashMap<>(); + Object tensor = OrtUtil.reshape(inputData, new long[] {1, 1, 3, 224, 224}); + container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); + try { + session.run(container); + fail("Should throw exception for incorrect size."); + } catch (OrtException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Invalid rank for input")); + } finally { + OnnxValue.close(container.values()); } } } @@ -550,12 +595,12 @@ public void throwExtraInputs() throws OrtException { container.put("extra", OnnxTensor.createTensor(env, tensor)); try { session.run(container); - OnnxValue.close(container.values()); fail("Should throw exception for too many inputs."); } catch (OrtException e) { - OnnxValue.close(container.values()); String msg = e.getMessage(); assertTrue(msg.contains("Unexpected number of inputs")); + } finally { + OnnxValue.close(container.values()); } } } @@ -565,12 +610,11 @@ public void testMultiThreads() throws OrtException, InterruptedException { int numThreads = 10; int loop = 10; SqueezeNetTuple tuple = openSessionSqueezeNet(); + Map container = new HashMap<>(); try (OrtSession session = tuple.session) { - float[] inputData = tuple.inputData; float[] expectedOutput = tuple.outputData; NodeInfo inputMeta = session.getInputInfo().values().iterator().next(); - Map container = new HashMap<>(); long[] inputShape = ((TensorInfo) inputMeta.getInfo()).shape; Object tensor = OrtUtil.reshape(inputData, inputShape); container.put(inputMeta.getName(), OnnxTensor.createTensor(env, tensor)); @@ -592,8 +636,9 @@ public void testMultiThreads() throws OrtException, InterruptedException { } executor.shutdown(); executor.awaitTermination(1, TimeUnit.MINUTES); - OnnxValue.close(container.values()); assertTrue(executor.isTerminated()); + } finally { + OnnxValue.close(container.values()); } } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index ea210d96c1507..064f14f3b51ff 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -12,8 +12,11 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; +import java.nio.IntBuffer; import java.nio.ShortBuffer; +import java.util.ArrayList; import java.util.Collections; +import java.util.List; import java.util.SplittableRandom; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -93,30 +96,108 @@ public void testScalarCreation() throws OrtException { } @Test - public void testBufferCreation() throws OrtException { + public void testArrayCreation() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); - // Test creating a value from an array - // Arrays result in tensors allocated by ORT, so they do not have a backing java.nio.Buffer + // Test creating a value from a single dimensional array float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { - // array creation isn't backed by buffers - assertFalse(t.ownsBuffer()); - assertFalse(t.getBufferRef().isPresent()); - FloatBuffer buf = t.getFloatBuffer(); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); float[] output = new float[arrValues.length]; buf.get(output); Assertions.assertArrayEquals(arrValues, output); - // Can't modify the tensor through this buffer. + // Can modify the tensor through this buffer. buf.put(0, 25); - Assertions.assertArrayEquals(arrValues, output); + Assertions.assertArrayEquals(new float[] {25, 1, 2, 3, 4}, (float[]) t.getValue()); } + // Test creating a value from a multidimensional float array + float[][][] arr3dValues = + new float[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, arr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + float[][][] output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + + // Can modify the tensor through the buffer. + FloatBuffer buf = (FloatBuffer) t.getBufferRef().get(); + buf.put(0, 25); + buf.put(12, 32); + buf.put(13, 33); + buf.put(23, 35); + arr3dValues[0][0][0] = 25; + arr3dValues[2][0][0] = 32; + arr3dValues[2][0][1] = 33; + arr3dValues[3][1][2] = 35; + output = (float[][][]) t.getValue(); + Assertions.assertArrayEquals(arr3dValues, output); + } + + // Test creating a value from a multidimensional int array + int[][][] iArr3dValues = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}, {9, 10, 11}}, + {{12, 13, 14}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, iArr3dValues)) { + Assertions.assertArrayEquals(new long[] {4, 2, 3}, t.getInfo().getShape()); + Assertions.assertTrue(t.ownsBuffer()); + Assertions.assertTrue(t.getBufferRef().isPresent()); + int[][][] output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + + // Can modify the tensor through the buffer. + IntBuffer buf = (IntBuffer) t.getBufferRef().get(); + buf.put(0, 25); + iArr3dValues[0][0][0] = 25; + output = (int[][][]) t.getValue(); + Assertions.assertArrayEquals(iArr3dValues, output); + } + + // Test creating a value from a ragged array throws + int[][][] ragged = + new int[][][] { + {{0, 1, 2}, {3, 4, 5}}, + {{6, 7, 8}}, + {{12, 13}, {15, 16, 17}}, + {{18, 19, 20}, {21, 22, 23}} + }; + try (OnnxTensor t = OnnxTensor.createTensor(env, ragged)) { + Assertions.fail("Can't create tensors from ragged arrays"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("ragged")); + } + + // Test creating a value from a non-array, non-primitive type throws. + List list = new ArrayList<>(5); + list.add(5); + try (OnnxTensor t = OnnxTensor.createTensor(env, list)) { + Assertions.fail("Can't create tensors from lists"); + } catch (OrtException e) { + Assertions.assertTrue(e.getMessage().contains("Cannot convert")); + } + } + + @Test + public void testBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Test creating a value from a non-direct byte buffer // Non-direct byte buffers are allocated on the Java heap and must be copied into off-heap - // direct byte buffers - // which can be directly passed to ORT + // direct byte buffers which can be directly passed to ORT + float[] arrValues = new float[] {0, 1, 2, 3, 4}; FloatBuffer nonDirectBuffer = FloatBuffer.allocate(5); nonDirectBuffer.put(arrValues); nonDirectBuffer.rewind(); @@ -335,10 +416,12 @@ public void testFp32ToFp16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-fp16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -347,6 +430,8 @@ public void testFp32ToFp16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToFp16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.fp16ToFloat(Fp16Conversions.floatToFp16(input[i][j])); } } floatBuf.rewind(); @@ -354,25 +439,31 @@ public void testFp32ToFp16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertFp16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound fp16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } } @@ -382,10 +473,12 @@ public void testFp32ToBf16() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-fp32-to-bf16.onnx").toString(); SplittableRandom rng = new SplittableRandom(1); - float[][] input = new float[10][5]; + int dim1 = 10, dim2 = 5; + float[][] input = new float[dim1][dim2]; + float[][] expectedOutput = new float[dim1][dim2]; FloatBuffer floatBuf = - ByteBuffer.allocateDirect(4 * 10 * 5).order(ByteOrder.nativeOrder()).asFloatBuffer(); - ShortBuffer shortBuf = ShortBuffer.allocate(10 * 5); + ByteBuffer.allocateDirect(4 * dim1 * dim2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + ShortBuffer shortBuf = ShortBuffer.allocate(dim1 * dim2); // Generate data for (int i = 0; i < input.length; i++) { @@ -394,6 +487,8 @@ public void testFp32ToBf16() throws OrtException { input[i][j] = Float.intBitsToFloat(bits); floatBuf.put(input[i][j]); shortBuf.put(Fp16Conversions.floatToBf16(input[i][j])); + expectedOutput[i][j] = + Fp16Conversions.bf16ToFloat(Fp16Conversions.floatToBf16(input[i][j])); } } floatBuf.rewind(); @@ -401,25 +496,31 @@ public void testFp32ToBf16() throws OrtException { try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); OrtSession session = env.createSession(modelPath, opts); - OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {10, 5}); + OnnxTensor tensor = OnnxTensor.createTensor(env, floatBuf, new long[] {dim1, dim2}); OrtSession.Result result = session.run(Collections.singletonMap("input", tensor))) { OnnxTensor output = (OnnxTensor) result.get(0); // Check outbound Java side cast to fp32 works FloatBuffer castOutput = output.getFloatBuffer(); - float[] expectedFloatArr = new float[10 * 5]; + float[] expectedFloatArr = new float[dim1 * dim2]; Fp16Conversions.convertBf16BufferToFloatBuffer(shortBuf).get(expectedFloatArr); - float[] actualFloatArr = new float[10 * 5]; + float[] actualFloatArr = new float[dim1 * dim2]; castOutput.get(actualFloatArr); Assertions.assertArrayEquals(expectedFloatArr, actualFloatArr); // Check bits are correct ShortBuffer outputBuf = output.getShortBuffer(); - short[] expectedShortArr = new short[10 * 5]; + short[] expectedShortArr = new short[dim1 * dim2]; shortBuf.get(expectedShortArr); - short[] actualShortArr = new short[10 * 5]; + short[] actualShortArr = new short[dim1 * dim2]; outputBuf.get(actualShortArr); Assertions.assertArrayEquals(expectedShortArr, actualShortArr); + + // Check outbound bf16 -> float[] conversion + float[][] floats = (float[][]) output.getValue(); + for (int i = 0; i < dim1; i++) { + Assertions.assertArrayEquals(expectedOutput[i], floats[i]); + } } }