forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[java] Multi-LoRA support (microsoft#22280)
### Description Java parts of Multi-LoRA support - microsoft#22046. ### Motivation and Context API equivalence with Python & C#. --------- Co-authored-by: Dmitri Smirnov <[email protected]>
- Loading branch information
1 parent
1fc2b94
commit 14d1bfc
Showing
5 changed files
with
381 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
/* | ||
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. | ||
* Copyright (c) Microsoft Corporation. All rights reserved. | ||
* Licensed under the MIT License. | ||
*/ | ||
package ai.onnxruntime; | ||
|
||
import java.io.IOException; | ||
import java.nio.ByteBuffer; | ||
import java.util.Objects; | ||
|
||
/** | ||
* A container for an adapter which can be supplied to {@link | ||
* OrtSession.RunOptions#addActiveLoraAdapter(OrtLoraAdapter)} to apply the adapter to a specific | ||
* execution of a model. | ||
*/ | ||
public final class OrtLoraAdapter implements AutoCloseable { | ||
static { | ||
try { | ||
OnnxRuntime.init(); | ||
} catch (IOException e) { | ||
throw new RuntimeException("Failed to load onnx-runtime library", e); | ||
} | ||
} | ||
|
||
private final long nativeHandle; | ||
|
||
private boolean closed = false; | ||
|
||
private OrtLoraAdapter(long nativeHandle) { | ||
this.nativeHandle = nativeHandle; | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter from a byte array. | ||
* | ||
* @param loraArray The LoRA stored in a byte array. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
public static OrtLoraAdapter create(byte[] loraArray) throws OrtException { | ||
return create(loraArray, null); | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter from a byte array. | ||
* | ||
* @param loraArray The LoRA stored in a byte array. | ||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the | ||
* allocator memory. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
static OrtLoraAdapter create(byte[] loraArray, OrtAllocator allocator) throws OrtException { | ||
Objects.requireNonNull(loraArray, "LoRA array must not be null"); | ||
long allocatorHandle = allocator == null ? 0 : allocator.handle; | ||
return new OrtLoraAdapter( | ||
createLoraAdapterFromArray(OnnxRuntime.ortApiHandle, loraArray, allocatorHandle)); | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer. | ||
* | ||
* @param loraBuffer The buffer to load. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
public static OrtLoraAdapter create(ByteBuffer loraBuffer) throws OrtException { | ||
return create(loraBuffer, null); | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter from a direct ByteBuffer. | ||
* | ||
* @param loraBuffer The buffer to load. | ||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the | ||
* allocator memory. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
static OrtLoraAdapter create(ByteBuffer loraBuffer, OrtAllocator allocator) throws OrtException { | ||
Objects.requireNonNull(loraBuffer, "LoRA buffer must not be null"); | ||
if (loraBuffer.remaining() == 0) { | ||
throw new OrtException("Invalid LoRA buffer, no elements remaining."); | ||
} else if (!loraBuffer.isDirect()) { | ||
throw new OrtException("ByteBuffer is not direct."); | ||
} | ||
long allocatorHandle = allocator == null ? 0 : allocator.handle; | ||
return new OrtLoraAdapter( | ||
createLoraAdapterFromBuffer( | ||
OnnxRuntime.ortApiHandle, | ||
loraBuffer, | ||
loraBuffer.position(), | ||
loraBuffer.remaining(), | ||
allocatorHandle)); | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter. | ||
* | ||
* @param adapterPath path to the adapter file that is going to be memory mapped. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
public static OrtLoraAdapter create(String adapterPath) throws OrtException { | ||
return create(adapterPath, null); | ||
} | ||
|
||
/** | ||
* Creates an instance of OrtLoraAdapter. | ||
* | ||
* @param adapterPath path to the adapter file that is going to be memory mapped. | ||
* @param allocator optional allocator or null. If supplied, adapter parameters are copied to the | ||
* allocator memory. | ||
* @throws OrtException If the native call failed. | ||
* @return An OrtLoraAdapter instance. | ||
*/ | ||
static OrtLoraAdapter create(String adapterPath, OrtAllocator allocator) throws OrtException { | ||
long allocatorHandle = allocator == null ? 0 : allocator.handle; | ||
return new OrtLoraAdapter( | ||
createLoraAdapter(OnnxRuntime.ortApiHandle, adapterPath, allocatorHandle)); | ||
} | ||
|
||
/** | ||
* Package accessor for native pointer. | ||
* | ||
* @return The native pointer. | ||
*/ | ||
long getNativeHandle() { | ||
return nativeHandle; | ||
} | ||
|
||
/** Checks if the OrtLoraAdapter is closed, if so throws {@link IllegalStateException}. */ | ||
void checkClosed() { | ||
if (closed) { | ||
throw new IllegalStateException("Trying to use a closed OrtLoraAdapter"); | ||
} | ||
} | ||
|
||
@Override | ||
public void close() { | ||
if (!closed) { | ||
close(OnnxRuntime.ortApiHandle, nativeHandle); | ||
closed = true; | ||
} else { | ||
throw new IllegalStateException("Trying to close an already closed OrtLoraAdapter"); | ||
} | ||
} | ||
|
||
private static native long createLoraAdapter( | ||
long apiHandle, String adapterPath, long allocatorHandle) throws OrtException; | ||
|
||
private static native long createLoraAdapterFromArray( | ||
long apiHandle, byte[] loraBytes, long allocatorHandle) throws OrtException; | ||
|
||
private static native long createLoraAdapterFromBuffer( | ||
long apiHandle, ByteBuffer loraBuffer, int bufferPos, int bufferSize, long allocatorHandle) | ||
throws OrtException; | ||
|
||
private static native void close(long apiHandle, long nativeHandle); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. | ||
* Licensed under the MIT License. | ||
*/ | ||
#include <jni.h> | ||
#include <string.h> | ||
#include "onnxruntime/core/session/onnxruntime_c_api.h" | ||
#include "OrtJniUtil.h" | ||
#include "ai_onnxruntime_OrtLoraAdapter.h" | ||
|
||
/* | ||
* Class: ai_onnxruntime_OrtLoraAdapter | ||
* Method: createLoraAdapter | ||
* Signature: (JLjava/lang/String;J)J | ||
*/ | ||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapter | ||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jstring loraPath, jlong allocatorHandle) { | ||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. | ||
const OrtApi* api = (const OrtApi*) apiHandle; | ||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; | ||
OrtLoraAdapter* lora; | ||
|
||
#ifdef _WIN32 | ||
const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, loraPath, NULL); | ||
size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, loraPath); | ||
wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); | ||
if (newString == NULL) { | ||
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); | ||
throwOrtException(jniEnv, 1, "Not enough memory"); | ||
return 0; | ||
} | ||
wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); | ||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(newString, allocator, &lora)); | ||
free(newString); | ||
(*jniEnv)->ReleaseStringChars(jniEnv, loraPath, cPath); | ||
#else | ||
const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, loraPath, NULL); | ||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapter(cPath, allocator, &lora)); | ||
(*jniEnv)->ReleaseStringUTFChars(jniEnv, loraPath, cPath); | ||
#endif | ||
|
||
return (jlong) lora; | ||
} | ||
|
||
/* | ||
* Class: ai_onnxruntime_OrtLoraAdapter | ||
* Method: createLoraAdapterFromBuffer | ||
* Signature: (JLjava/nio/ByteBuffer;IIJ)J | ||
*/ | ||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromBuffer | ||
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong allocatorHandle) { | ||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. | ||
const OrtApi* api = (const OrtApi*) apiHandle; | ||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; | ||
OrtLoraAdapter* lora; | ||
|
||
// Extract the buffer | ||
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); | ||
// Increment by bufferPos bytes | ||
bufferArr = bufferArr + bufferPos; | ||
|
||
// Create the adapter | ||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) bufferArr, bufferSize, allocator, &lora)); | ||
|
||
return (jlong) lora; | ||
} | ||
|
||
/* | ||
* Class: ai_onnxruntime_OrtLoraAdapter | ||
* Method: createLoraAdapterFromArray | ||
* Signature: (J[BJ)J | ||
*/ | ||
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtLoraAdapter_createLoraAdapterFromArray | ||
(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jbyteArray jLoraArray, jlong allocatorHandle) { | ||
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. | ||
const OrtApi* api = (const OrtApi*) apiHandle; | ||
OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; | ||
OrtLoraAdapter* lora; | ||
|
||
size_t loraLength = (*jniEnv)->GetArrayLength(jniEnv, jLoraArray); | ||
if (loraLength == 0) { | ||
throwOrtException(jniEnv, 2, "Invalid LoRA, the byte array is zero length."); | ||
return 0; | ||
} | ||
|
||
// Get a reference to the byte array elements | ||
jbyte* loraArr = (*jniEnv)->GetByteArrayElements(jniEnv, jLoraArray, NULL); | ||
checkOrtStatus(jniEnv, api, api->CreateLoraAdapterFromArray((const uint8_t*) loraArr, loraLength, allocator, &lora)); | ||
// Release the C array. | ||
(*jniEnv)->ReleaseByteArrayElements(jniEnv, jLoraArray, loraArr, JNI_ABORT); | ||
|
||
return (jlong) lora; | ||
} | ||
|
||
/* | ||
* Class: ai_onnxruntime_OrtLoraAdapter | ||
* Method: close | ||
* Signature: (JJ)V | ||
*/ | ||
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtLoraAdapter_close | ||
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong loraHandle) { | ||
(void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object. | ||
const OrtApi* api = (const OrtApi*) apiHandle; | ||
api->ReleaseLoraAdapter((OrtLoraAdapter*) loraHandle); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.