Skip to content

Commit

Permalink
[java] Multi-LoRA support (microsoft#22280)
Browse files Browse the repository at this point in the history
### Description
Java parts of Multi-LoRA support - microsoft#22046.

### Motivation and Context
API equivalence with Python & C#.

---------

Co-authored-by: Dmitri Smirnov <[email protected]>
  • Loading branch information
Craigacp and yuslepukhin authored Oct 1, 2024
1 parent 1fc2b94 commit 14d1bfc
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 8 deletions.
161 changes: 161 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtLoraAdapter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Objects;

/**
* A container for an adapter which can be supplied to {@link
* OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific
* execution of a model.
*/
public final class OrtLoraAdapter implements AutoCloseable {
static {
try {
OnnxRuntime.init();
} catch (IOException e) {
throw new RuntimeException("Failed to load onnx-runtime library", e);
}
}

private final long nativeHandle;

private boolean closed = false;

private OrtLoraAdapter(long nativeHandle) {
this.nativeHandle = nativeHandle;
}

/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(byte[] loraArray) throws OrtException {
return create(loraArray, null);
}

/**
* Creates an instance of OrtLoraAdapter from a byte array.
*
* @param loraArray The LoRA stored in a byte array.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraArray, "LoRA array must not be null");
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle));
}

/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException {
return create(loraBuffer, null);
}

/**
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer.
*
* @param loraBuffer The buffer to load.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException {
Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null");
if (loraBuffer.remaining() == 0) {
throw new OrtException("Invalid LoRA buffer, no elements remaining.");
} else if (!loraBuffer.isDirect()) {
throw new OrtException("ByteBuffer is not direct.");
}
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapterFromBuffer(
OnnxRuntime.ortApiHandle,
loraBuffer,
loraBuffer.position(),
loraBuffer.remaining(),
allocatorHandle));
}

/**
* Creates an instance of OrtLoraAdapter.
*
* @param adapterPath path to the adapter file that is going to be memory mapped.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
public static OrtLoraAdapter create(String adapterPath) throws OrtException {
return create(adapterPath, null);
}

/**
* Creates an instance of OrtLoraAdapter.
*
* @param adapterPath path to the adapter file that is going to be memory mapped.
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the
* allocator memory.
* @throws OrtException If the native call failed.
* @return An OrtLoraAdapter instance.
*/
static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException {
long allocatorHandle = allocator == null ? 0 : allocator.handle;
return new OrtLoraAdapter(
createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle));
}

/**
* Package accessor for native pointer.
*
* @return The native pointer.
*/
long getNativeHandle() {
return nativeHandle;
}

/** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */
void checkClosed() {
if (closed) {
throw new IllegalStateException("Trying to use a closed OrtLoraAdapter");
}
}

@Override
public void close() {
if (!closed) {
close(OnnxRuntime.ortApiHandle, nativeHandle);
closed = true;
} else {
throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter");
}
}

private static native long createLoraAdapter(
long apiHandle, String adapterPath, long allocatorHandle) throws OrtException;

private static native long createLoraAdapterFromArray(
long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException;

private static native long createLoraAdapterFromBuffer(
long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle)
throws OrtException;

private static native void close(long apiHandle, long nativeHandle);
}
34 changes: 26 additions & 8 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,8 @@ public static class SessionOptions implements AutoCloseable {
* The optimisation level to use. Needs to be kept in sync with the GraphOptimizationLevel enum
* in the C API.
*
* <p>See <a
* href="https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* <p>See <a href=
* "https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html">Graph
* Optimizations</a> for more details.
*/
public enum OptLevel {
Expand Down Expand Up @@ -684,6 +684,7 @@ public enum ExecutionMode {
SEQUENTIAL(0),
/** Executes some nodes in parallel. */
PARALLEL(1);

private final int id;

ExecutionMode(int id) {
Expand Down Expand Up @@ -1391,17 +1392,19 @@ private native void addConfigEntry(
throws OrtException;

/*
* To use additional providers, you must build ORT with the extra providers enabled. Then call one of these
* functions to enable them in the session:
* To use additional providers, you must build ORT with the extra providers enabled. Then call
* one of these functions to enable them in the session:
*
* OrtSessionOptionsAppendExecutionProvider_CPU
* OrtSessionOptionsAppendExecutionProvider_CUDA
* OrtSessionOptionsAppendExecutionProvider_ROCM
* OrtSessionOptionsAppendExecutionProvider_<remaining providers...>
* The order they care called indicates the preference order as well. In other words call this method
* on your most preferred execution provider first followed by the less preferred ones.
* If none are called Ort will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException
* The order they are called indicates the preference order as well. In other words call this
* method on your most preferred execution provider first followed by the less preferred ones.
* If none are called ORT will use its internal CPU execution provider.
*
* If a backend is unavailable then it throws an OrtException.
*/
private native void addCPU(long apiHandle, long nativeHandle, int useArena) throws OrtException;

Expand Down Expand Up @@ -1579,6 +1582,18 @@ public void addRunConfigEntry(String key, String value) throws OrtException {
addRunConfigEntry(OnnxRuntime.ortApiHandle, nativeHandle, key, value);
}

/**
* Adds the specified adapter to the list of active adapters for this run.
*
* @param loraAdapter valid OrtLoraAdapter object
* @throws OrtException of the native library call failed
*/
public void addActiveLoraAdapter(OrtLoraAdapter loraAdapter) throws OrtException {
checkClosed();
loraAdapter.checkClosed();
addActiveLoraAdapter(OnnxRuntime.ortApiHandle, nativeHandle, loraAdapter.getNativeHandle());
}

/** Checks if the RunOptions is closed, if so throws {@link IllegalStateException}. */
private void checkClosed() {
if (closed) {
Expand Down Expand Up @@ -1619,6 +1634,9 @@ private native void setTerminate(long apiHandle, long nativeHandle, boolean term
private native void addRunConfigEntry(
long apiHandle, long nativeHandle, String key, String value) throws OrtException;

private native void addActiveLoraAdapter(
long apiHandle, long nativeHandle, long loraAdapterHandle) throws OrtException;

private static native void close(long apiHandle, long nativeHandle);
}

Expand Down
106 changes: 106 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtLoraAdapter.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
#include <string.h>
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtLoraAdapter.h"

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapter
* Signature: (JLjava/lang/String;J)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

#ifdef _WIN32
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL);
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath);
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t));
if (newString == NULL) {
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
throwOrtException(jniEnv, 1, "Not enough memory");
return 0;
}
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora));
free(newString);
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath);
#else
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora));
(*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath);
#endif

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromBuffer
* Signature: (JLjava/nio/ByteBuffer;IIJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;

// Create the adapter
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora));

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: createLoraAdapterFromArray
* Signature: (J[BJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle;
OrtLoraAdapter* lora;

size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray);
if (loraLength == 0) {
throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length.");
return 0;
}

// Get a reference to the byte array elements
jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL);
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora));
// Release the C array.
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT);

return (jlong) lora;
}

/*
* Class: ai_onnxruntime_OrtLoraAdapter
* Method: close
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) {
(void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle);
}

12 changes: 12 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtSession_RunOptions.c
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addRunConf
(*jniEnv)->ReleaseStringUTFChars(jniEnv, valueStr, value);
}

/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: addActiveLoraAdapter
* Signature: (JJJ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024RunOptions_addActiveLoraAdapter
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong nativeHandle, jlong loraHandle) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
checkOrtStatus(jniEnv, api, api->RunOptionsAddActiveLoraAdapter((OrtRunOptions*) nativeHandle, (OrtLoraAdapter*) loraHandle));
}

/*
* Class: ai_onnxruntime_OrtSession_RunOptions
* Method: setTerminate
Expand Down
Loading

0 comments on commit 14d1bfc

Please sign in to comment.