Skip to content

Commit

Permalink
Updating TensorInfo so it contains the named dimensions if any.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Dec 30, 2023
1 parent 780fc36 commit 4d5f8ea
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
32 changes: 30 additions & 2 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;
import java.util.stream.Collectors;

/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {
Expand Down Expand Up @@ -159,6 +160,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The shape of the tensor. */
final long[] shape;

/** The names of the unbound dimensions. */
final String[] dimensionNames;

/** The Java type of this tensor. */
public final OnnxJavaType type;

Expand All @@ -177,6 +181,8 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
*/
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
this.shape = shape;
this.dimensionNames = new String[shape.length];
Arrays.fill(dimensionNames, "");
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
Expand All @@ -188,10 +194,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param names The dimension names.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
this.dimensionNames = names;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
Expand All @@ -206,6 +214,15 @@ public long[] getShape() {
return Arrays.copyOf(shape, shape.length);
}

/**
* Get a copy of the tensor's named dimensions.
*
* @return A copof the tensor's named dimensions.
*/
public String[] getDimensionNames() {
return Arrays.copyOf(dimensionNames, dimensionNames.length);
}

@Override
public String toString() {
return "TensorInfo(javaType="
Expand All @@ -214,7 +231,18 @@ public String toString() {
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape)
+ ")";
+ ",dimNames=["
+ Arrays.stream(dimensionNames)
.map(
a -> {
if (a.isEmpty()) {
return "\"\"";
} else {
return a;
}
})
.collect(Collectors.joining(","))
+ "])";
}

/**
Expand Down
26 changes: 22 additions & 4 deletions java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
//printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
Expand All @@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;

// Create the string array for the names.
const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
if (dimensionNames == NULL) {
throwOrtException(jniEnv, 1, "Not enough memory");
return NULL;
}
code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
if (code != ORT_OK) {
// extraction failed, exception has been thrown, return to Java.
free(dimensionNames);
return NULL;
}
jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
for (size_t i = 0; i < numDim; i++) {
jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv, names, i, javaName);
}
free(dimensionNames);

// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([JI)V");
//printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "<init>", "([J[Ljava/lang/String;I)V");
jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
return tensorInfo;
}

Expand Down
6 changes: 6 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,12 @@ public void testSymbolicDimensionAssignment() throws OrtException {
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
assertEquals(2, aInfo.dimensionNames.length);
assertEquals("n", aInfo.dimensionNames[0]);
assertEquals("", aInfo.dimensionNames[1]);
TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
assertEquals(1, bInfo.dimensionNames.length);
assertEquals("m", bInfo.dimensionNames[0]);
}
}
// Check that when the options are assigned it overrides the symbolic dimension
Expand Down

0 comments on commit 4d5f8ea

Please sign in to comment.