Skip to content

Commit

Permalink
Fixing size_t cast for Windows and slightly tidying up the toString.
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp committed Jan 14, 2024
1 parent 4d5f8ea commit 41d4ff8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 19 deletions.
55 changes: 37 additions & 18 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The names of the unbound dimensions. */
final String[] dimensionNames;

/** If there are non-empty dimension names */
private final boolean hasNames;

/** The Java type of this tensor. */
public final OnnxJavaType type;

Expand All @@ -183,6 +186,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.shape = shape;
this.dimensionNames = new String[shape.length];
Arrays.fill(dimensionNames, "");
this.hasNames = false;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
Expand All @@ -200,6 +204,14 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
this.dimensionNames = names;
boolean hasNames = false;
for (String s : names) {
if (!s.isEmpty()) {
hasNames = true;
break;
}
}
this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
Expand All @@ -225,24 +237,31 @@ public String[] getDimensionNames() {

@Override
public String toString() {
return "TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape)
+ ",dimNames=["
+ Arrays.stream(dimensionNames)
.map(
a -> {
if (a.isEmpty()) {
return "\"\"";
} else {
return a;
}
})
.collect(Collectors.joining(","))
+ "])";
String output =
"TensorInfo(javaType="
+ type.toString()
+ ",onnxType="
+ onnxType.toString()
+ ",shape="
+ Arrays.toString(shape);
if (hasNames) {
output =
output
+ ",dimNames=["
+ Arrays.stream(dimensionNames)
.map(
a -> {
if (a.isEmpty()) {
return "\"\"";
} else {
return a;
}
})
.collect(Collectors.joining(","))
+ "]";
}
output = output + ")";
return output;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
for (size_t i = 0; i < numDim; i++) {
jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
(*jniEnv)->SetObjectArrayElement(jniEnv, names, i, javaName);
(*jniEnv)->SetObjectArrayElement(jniEnv, names, i, safecast_size_t_to_jsize(javaName));
}
free(dimensionNames);

Expand Down

0 comments on commit 41d4ff8

Please sign in to comment.