diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8fe73ff69e169..f87cbc76ef141 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -11,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *
Must be a direct byte buffer.
+ *
+ * @param env The environment.
+ * @param modelBuffer The model protobuf as a byte buffer.
+ * @param allocator The allocator to use.
+ * @param options Session configuration options.
+ * @throws OrtException If the model was corrupted or some other error occurred in native code.
+ */
+ OrtSession(
+ OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
+ throws OrtException {
+ this(
+ createSession(
+ OnnxRuntime.ortApiHandle,
+ env.getNativeHandle(),
+ modelBuffer,
+ modelBuffer.position(),
+ modelBuffer.remaining(),
+ options.getNativeHandle()),
+ allocator);
+ }
+
/**
* Private constructor to build the Java object wrapped around a native session.
*
@@ -514,6 +540,15 @@ private static native long createSession(
private static native long createSession(
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;
+ private static native long createSession(
+ long apiHandle,
+ long envHandle,
+ ByteBuffer modelBuffer,
+ int bufferPos,
+ int bufferSize,
+ long optsHandle)
+ throws OrtException;
+
private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;
private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle)
diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c
index f4d5ab080cd31..ee8cdee659296 100644
--- a/java/src/main/native/ai_onnxruntime_OrtSession.c
+++ b/java/src/main/native/ai_onnxruntime_OrtSession.c
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include