From 4d5f8ea8a30435148b94858d2ac6988cf48f8abd Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 29 Dec 2023 21:43:34 -0500 Subject: [PATCH] Updating TensorInfo so it contains the named dimensions if any. --- .../main/java/ai/onnxruntime/TensorInfo.java | 32 +++++++++++++++++-- java/src/main/native/OrtJniUtil.c | 26 ++++++++++++--- .../java/ai/onnxruntime/InferenceTest.java | 6 ++++ 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 69ccb954e8afe..4861583cb9457 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -7,6 +7,7 @@ import java.lang.reflect.Array; import java.nio.Buffer; import java.util.Arrays; +import java.util.stream.Collectors; /** Describes an {@link OnnxTensor}, including it's size, shape and element type. */ public class TensorInfo implements ValueInfo { @@ -159,6 +160,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The shape of the tensor. */ final long[] shape; + /** The names of the unbound dimensions. */ + final String[] dimensionNames; + /** The Java type of this tensor. */ public final OnnxJavaType type; @@ -177,6 +181,8 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { */ TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) { this.shape = shape; + this.dimensionNames = new String[shape.length]; + Arrays.fill(dimensionNames, ""); this.type = type; this.onnxType = onnxType; this.numElements = elementCount(shape); @@ -188,10 +194,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { *

Called from JNI. * * @param shape The tensor shape. + * @param names The dimension names. * @param typeInt The native type int. */ - TensorInfo(long[] shape, int typeInt) { + TensorInfo(long[] shape, String[] names, int typeInt) { this.shape = shape; + this.dimensionNames = names; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); this.numElements = elementCount(shape); @@ -206,6 +214,15 @@ public long[] getShape() { return Arrays.copyOf(shape, shape.length); } + /** + * Get a copy of the tensor's named dimensions. + * + * @return A copof the tensor's named dimensions. + */ + public String[] getDimensionNames() { + return Arrays.copyOf(dimensionNames, dimensionNames.length); + } + @Override public String toString() { return "TensorInfo(javaType=" @@ -214,7 +231,18 @@ public String toString() { + onnxType.toString() + ",shape=" + Arrays.toString(shape) - + ")"; + + ",dimNames=[" + + Arrays.stream(dimensionNames) + .map( + a -> { + if (a.isEmpty()) { + return "\"\""; + } else { + return a; + } + }) + .collect(Collectors.joining(",")) + + "])"; } /** diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 879ba8a310618..96a4cebdb5c51 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT if (code != ORT_OK) { return NULL; } - //printf("numDim %d\n",numDim); int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim); code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim)); if (code != ORT_OK) { @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT free(dimensions); dimensions = NULL; + // Create the string array for the names. + const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim); + if (dimensionNames == NULL) { + throwOrtException(jniEnv, 1, "Not enough memory"); + return NULL; + } + code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim)); + if (code != ORT_OK) { + // extraction failed, exception has been thrown, return to Java. + free(dimensionNames); + return NULL; + } + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL); + for (size_t i = 0; i < numDim; i++) { + jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, names, i, javaName); + } + free(dimensionNames); + // Create the TensorInfo object static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo"; jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName); - jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V"); - //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor); - jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt); + jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V"); + jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt); return tensorInfo; } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e975117fb75bd..a8dea8bd62a00 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -588,6 +588,12 @@ public void testSymbolicDimensionAssignment() throws OrtException { Map infoMap = session.getInputInfo(); TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo(); assertArrayEquals(new long[] {-1, 2}, aInfo.shape); + assertEquals(2, aInfo.dimensionNames.length); + assertEquals("n", aInfo.dimensionNames[0]); + assertEquals("", aInfo.dimensionNames[1]); + TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo(); + assertEquals(1, bInfo.dimensionNames.length); + assertEquals("m", bInfo.dimensionNames[0]); } } // Check that when the options are assigned it overrides the symbolic dimension