From 71657d1eb8b0a24a4b6584d9e904506a0b4e1521 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 14 Jan 2024 17:53:26 -0500 Subject: [PATCH] [java] Fix double close (#19133) ### Description The `OnnxValue` and `OrtProviderOptions` implementations now check to see if they've been closed before accessing the native pointer, and also before close is called. ### Motivation and Context Before they could be closed twice which SIGSEGV'd the JVM. Fixes #19125. --- .../src/main/java/ai/onnxruntime/OnnxMap.java | 27 +++++++++++++-- .../java/ai/onnxruntime/OnnxSequence.java | 27 +++++++++++++-- .../java/ai/onnxruntime/OnnxSparseTensor.java | 18 ++++++++-- .../main/java/ai/onnxruntime/OnnxTensor.java | 24 +++++++++++--- .../java/ai/onnxruntime/OnnxTensorLike.java | 16 +++++++++ .../main/java/ai/onnxruntime/OnnxValue.java | 9 ++++- .../ai/onnxruntime/OrtProviderOptions.java | 30 ++++++++++++++++- .../ai/onnxruntime/OrtTrainingSession.java | 33 +++++++++++++++++-- .../StringConfigProviderOptions.java | 1 + .../java/ai/onnxruntime/InferenceTest.java | 2 ++ .../java/ai/onnxruntime/OnnxTensorTest.java | 27 +++++++++++++-- .../test/java/ai/onnxruntime/TestHelpers.java | 12 +++++++ 12 files changed, 208 insertions(+), 18 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxMap.java b/java/src/main/java/ai/onnxruntime/OnnxMap.java index 354ebec61274d..cde9f0de4ff0a 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxMap.java +++ b/java/src/main/java/ai/onnxruntime/OnnxMap.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.logging.Logger; /** * A container for a map returned by {@link OrtSession#run(Map)}. @@ -16,6 +17,7 @@ * values: String, Long, Float, Double. */ public class OnnxMap implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxMap.class.getName()); static { try { @@ -107,6 +109,8 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { private final OnnxMapValueType valueType; + private boolean closed; + /** * Constructs an OnnxMap containing a reference to the native map along with the type information. * @@ -122,6 +126,7 @@ public static OnnxMapValueType mapFromOnnxJavaType(OnnxJavaType type) { this.info = info; this.stringKeys = info.keyType == OnnxJavaType.STRING; this.valueType = OnnxMapValueType.mapFromOnnxJavaType(info.valueType); + this.closed = false; } /** @@ -146,6 +151,7 @@ public OnnxValueType getType() { */ @Override public Map getValue() throws OrtException { + checkClosed(); Object[] keys = getMapKeys(); Object[] values = getMapValues(); HashMap map = new HashMap<>(OrtUtil.capacityFromSize(keys.length)); @@ -222,10 +228,27 @@ public String toString() { return "ONNXMap(size=" + size() + ",info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this map, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed map."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native String[] getStringKeys(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxSequence.java b/java/src/main/java/ai/onnxruntime/OnnxSequence.java index 93e1be21588b4..7722514b913b6 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSequence.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSequence.java @@ -8,6 +8,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.logging.Logger; /** * A sequence of {@link OnnxValue}s all of the same type. @@ -24,6 +25,7 @@ * */ public class OnnxSequence implements OnnxValue { + private static final Logger logger = Logger.getLogger(OnnxSequence.class.getName()); static { try { @@ -40,6 +42,8 @@ public class OnnxSequence implements OnnxValue { private final SequenceInfo info; + private boolean closed; + /** * Creates the wrapper object for a native sequence. * @@ -53,6 +57,7 @@ public class OnnxSequence implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } @Override @@ -76,6 +81,7 @@ public OnnxValueType getType() { */ @Override public List getValue() throws OrtException { + checkClosed(); if (info.sequenceOfMaps) { OnnxMap[] maps = getMaps(OnnxRuntime.ortApiHandle, nativeHandle, allocatorHandle); return Collections.unmodifiableList(Arrays.asList(maps)); @@ -110,10 +116,27 @@ public String toString() { return "OnnxSequence(info=" + info.toString() + ")"; } + @Override + public synchronized boolean isClosed() { + return closed; + } + /** Closes this sequence, releasing the native memory backing it and it's elements. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed sequence."); + } + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } } private native OnnxMap[] getMaps(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java index 53bd4c7f9b3e6..804fe742ad624 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxSparseTensor.java @@ -14,6 +14,7 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Arrays; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxSparseTensor. @@ -22,6 +23,7 @@ * different static inner class representing each type. */ public final class OnnxSparseTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxSparseTensor.class.getName()); private final SparseTensorType sparseTensorType; // Held to prevent deallocation while used in native code. @@ -198,6 +200,7 @@ public OnnxValueType getType() { @Override public SparseTensor getValue() throws OrtException { + checkClosed(); Buffer buffer = getValuesBuffer(); long[] indicesShape = getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); switch (sparseTensorType) { @@ -234,8 +237,13 @@ public SparseTensor getValue() throws OrtException { } @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed OnnxSparseTensor."); + } } /** @@ -257,6 +265,7 @@ public SparseTensorType getSparseTensorType() { * @return The indices. */ public Buffer getIndicesBuffer() { + checkClosed(); switch (sparseTensorType) { case COO: case CSRC: @@ -295,6 +304,7 @@ public Buffer getIndicesBuffer() { * @return The inner indices. */ public LongBuffer getInnerIndicesBuffer() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { LongBuffer buf = getInnerIndicesBuffer(OnnxRuntime.ortApiHandle, nativeHandle) @@ -320,6 +330,7 @@ public LongBuffer getInnerIndicesBuffer() { * @return The data buffer. */ public Buffer getValuesBuffer() { + checkClosed(); ByteBuffer buffer = getValuesBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); switch (info.type) { @@ -396,6 +407,7 @@ public Buffer getValuesBuffer() { * @return The indices shape. */ public long[] getIndicesShape() { + checkClosed(); return getIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } @@ -405,6 +417,7 @@ public long[] getIndicesShape() { * @return The indices shape. */ public long[] getInnerIndicesShape() { + checkClosed(); if (sparseTensorType == SparseTensorType.CSRC) { return getInnerIndicesShape(OnnxRuntime.ortApiHandle, nativeHandle); } else { @@ -420,6 +433,7 @@ public long[] getInnerIndicesShape() { * @return The values shape. */ public long[] getValuesShape() { + checkClosed(); return getValuesShape(OnnxRuntime.ortApiHandle, nativeHandle); } diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 0078adb6402f8..e1ee2c14fd9d1 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -14,12 +14,14 @@ import java.nio.LongBuffer; import java.nio.ShortBuffer; import java.util.Optional; +import java.util.logging.Logger; /** * A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be * returned as outputs. */ public class OnnxTensor extends OnnxTensorLike { + private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName()); /** * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does @@ -97,6 +99,7 @@ public OnnxValueType getType() { */ @Override public Object getValue() throws OrtException { + checkClosed(); if (info.isScalar()) { switch (info.type) { case FLOAT: @@ -144,16 +147,21 @@ public Object getValue() throws OrtException { @Override public String toString() { - return "OnnxTensor(info=" + info.toString() + ")"; + return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")"; } /** - * Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If - * it is backed by a buffer then the memory is released when the buffer is GC'd. + * Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it + * is backed by a buffer then the memory is released when the buffer is GC'd. */ @Override - public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } } /** @@ -165,6 +173,7 @@ public void close() { * @return A ByteBuffer copy of the OnnxTensor. */ public ByteBuffer getByteBuffer() { + checkClosed(); if (info.type != OnnxJavaType.STRING) { ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle); ByteBuffer output = ByteBuffer.allocate(buffer.capacity()); @@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() { * @return A FloatBuffer copy of the OnnxTensor. */ public FloatBuffer getFloatBuffer() { + checkClosed(); if (info.type == OnnxJavaType.FLOAT) { // if it's fp32 use the efficient copy. FloatBuffer buffer = getBuffer().asFloatBuffer(); @@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() { * @return A DoubleBuffer copy of the OnnxTensor. */ public DoubleBuffer getDoubleBuffer() { + checkClosed(); if (info.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = getBuffer().asDoubleBuffer(); DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity()); @@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() { * @return A ShortBuffer copy of the OnnxTensor. */ public ShortBuffer getShortBuffer() { + checkClosed(); if ((info.type == OnnxJavaType.INT16) || (info.type == OnnxJavaType.FLOAT16) || (info.type == OnnxJavaType.BFLOAT16)) { @@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() { * @return An IntBuffer copy of the OnnxTensor. */ public IntBuffer getIntBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT32) { IntBuffer buffer = getBuffer().asIntBuffer(); IntBuffer output = IntBuffer.allocate(buffer.capacity()); @@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() { * @return A LongBuffer copy of the OnnxTensor. */ public LongBuffer getLongBuffer() { + checkClosed(); if (info.type == OnnxJavaType.INT64) { LongBuffer buffer = getBuffer().asLongBuffer(); LongBuffer output = LongBuffer.allocate(buffer.capacity()); diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java index c2989fe296dc2..bbfd4e981ece2 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java @@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue { /** The size and shape information for this tensor. */ protected final TensorInfo info; + /** Is this value closed? */ + protected boolean closed; + /** * Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor). * @@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue { this.nativeHandle = nativeHandle; this.allocatorHandle = allocatorHandle; this.info = info; + this.closed = false; } /** @@ -59,4 +63,16 @@ long getNativeHandle() { public TensorInfo getInfo() { return info; } + + @Override + public synchronized boolean isClosed() { + return closed; + } + + /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OnnxValue"); + } + } } diff --git a/java/src/main/java/ai/onnxruntime/OnnxValue.java b/java/src/main/java/ai/onnxruntime/OnnxValue.java index 752a0e74267d3..e829bc80f09f6 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxValue.java +++ b/java/src/main/java/ai/onnxruntime/OnnxValue.java @@ -64,7 +64,14 @@ public enum OnnxValueType { */ public ValueInfo getInfo(); - /** Closes the OnnxValue, freeing it's native memory. */ + /** + * Checks if this value is closed (i.e., the native object has been released). + * + * @return True if the value is closed and the native object has been released. + */ + public boolean isClosed(); + + /** Closes the OnnxValue, freeing its native memory. */ @Override public void close(); diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java index 39a5121fad7a2..70af10ff8cd79 100644 --- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java @@ -5,11 +5,14 @@ package ai.onnxruntime; import java.io.IOException; +import java.util.logging.Logger; /** An abstract base class for execution provider options classes. */ // Note this lives in ai.onnxruntime to allow subclasses to access the OnnxRuntime.ortApiHandle // package private field. public abstract class OrtProviderOptions implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtProviderOptions.class.getName()); + static { try { OnnxRuntime.init(); @@ -21,6 +24,9 @@ public abstract class OrtProviderOptions implements AutoCloseable { /** The native pointer. */ protected final long nativeHandle; + /** Is the native object closed? */ + protected boolean closed; + /** * Constructs a OrtProviderOptions wrapped around a native pointer. * @@ -28,6 +34,7 @@ public abstract class OrtProviderOptions implements AutoCloseable { */ protected OrtProviderOptions(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -46,9 +53,30 @@ protected static long getApiHandle() { */ public abstract OrtProvider getProvider(); + /** + * Is the native object closed? + * + * @return True if the native object has been released. + */ + public synchronized boolean isClosed() { + return closed; + } + @Override public void close() { - close(OnnxRuntime.ortApiHandle, nativeHandle); + if (!closed) { + close(OnnxRuntime.ortApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing an already closed tensor."); + } + } + + /** Checks if the OrtProviderOptions is closed, if so throws {@link IllegalStateException}. */ + protected void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtProviderOptions"); + } } /** diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 49ddf29c22335..eeede3a1bed0b 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -12,6 +12,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.logging.Logger; /** * Wraps an ONNX training model and allows training and inference calls. @@ -1049,8 +1050,12 @@ private native void exportModelForInference( /** Wrapper class for the checkpoint state. */ static final class OrtCheckpointState implements AutoCloseable { + private static final Logger logger = Logger.getLogger(OrtCheckpointState.class.getName()); + final long nativeHandle; + private boolean closed; + /** * Wraps an object around the checkpoint native handle. * @@ -1058,6 +1063,7 @@ static final class OrtCheckpointState implements AutoCloseable { */ OrtCheckpointState(long nativeHandle) { this.nativeHandle = nativeHandle; + this.closed = false; } /** @@ -1097,6 +1103,7 @@ static OrtCheckpointState loadCheckpoint(String checkpoint) throws OrtException * @throws OrtException If the checkpoint failed to save. */ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException { + checkClosed(); Objects.requireNonNull(outputPath, "checkpoint path must not be null"); String outputStr = outputPath.toString(); saveCheckpoint( @@ -1115,6 +1122,7 @@ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtExc * @throws OrtException If the call failed. */ public void addProperty(String name, float value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1127,6 +1135,7 @@ public void addProperty(String name, float value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, int value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1139,6 +1148,7 @@ public void addProperty(String name, int value) throws OrtException { * @throws OrtException If the call failed. */ public void addProperty(String name, String value) throws OrtException { + checkClosed(); addProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value); } @@ -1152,6 +1162,7 @@ public void addProperty(String name, String value) throws OrtException { * @throws OrtException If the property does not exist, or is of the wrong type. */ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getFloatProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1169,6 +1180,7 @@ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtExc * @throws OrtException If the property does not exist, or is of the wrong type. */ public int getIntProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getIntProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1186,6 +1198,7 @@ public int getIntProperty(OrtAllocator allocator, String name) throws OrtExcepti * @throws OrtException If the property does not exist, or is of the wrong type. */ public String getStringProperty(OrtAllocator allocator, String name) throws OrtException { + checkClosed(); return getStringProperty( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -1194,9 +1207,25 @@ public String getStringProperty(OrtAllocator allocator, String name) throws OrtE name); } + /** Checks if the OrtCheckpointState is closed, if so throws {@link IllegalStateException}. */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtCheckpointState"); + } + } + + public synchronized boolean isClosed() { + return closed; + } + @Override - public void close() { - close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + public synchronized void close() { + if (!closed) { + close(OnnxRuntime.ortTrainingApiHandle, nativeHandle); + closed = true; + } else { + logger.warning("Closing a checkpoint twice"); + } } /* diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java index 02207b2949e54..961163035c9a6 100644 --- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java +++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java @@ -32,6 +32,7 @@ protected StringConfigProviderOptions(long nativeHandle) { * @throws OrtException If the addition failed. */ public void add(String key, String value) throws OrtException { + checkClosed(); Objects.requireNonNull(key, "Key must not be null"); Objects.requireNonNull(value, "Value must not be null"); options.put(key, value); diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index e975117fb75bd..f6f9da1829402 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -69,7 +69,9 @@ public void environmentTest() { // Checks that the environment instance is the same. OrtEnvironment otherEnv = OrtEnvironment.getEnvironment(); assertSame(env, otherEnv); + TestHelpers.quietLogger(OrtEnvironment.class); otherEnv = OrtEnvironment.getEnvironment("test-name"); + TestHelpers.loudLogger(OrtEnvironment.class); assertSame(env, otherEnv); } diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index a5f285ba86a14..c060cf73ecf14 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -4,6 +4,10 @@ */ package ai.onnxruntime; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + import ai.onnxruntime.platform.Fp16Conversions; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -97,8 +101,8 @@ public void testBufferCreation() throws OrtException { float[] arrValues = new float[] {0, 1, 2, 3, 4}; try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) { // array creation isn't backed by buffers - Assertions.assertFalse(t.ownsBuffer()); - Assertions.assertFalse(t.getBufferRef().isPresent()); + assertFalse(t.ownsBuffer()); + assertFalse(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); float[] output = new float[arrValues.length]; buf.get(output); @@ -146,7 +150,7 @@ public void testBufferCreation() throws OrtException { directBuffer.rewind(); try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) { // direct buffers don't trigger a copy - Assertions.assertFalse(t.ownsBuffer()); + assertFalse(t.ownsBuffer()); // tensors backed by buffers can get the buffer ref back out Assertions.assertTrue(t.getBufferRef().isPresent()); FloatBuffer buf = t.getFloatBuffer(); @@ -428,4 +432,21 @@ public void testBf16RoundTrip() { } } } + + @Test + public void testClose() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + long[] input = new long[] {1, 2, 3, 4, 5}; + OnnxTensor value = OnnxTensor.createTensor(env, input); + assertFalse(value.isClosed()); + long[] output = (long[]) value.getValue(); + assertArrayEquals(input, output); + value.close(); + // check use after close throws + assertThrows(IllegalStateException.class, value::getValue); + // check double close doesn't crash (emits warning) + TestHelpers.quietLogger(OnnxTensor.class); + value.close(); + TestHelpers.loudLogger(OnnxTensor.class); + } } diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 55d8169434d48..c13cdf222b15b 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -22,6 +22,8 @@ import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; import java.util.regex.Pattern; import org.junit.jupiter.api.Assertions; @@ -258,6 +260,16 @@ static void flattenStringBase(String[] input, List output) { output.addAll(Arrays.asList(input)); } + static void loudLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.INFO); + } + + static void quietLogger(Class loggerClass) { + Logger l = Logger.getLogger(loggerClass.getName()); + l.setLevel(Level.OFF); + } + public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); }