Skip to content

Commit

Permalink
[Java] JNI refactor for OrtJniUtil (#12516)
Browse files Browse the repository at this point in the history
Refactoring more JNI methods in OrtJniUtil.
Make the strings const.
Removing unnecessary use of OrtAllocator.
  • Loading branch information
Craigacp authored Sep 9, 2022
1 parent 60e4d01 commit 5d55b07
Show file tree
Hide file tree
Showing 9 changed files with 697 additions and 624 deletions.
17 changes: 17 additions & 0 deletions java/src/main/java/ai/onnxruntime/MapInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.TensorInfo.OnnxTensorType;

/** Describes an {@link OnnxMap} object or output node. */
public class MapInfo implements ValueInfo {

Expand Down Expand Up @@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo {
this.valueType = valueType;
}

/**
* Construct a MapInfo with the specified size, key type and value type.
*
* <p>Called from JNI.
*
* @param size The size.
* @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys.
* @param valueTypeInt The int representing the {@link OnnxTensorType} of the values.
*/
MapInfo(int size, int keyTypeInt, int valueTypeInt) {
this.size = size;
this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt));
this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt));
}

@Override
public String toString() {
String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size;
Expand Down
11 changes: 5 additions & 6 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ public Object getValue() throws OrtException {
case BOOL:
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
case STRING:
return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case UNKNOWN:
default:
throw new OrtException("Extracting the value of an invalid Tensor.");
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
Expand Down Expand Up @@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)

private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;

private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native String getString(long apiHandle, long nativeHandle) throws OrtException;

private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;

private native void getArray(
long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
private native void getArray(long apiHandle, long nativeHandle, Object carrier)
throws OrtException;

private native void close(long apiHandle, long nativeHandle);

Expand Down
19 changes: 19 additions & 0 deletions java/src/main/java/ai/onnxruntime/SequenceInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.TensorInfo.OnnxTensorType;

/** Describes an {@link OnnxSequence}, including it's element type if known. */
public class SequenceInfo implements ValueInfo {

Expand Down Expand Up @@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo {
this.mapInfo = null;
}

/**
* Construct a sequence of known length, with the specified type. This sequence does not contain
* maps.
*
* <p>Called from JNI.
*
* @param length The length of the sequence.
* @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}.
*/
SequenceInfo(int length, int sequenceTypeInt) {
this.length = length;
this.sequenceType =
OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt));
this.sequenceOfMaps = false;
this.mapInfo = null;
}

/**
* Construct a sequence of known length containing maps.
*
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.onnxType = onnxType;
}

/**
* Constructs a TensorInfo with the specified shape and native type int.
*
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
this.shape = shape;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
}

/**
* Get a copy of the tensor's shape.
*
Expand Down
Loading

0 comments on commit 5d55b07

Please sign in to comment.