Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] JNI refactor for OrtJniUtil #12516

Merged
merged 8 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions java/src/main/java/ai/onnxruntime/MapInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.TensorInfo.OnnxTensorType;

/** Describes an {@link OnnxMap} object or output node. */
public class MapInfo implements ValueInfo {

Expand Down Expand Up @@ -42,6 +44,21 @@ public class MapInfo implements ValueInfo {
this.valueType = valueType;
}

/**
* Construct a MapInfo with the specified size, key type and value type.
*
* <p>Called from JNI.
*
* @param size The size.
* @param keyTypeInt The int representing the {@link OnnxTensorType} of the keys.
* @param valueTypeInt The int representing the {@link OnnxTensorType} of the values.
*/
MapInfo(int size, int keyTypeInt, int valueTypeInt) {
this.size = size;
this.keyType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(keyTypeInt));
this.valueType = OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(valueTypeInt));
}

@Override
public String toString() {
String initial = size == -1 ? "MapInfo(size=UNKNOWN" : "MapInfo(size=" + size;
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
11 changes: 5 additions & 6 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ public Object getValue() throws OrtException {
case BOOL:
return getBool(OnnxRuntime.ortApiHandle, nativeHandle);
case STRING:
return getString(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return getString(OnnxRuntime.ortApiHandle, nativeHandle);
case UNKNOWN:
default:
throw new OrtException("Extracting the value of an invalid Tensor.");
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle, carrier);
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
Expand Down Expand Up @@ -284,13 +284,12 @@ private native short getShort(long apiHandle, long nativeHandle, int onnxType)

private native long getLong(long apiHandle, long nativeHandle, int onnxType) throws OrtException;

private native String getString(long apiHandle, long nativeHandle, long allocatorHandle)
throws OrtException;
private native String getString(long apiHandle, long nativeHandle) throws OrtException;

private native boolean getBool(long apiHandle, long nativeHandle) throws OrtException;

private native void getArray(
long apiHandle, long nativeHandle, long allocatorHandle, Object carrier) throws OrtException;
private native void getArray(long apiHandle, long nativeHandle, Object carrier)
throws OrtException;

private native void close(long apiHandle, long nativeHandle);

Expand Down
19 changes: 19 additions & 0 deletions java/src/main/java/ai/onnxruntime/SequenceInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
*/
package ai.onnxruntime;

import ai.onnxruntime.TensorInfo.OnnxTensorType;

/** Describes an {@link OnnxSequence}, including it's element type if known. */
public class SequenceInfo implements ValueInfo {

Expand Down Expand Up @@ -35,6 +37,23 @@ public class SequenceInfo implements ValueInfo {
this.mapInfo = null;
}

/**
* Construct a sequence of known length, with the specified type. This sequence does not contain
* maps.
*
* <p>Called from JNI.
*
* @param length The length of the sequence.
* @param sequenceTypeInt The element type int of the sequence mapped from {@link OnnxTensorType}.
*/
SequenceInfo(int length, int sequenceTypeInt) {
this.length = length;
this.sequenceType =
OnnxJavaType.mapFromOnnxTensorType(OnnxTensorType.mapFromInt(sequenceTypeInt));
this.sequenceOfMaps = false;
this.mapInfo = null;
}

/**
* Construct a sequence of known length containing maps.
*
Expand Down
14 changes: 14 additions & 0 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.onnxType = onnxType;
}

/**
* Constructs a TensorInfo with the specified shape and native type int.
*
* <p>Called from JNI.
*
* @param shape The tensor shape.
* @param typeInt The native type int.
*/
TensorInfo(long[] shape, int typeInt) {
this.shape = shape;
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
}

/**
* Get a copy of the tensor's shape.
*
Expand Down
Loading