From 02e00dc02365f126b01fb89f3a61f7731da755cb Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 15 Sep 2024 18:31:55 -0400 Subject: [PATCH] [java] Adding ability to load a model from a memory mapped byte buffer (#20062) ### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with #19599. --- .../java/ai/onnxruntime/OrtEnvironment.java | 49 ++++++++++++++++++- .../main/java/ai/onnxruntime/OrtSession.java | 35 +++++++++++++ .../main/native/ai_onnxruntime_OrtSession.c | 25 +++++++++- .../java/ai/onnxruntime/InferenceTest.java | 31 ++++++++++++ 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8fe73ff69e169..f87cbc76ef141 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -11,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *

Must be a direct byte buffer. + * + * @param env The environment. + * @param modelBuffer The model protobuf as a byte buffer. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. + */ + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); + } + /** * Private constructor to build the Java object wrapped around a native session. * @@ -514,6 +540,15 @@ private static native long createSession( private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index f4d5ab080cd31..ee8cdee659296 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la return (jlong)session; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: createSession + * Signature: (JJLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the session + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session)); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: createSession diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 3340a2e5e9f3a..f76e1b3b20e19 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,10 +20,14 @@ import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; import java.io.File; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException { } } + @Test + public void createSessionFromByteBuffer() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r"); + FileChannel channel = file.getChannel()) { + MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size()); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBuffer, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } + } + } + } + @Test public void createSessionFromByteArray() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");