Skip to content

Commit

Permalink
[java] Fix double close (#19133)
Browse files Browse the repository at this point in the history
### Description
The `OnnxValue` and `OrtProviderOptions` implementations now check to
see if they've been closed before accessing the native pointer, and also
before close is called.

### Motivation and Context
Before they could be closed twice which SIGSEGV'd the JVM. Fixes #19125.
  • Loading branch information
Craigacp authored Jan 14, 2024
1 parent c3ce9df commit 71657d1
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 18 deletions.
27 changes: 25 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxMap.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

/**
* A container for a map returned by {@link OrtSession#run(Map)}.
Expand All @@ -16,6 +17,7 @@
* values: String, Long, Float, Double.
*/
public class OnnxMap implements OnnxValue {
private static final Logger logger = Logger.getLogger(OnnxMap.class.getName());

static {
try {
Expand Down Expand Up @@ -107,6 +109,8 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {

private final OnnxMapValueType valueType;

private boolean closed;

/**
* Constructs an OnnxMap containing a reference to the native map along with the type information.
*
Expand All @@ -122,6 +126,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) {
this.info = info;
this.stringKeys = info.keyType == OnnxJavaType.STRING;
this.valueType = OnnxMapValueType.mapFromOnnxJavaType(info.valueType);
this.closed = false;
}

/**
Expand All @@ -146,6 +151,7 @@ public OnnxValueType getType() {
*/
@Override
public Map<? extends Object, ? extends Object> getValue() throws OrtException {
checkClosed();
Object[] keys = getMapKeys();
Object[] values = getMapValues();
HashMap<Object, Object> map = new HashMap<>(OrtUtil.capacityFromSize(keys.length));
Expand Down Expand Up @@ -222,10 +228,27 @@ public String toString() {
return "ONNXMap(size=" + size() + ",info=" + info.toString() + ")";
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Closes this map, releasing the native memory backing it and it's elements. */
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed map.");
}
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}

private native String[] getStringKeys(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down
27 changes: 25 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxSequence.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;

/**
* A sequence of {@link OnnxValue}s all of the same type.
Expand All @@ -24,6 +25,7 @@
* </ul>
*/
public class OnnxSequence implements OnnxValue {
private static final Logger logger = Logger.getLogger(OnnxSequence.class.getName());

static {
try {
Expand All @@ -40,6 +42,8 @@ public class OnnxSequence implements OnnxValue {

private final SequenceInfo info;

private boolean closed;

/**
* Creates the wrapper object for a native sequence.
*
Expand All @@ -53,6 +57,7 @@ public class OnnxSequence implements OnnxValue {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
this.closed = false;
}

@Override
Expand All @@ -76,6 +81,7 @@ public OnnxValueType getType() {
*/
@Override
public List<? extends OnnxValue> getValue() throws OrtException {
checkClosed();
if (info.sequenceOfMaps) {
OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle);
return Collections.unmodifiableList(Arrays.asList(maps));
Expand Down Expand Up @@ -110,10 +116,27 @@ public String toString() {
return "OnnxSequence(info=" + info.toString() + ")";
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Closes this sequence, releasing the native memory backing it and it's elements. */
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed sequence.");
}
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}

private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down
18 changes: 16 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Arrays;
import java.util.logging.Logger;

/**
* A Java object wrapping an OnnxSparseTensor.
Expand All @@ -22,6 +23,7 @@
* different static inner class representing each type.
*/
public final class OnnxSparseTensor extends OnnxTensorLike {
private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName());
private final SparseTensorType sparseTensorType;

// Held to prevent deallocation while used in native code.
Expand Down Expand Up @@ -198,6 +200,7 @@ public OnnxValueType getType() {

@Override
public SparseTensor<? extends Buffer> getValue() throws OrtException {
checkClosed();
Buffer buffer = getValuesBuffer();
long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
switch (sparseTensorType) {
Expand Down Expand Up @@ -234,8 +237,13 @@ public SparseTensor<? extends Buffer> getValue() throws OrtException {
}

@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed OnnxSparseTensor.");
}
}

/**
Expand All @@ -257,6 +265,7 @@ public SparseTensorType getSparseTensorType() {
* @return The indices.
*/
public Buffer getIndicesBuffer() {
checkClosed();
switch (sparseTensorType) {
case COO:
case CSRC:
Expand Down Expand Up @@ -295,6 +304,7 @@ public Buffer getIndicesBuffer() {
* @return The inner indices.
*/
public LongBuffer getInnerIndicesBuffer() {
checkClosed();
if (sparseTensorType == SparseTensorType.CSRC) {
LongBuffer buf =
getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle)
Expand All @@ -320,6 +330,7 @@ public LongBuffer getInnerIndicesBuffer() {
* @return The data buffer.
*/
public Buffer getValuesBuffer() {
checkClosed();
ByteBuffer buffer =
getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder());
switch (info.type) {
Expand Down Expand Up @@ -396,6 +407,7 @@ public Buffer getValuesBuffer() {
* @return The indices shape.
*/
public long[] getIndicesShape() {
checkClosed();
return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}

Expand All @@ -405,6 +417,7 @@ public long[] getIndicesShape() {
* @return The indices shape.
*/
public long[] getInnerIndicesShape() {
checkClosed();
if (sparseTensorType == SparseTensorType.CSRC) {
return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle);
} else {
Expand All @@ -420,6 +433,7 @@ public long[] getInnerIndicesShape() {
* @return The values shape.
*/
public long[] getValuesShape() {
checkClosed();
return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle);
}

Expand Down
24 changes: 19 additions & 5 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Optional;
import java.util.logging.Logger;

/**
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
* returned as outputs.
*/
public class OnnxTensor extends OnnxTensorLike {
private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName());

/**
* This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
Expand Down Expand Up @@ -97,6 +99,7 @@ public OnnxValueType getType() {
*/
@Override
public Object getValue() throws OrtException {
checkClosed();
if (info.isScalar()) {
switch (info.type) {
case FLOAT:
Expand Down Expand Up @@ -144,16 +147,21 @@ public Object getValue() throws OrtException {

@Override
public String toString() {
return "OnnxTensor(info=" + info.toString() + ")";
return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")";
}

/**
* Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If
* it is backed by a buffer then the memory is released when the buffer is GC'd.
* Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it
* is backed by a buffer then the memory is released when the buffer is GC'd.
*/
@Override
public void close() {
close(OnnxRuntime.ortApiHandle, nativeHandle);
public synchronized void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
logger.warning("Closing an already closed tensor.");
}
}

/**
Expand All @@ -165,6 +173,7 @@ public void close() {
* @return A ByteBuffer copy of the OnnxTensor.
*/
public ByteBuffer getByteBuffer() {
checkClosed();
if (info.type != OnnxJavaType.STRING) {
ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
Expand All @@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() {
* @return A FloatBuffer copy of the OnnxTensor.
*/
public FloatBuffer getFloatBuffer() {
checkClosed();
if (info.type == OnnxJavaType.FLOAT) {
// if it's fp32 use the efficient copy.
FloatBuffer buffer = getBuffer().asFloatBuffer();
Expand Down Expand Up @@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() {
* @return A DoubleBuffer copy of the OnnxTensor.
*/
public DoubleBuffer getDoubleBuffer() {
checkClosed();
if (info.type == OnnxJavaType.DOUBLE) {
DoubleBuffer buffer = getBuffer().asDoubleBuffer();
DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity());
Expand All @@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() {
* @return A ShortBuffer copy of the OnnxTensor.
*/
public ShortBuffer getShortBuffer() {
checkClosed();
if ((info.type == OnnxJavaType.INT16)
|| (info.type == OnnxJavaType.FLOAT16)
|| (info.type == OnnxJavaType.BFLOAT16)) {
Expand All @@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() {
* @return An IntBuffer copy of the OnnxTensor.
*/
public IntBuffer getIntBuffer() {
checkClosed();
if (info.type == OnnxJavaType.INT32) {
IntBuffer buffer = getBuffer().asIntBuffer();
IntBuffer output = IntBuffer.allocate(buffer.capacity());
Expand All @@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() {
* @return A LongBuffer copy of the OnnxTensor.
*/
public LongBuffer getLongBuffer() {
checkClosed();
if (info.type == OnnxJavaType.INT64) {
LongBuffer buffer = getBuffer().asLongBuffer();
LongBuffer output = LongBuffer.allocate(buffer.capacity());
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue {
/** The size and shape information for this tensor. */
protected final TensorInfo info;

/** Is this value closed? */
protected boolean closed;

/**
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
*
Expand All @@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
this.closed = false;
}

/**
Expand All @@ -59,4 +63,16 @@ long getNativeHandle() {
public TensorInfo getInfo() {
return info;
}

@Override
public synchronized boolean isClosed() {
return closed;
}

/** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
protected void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OnnxValue");
}
}
}
9 changes: 8 additions & 1 deletion java/src/main/java/ai/onnxruntime/OnnxValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ public enum OnnxValueType {
*/
public ValueInfo getInfo();

/** Closes the OnnxValue, freeing it's native memory. */
/**
* Checks if this value is closed (i.e., the native object has been released).
*
* @return True if the value is closed and the native object has been released.
*/
public boolean isClosed();

/** Closes the OnnxValue, freeing its native memory. */
@Override
public void close();

Expand Down
Loading

0 comments on commit 71657d1

Please sign in to comment.