Skip to content

Commit

Permalink
[java] Enable output pinning in OrtSession and OrtTrainingSession (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
Craigacp authored Sep 26, 2023
1 parent ccb73fd commit aed43f4
Show file tree
Hide file tree
Showing 9 changed files with 668 additions and 119 deletions.
202 changes: 178 additions & 24 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ public Result run(Map<String, ? extends OnnxTensorLike> inputs, RunOptions runOp
*/
public Result run(Map<String, ? extends OnnxTensorLike> inputs, Set<String> requestedOutputs)
throws OrtException {
return run(inputs, requestedOutputs, null);
return run(inputs, requestedOutputs, Collections.emptyMap(), null);
}

/**
Expand All @@ -259,17 +259,90 @@ public Result run(
Set<String> requestedOutputs,
RunOptions runOptions)
throws OrtException {
return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions);
}

/**
* Scores an input feed dict, returning the map of pinned outputs.
*
* <p>The outputs are sorted based on the supplied map traversal order.
*
* <p>Note: pinned outputs are not owned by the {@link Result} object, and are <b>not</b> closed
* when the result object is closed.
*
* @param inputs The inputs to score.
* @param pinnedOutputs The requested outputs which the user has allocated.
* @return The inferred outputs.
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, ? extends OnnxTensorLike> inputs, Map<String, ? extends OnnxValue> pinnedOutputs)
throws OrtException {
return run(inputs, Collections.emptySet(), pinnedOutputs, null);
}

/**
* Scores an input feed dict, returning the map of requested and pinned outputs.
*
* <p>The outputs are sorted based on the supplied set traversal order with pinned outputs first,
* then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name
* appears in both the requested outputs and the pinned outputs.
*
* <p>Note: pinned outputs are not owned by the {@link Result} object, and are <b>not</b> closed
* when the result object is closed.
*
* @param inputs The inputs to score.
* @param requestedOutputs The requested outputs which ORT will allocate.
* @param pinnedOutputs The requested outputs which the user has allocated.
* @return The inferred outputs.
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, ? extends OnnxTensorLike> inputs,
Set<String> requestedOutputs,
Map<String, ? extends OnnxValue> pinnedOutputs)
throws OrtException {
return run(inputs, requestedOutputs, pinnedOutputs, null);
}

/**
* Scores an input feed dict, returning the map of requested and pinned outputs.
*
* <p>The outputs are sorted based on the supplied set traversal order with pinned outputs first,
* then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name
* appears in both the requested outputs and the pinned outputs.
*
* <p>Note: pinned outputs are not owned by the {@link Result} object, and are <b>not</b> closed
* when the result object is closed.
*
* @param inputs The inputs to score.
* @param requestedOutputs The requested outputs which ORT will allocate.
* @param pinnedOutputs The requested outputs which the user has allocated.
* @param runOptions The RunOptions to control this run.
* @return The inferred outputs.
* @throws OrtException If there was an error in native code, the input or output names are
* invalid, or if there are zero or too many inputs or outputs.
*/
public Result run(
Map<String, ? extends OnnxTensorLike> inputs,
Set<String> requestedOutputs,
Map<String, ? extends OnnxValue> pinnedOutputs,
RunOptions runOptions)
throws OrtException {
if (!closed) {
if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) {
throw new OrtException(
"Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size());
}
if (requestedOutputs.isEmpty() || (requestedOutputs.size() > numOutputs)) {
int totalOutputs = requestedOutputs.size() + pinnedOutputs.size();
if ((totalOutputs == 0) || (totalOutputs > numOutputs)) {
throw new OrtException(
"Unexpected number of requestedOutputs, expected [1,"
"Unexpected number of requestedOutputs & pinnedOutputs, expected [1,"
+ numOutputs
+ ") found "
+ requestedOutputs.size());
+ totalOutputs);
}
String[] inputNamesArray = new String[inputs.size()];
long[] inputHandles = new long[inputs.size()];
Expand All @@ -284,20 +357,41 @@ public Result run(
"Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString());
}
}
String[] outputNamesArray = new String[requestedOutputs.size()];
String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()];
OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length];
long[] outputHandles = new long[outputNamesArray.length];
i = 0;
for (Map.Entry<String, ? extends OnnxValue> e : pinnedOutputs.entrySet()) {
if (outputNames.contains(e.getKey())) {
outputNamesArray[i] = e.getKey();
outputValues[i] = e.getValue();
outputHandles[i] = getHandle(e.getValue());
i++;
} else {
throw new OrtException(
"Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString());
}
}
for (String s : requestedOutputs) {
if (outputNames.contains(s)) {
outputNamesArray[i] = s;
i++;
if (!pinnedOutputs.containsKey(s)) {
outputNamesArray[i] = s;
// outputValues and outputHandles can be null/0 for these outputs as ORT will allocate
// them.
i++;
} else {
throw new OrtException(
"Output '"
+ s
+ "' was found in both the requested outputs and the pinned outputs");
}
} else {
throw new OrtException(
"Unknown output name " + s + ", expected one of " + outputNames.toString());
}
}
long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle();

OnnxValue[] outputValues =
boolean[] ownedByResult =
run(
OnnxRuntime.ortApiHandle,
nativeHandle,
Expand All @@ -307,13 +401,40 @@ public Result run(
inputNamesArray.length,
outputNamesArray,
outputNamesArray.length,
outputValues,
outputHandles,
runOptionsHandle);
return new Result(outputNamesArray, outputValues);
return new Result(outputNamesArray, outputValues, ownedByResult);
} else {
throw new IllegalStateException("Trying to score a closed OrtSession.");
}
}

/**
* Pulls out the native handle by casting it to the appropriate type.
*
* @param v The OnnxValue.
* @return The native handle.
*/
static long getHandle(OnnxValue v) {
/*
* Note this method exists as interface methods are all public, but we do not want users to be
* able to access the native pointer via a public API so can't add a method to OnnxValue which
* exposes it.
*/
if (v instanceof OnnxTensorLike) {
return ((OnnxTensorLike) v).nativeHandle;
} else if (v instanceof OnnxSequence) {
return ((OnnxSequence) v).nativeHandle;
} else if (v instanceof OnnxMap) {
return ((OnnxMap) v).nativeHandle;
} else {
throw new IllegalArgumentException(
"Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found "
+ v.getClass());
}
}

/**
* Gets the metadata for the currently loaded model.
*
Expand Down Expand Up @@ -409,8 +530,9 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long
throws OrtException;

/**
* The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other
* handles must be valid pointers.
* The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can
* contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but
* all other handles must be valid pointers.
*
* @param apiHandle The pointer to the api.
* @param nativeHandle The pointer to the session.
Expand All @@ -419,12 +541,14 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long
* @param inputs The input tensors.
* @param numInputs The number of inputs.
* @param outputNamesArray The requested output names.
* @param outputValues The OnnxValue output array.
* @param outputHandles The OrtValue output pointer array.
* @param numOutputs The number of requested outputs.
* @param runOptionsHandle The (possibly null) pointer to the run options.
* @return The OnnxValues produced by this run.
* @return A boolean array representing if the OnnxValues were allocated by this run call.
* @throws OrtException If the native call failed in some way.
*/
private native OnnxValue[] run(
private native boolean[] run(
long apiHandle,
long nativeHandle,
long allocatorHandle,
Expand All @@ -433,6 +557,8 @@ private native OnnxValue[] run(
long numInputs,
String[] outputNamesArray,
long numOutputs,
OnnxValue[] outputValues,
long[] outputHandles,
long runOptionsHandle)
throws OrtException;

Expand Down Expand Up @@ -1417,9 +1543,13 @@ private native void addRunConfigEntry(
/**
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s.
*
* <p>When this is closed it closes all the {@link OnnxValue}s inside it. If you maintain a
* reference to a value after this object has been closed it will throw an {@link
* <p>When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you
* maintain a reference to a value after this object has been closed it will throw an {@link
* IllegalStateException} upon access.
*
* <p>{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed
* by the {@link Result#close()} method. Ownership of each output can be checked with {@link
* Result#isResultOwner(int)}.
*/
public static class Result implements AutoCloseable, Iterable<Map.Entry<String, OnnxValue>> {

Expand All @@ -1429,6 +1559,8 @@ public static class Result implements AutoCloseable, Iterable<Map.Entry<String,

private final List<OnnxValue> list;

private final boolean[] ownedByResult;

private boolean closed;

/**
Expand All @@ -1437,21 +1569,23 @@ public static class Result implements AutoCloseable, Iterable<Map.Entry<String,
* @param names The output names.
* @param values The output values.
*/
Result(String[] names, OnnxValue[] values) {
if (names.length != values.length) {
Result(String[] names, OnnxValue[] values, boolean[] ownedByResult) {
if ((names.length != values.length) || (names.length != ownedByResult.length)) {
throw new IllegalArgumentException(
"Expected same number of names and values, found names.length = "
"Expected same number of names, values and ownedByResult, found names.length = "
+ names.length
+ ", values.length = "
+ values.length);
+ values.length
+ ", ownedByResult.length = "
+ ownedByResult.length);
}

map = new LinkedHashMap<>(OrtUtil.capacityFromSize(names.length));
list = new ArrayList<>(names.length);
list = new ArrayList<>(Arrays.asList(values));
this.ownedByResult = ownedByResult;

for (int i = 0; i < names.length; i++) {
map.put(names[i], values[i]);
list.add(values[i]);
}
this.closed = false;
}
Expand All @@ -1460,8 +1594,11 @@ public static class Result implements AutoCloseable, Iterable<Map.Entry<String,
public void close() {
if (!closed) {
closed = true;
for (OnnxValue t : map.values()) {
t.close();
for (int i = 0; i < list.size(); i++) {
if (ownedByResult[i]) {
OnnxValue value = list.get(i);
value.close();
}
}
} else {
logger.warning("Closing an already closed Result");
Expand Down Expand Up @@ -1494,6 +1631,23 @@ public OnnxValue get(int index) {
}
}

/**
* Gets the value from the container at the specified index.
*
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
* ArrayIndexOutOfBoundsException} if the index is invalid.
*
* @param index The index to lookup.
* @return Is that value owned by this result object?
*/
public boolean isResultOwner(int index) {
if (!closed) {
return ownedByResult[index];
} else {
throw new IllegalStateException("Result is closed");
}
}

/**
* Returns the number of outputs in this Result.
*
Expand Down
Loading

0 comments on commit aed43f4

Please sign in to comment.