Skip to content

Commit

Permalink
[C#] Expose Multi-Lora support in C# (microsoft#22281)
Browse files Browse the repository at this point in the history
### Description


### Motivation and Context
microsoft#22046
  • Loading branch information
yuslepukhin authored Oct 2, 2024
1 parent 4e15b22 commit 224f065
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
list(APPEND onnxruntime_customopregistration_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
if (${CMAKE_SYSTEM_NAME} MATCHES "AIX")
list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp)
list(APPEND onnxruntime_customopregistration_test_LIBS onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers iconv re2 libprotobuf-lite onnx_proto nsync_cpp)
endif()
AddTest(DYN
TARGET onnxruntime_customopregistration_test
Expand Down
80 changes: 79 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,29 @@ public struct OrtApi
public IntPtr ReleaseROCMProviderOptions;
public IntPtr CreateAndRegisterAllocatorV2;
public IntPtr RunAsync;
public IntPtr UpdateTensorRTProviderOptionsWithValue;
public IntPtr GetTensorRTProviderOptionsByName;
public IntPtr UpdateCUDAProviderOptionsWithValue;
public IntPtr GetCUDAProviderOptionsByName;
public IntPtr KernelContext_GetResource;
public IntPtr SetUserLoggingFunction;
public IntPtr ShapeInferContext_GetInputCount;
public IntPtr ShapeInferContext_GetInputTypeShape;
public IntPtr ShapeInferContext_GetAttribute;
public IntPtr ShapeInferContext_SetOutputTypeShape;
public IntPtr SetSymbolicDimensions;
public IntPtr ReadOpAttr;
public IntPtr SetDeterministicCompute;
public IntPtr KernelContext_ParallelFor;
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO_V2;
public IntPtr SessionOptionsAppendExecutionProvider_VitisAI;
public IntPtr KernelContext_GetScratchBuffer;
public IntPtr KernelInfoGetAllocator;
public IntPtr AddExternalInitializersFromFilesInMemory;
public IntPtr CreateLoraAdapter;
public IntPtr CreateLoraAdapterFromArray;
public IntPtr ReleaseLoraAdapter;
public IntPtr RunOptionsAddActiveLoraAdapter;
}

internal static class NativeMethods
Expand Down Expand Up @@ -540,6 +563,13 @@ static NativeMethods()
OrtReleaseROCMProviderOptions = (DOrtReleaseROCMProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseROCMProviderOptions, typeof(DOrtReleaseROCMProviderOptions));
OrtCreateAndRegisterAllocatorV2 = (DCreateAndRegisterAllocatorV2)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocatorV2, typeof(DCreateAndRegisterAllocatorV2));
OrtRunAsync = (DOrtRunAsync)Marshal.GetDelegateForFunctionPointer(api_.RunAsync, typeof(DOrtRunAsync));
CreateLoraAdapter = (DCreateLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.CreateLoraAdapter,
typeof(DCreateLoraAdapter));
CreateLoraAdapterFromArray = (DCreateLoraAdapterFromArray)Marshal.GetDelegateForFunctionPointer (api_.CreateLoraAdapterFromArray, typeof(DCreateLoraAdapterFromArray));
ReleaseLoraAdapter = (DReleaseLoraAdapter)Marshal.GetDelegateForFunctionPointer(api_.ReleaseLoraAdapter,
typeof(DReleaseLoraAdapter));
OrtRunOptionsAddActiveLoraAdapter = (DOrtRunOptionsAddActiveLoraAdapter)Marshal.GetDelegateForFunctionPointer(
api_.RunOptionsAddActiveLoraAdapter, typeof(DOrtRunOptionsAddActiveLoraAdapter));
}

internal class NativeLib
Expand Down Expand Up @@ -1263,7 +1293,49 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca

#endregion

#region RunOptions API
#region LoraAdapter API
/// <summary>
/// Memory maps the adapter file, wraps it into the adapter object
/// and returns it.
/// </summary>
/// <param name="adapter_path">absolute path to the adapter file</param>
/// <param name="allocator">optional device allocator or null</param>
/// <param name="lora_adapter">New LoraAdapter object</param>
/// <returns></returns>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapter(
byte[] adapter_path, // This takes const ORTCHAR_T* use GetPlatformSerializedString
IntPtr /* OrtAllocator */ allocator, // optional
out IntPtr lora_adapter
);
public static DCreateLoraAdapter CreateLoraAdapter;

/// <summary>
/// Creates LoraAdapter instance from a byte array that must
/// represents a valid LoraAdapter formst.
/// </summary>
/// <param name="bytes">bytes</param>
/// <param name="size">size in bytes</param>
/// <param name="allocator">optional device allocator</param>
/// <param name="lora_adapter">resuling LoraAdapter instance</param>
/// <returns></returns>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DCreateLoraAdapterFromArray(
byte[] bytes,
UIntPtr size,
IntPtr /* OrtAllocator */ allocator, // optional
out IntPtr lora_adapter
);
public static DCreateLoraAdapterFromArray CreateLoraAdapterFromArray;


[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DReleaseLoraAdapter(IntPtr /* OrtLoraAdapter* */ lora_adapter);
public static DReleaseLoraAdapter ReleaseLoraAdapter;

#endregion

#region RunOptions API

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateRunOptions(out IntPtr /* OrtRunOptions** */ runOptions);
Expand Down Expand Up @@ -1308,6 +1380,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsUnsetTerminate(IntPtr /* OrtRunOptions* */ options);
public static DOrtRunOptionsUnsetTerminate OrtRunOptionsUnsetTerminate;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtRunOptionsAddActiveLoraAdapter(
IntPtr /* OrtRunOptions* */ options,
IntPtr /* OrtLoraAdapter* */ lora_adapter);
public static DOrtRunOptionsAddActiveLoraAdapter OrtRunOptionsAddActiveLoraAdapter;

/// <summary>
/// Add run config entry
/// </summary>
Expand Down
81 changes: 81 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtLoraAdapter.shared.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Represents Lora Adapter in memory
/// </summary>
public class OrtLoraAdapter : SafeHandle
{
/// <summary>
/// Creates an instance of OrtLoraAdapter from file.
/// The adapter file is memory mapped. If allocator parameter
/// is provided, then lora parameters are copied to the memory
/// allocated by the specified allocator.
/// </summary>
/// <param name="adapterPath">path to the adapter file</param>
/// <param name="ortAllocator">optional allocator, can be null, must be a device allocator</param>
/// <returns>New instance of LoraAdapter</returns>
public static OrtLoraAdapter Create(string adapterPath, OrtAllocator ortAllocator)
{
var platformPath = NativeOnnxValueHelper.GetPlatformSerializedString(adapterPath);
var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapter(platformPath, allocatorHandle,
out IntPtr adapterHandle));
return new OrtLoraAdapter(adapterHandle);
}

/// <summary>
/// Creates an instance of OrtLoraAdapter from an array of bytes. The API
/// makes a copy of the bytes internally.
/// </summary>
/// <param name="bytes">array of bytes containing valid LoraAdapter format</param>
/// <param name="ortAllocator">optional device allocator or null</param>
/// <returns>new instance of LoraAdapter</returns>
public static OrtLoraAdapter Create(byte[] bytes, OrtAllocator ortAllocator)
{
var allocatorHandle = (ortAllocator != null) ? ortAllocator.Pointer : IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.CreateLoraAdapterFromArray(bytes,
new UIntPtr((uint)bytes.Length), allocatorHandle, out IntPtr adapterHandle));
return new OrtLoraAdapter(adapterHandle);
}

internal OrtLoraAdapter(IntPtr adapter)
: base(adapter, true)
{
}

internal IntPtr Handle
{
get
{
return handle;
}
}

#region SafeHandle

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }

/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtLoraAdapter
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.ReleaseLoraAdapter(handle);
handle = IntPtr.Zero;
return true;
}

#endregion
}
}
12 changes: 12 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/RunOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ public void AddRunConfigEntry(string configKey, string configValue)
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddRunConfigEntry(handle, utf8Key, utf8Value));
}

/// <summary>
/// Appends the specified lora adapter to the list of active lora adapters
/// for this RunOptions instance. All run calls with this instant will
/// make use of the activated Lora Adapters. An adapter is considered active
/// if it is added to RunOptions that are used during Run() calls.
/// </summary>
/// <param name="loraAdapter">Lora adapter instance</param>
public void AddActiveLoraAdapter(OrtLoraAdapter loraAdapter)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsAddActiveLoraAdapter(handle, loraAdapter.Handle));
}

#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1678,6 +1678,7 @@ private void TestInferenceSessionWithByteArray()
}
}


void TestCPUAllocatorInternal(InferenceSession session)
{
int device_id = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ private void RunPretrainedModel(InferenceSession session, RunOptions runOptions,

var orderedInputNames = new List<string>(inputContainer.Count);
var orderdedInputs = new List<OrtValue>(inputContainer.Count);
foreach(var pair in inputContainer)
foreach (var pair in inputContainer)
{
orderedInputNames.Add(pair.Key);
orderdedInputs.Add(pair.Value);
Expand Down Expand Up @@ -772,7 +772,7 @@ private void TestPreTrainedModels(string opsetDir, string modelName, bool useOrt
throw new Exception($"Opset {opset} Model {modelName}. Can't determine model file name. Found these :{modelNamesList}");
}

using(var runOptions = new RunOptions())
using (var runOptions = new RunOptions())
using (var session = new InferenceSession(onnxModelFileName))
{
string testDataDirNamePattern = "test_data*";
Expand Down Expand Up @@ -1077,7 +1077,7 @@ private static void VerifyContainerContent(IReadOnlyList<OrtValue> results,
Assert.Equal(result.GetStringTensorAsArray(), expectedValue.AsTensor<string>().ToArray(), new ExactComparer<string>());
break;
default:
Assert.Fail($"VerifyTensorResults cannot handle ElementType: { resultTypeShape.ElementDataType}");
Assert.Fail($"VerifyTensorResults cannot handle ElementType: {resultTypeShape.ElementDataType}");
break;
}
}
Expand Down Expand Up @@ -1251,6 +1251,88 @@ private void TestModelSerialization()
}
}

private static OrtLoraAdapter CreateLoraAdapterFromFile()
{
var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter");
return OrtLoraAdapter.Create(adapterPath, null);
}

private static OrtLoraAdapter CreateLoraAdapterFromArray()
{
var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter");
var adapterBytes = File.ReadAllBytes(adapterPath);
return OrtLoraAdapter.Create(adapterBytes, null);
}

// See tests below for running with Lora Adapters
[Fact(DisplayName = "TestInferenceWithBaseLoraModel")]
private void TestInferenceWithBaseLoraModel()
{
var modelPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx");

var inputShape = new long[] { 4, 4 };
var inputData = new float[16];
Array.Fill(inputData, 1);
using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape);

var expectedOutput = new float[] {
28, 32, 36, 40,
28, 32, 36, 40,
28, 32, 36, 40,
28, 32, 36, 40 };

using var session = new InferenceSession(modelPath);
using var runOptions = new RunOptions();

using var outputs = session.Run(runOptions, ["input_x"], [inputOrtValue], ["output"]);
Assert.Single(outputs);
var output = outputs[0].GetTensorDataAsSpan<float>();
Assert.Equal(expectedOutput.Length, output.Length);
Assert.Equal(expectedOutput, output.ToArray(), new FloatComparer());
}


private static void TestInferenceWithLoraAdapter(OrtLoraAdapter ortLoraAdapter)
{
var modelPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx");
var adapterPath = Path.Combine(Directory.GetCurrentDirectory(), "two_params_lora_model.onnx_adapter");

var inputShape = new long[] { 4, 4 };
var inputData = new float[16];
Array.Fill(inputData, 1);
using var inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape);

var expectedOutput = new float[] {
154, 176, 198, 220,
154, 176, 198, 220,
154, 176, 198, 220,
154, 176, 198, 220 };

using var session = new InferenceSession(modelPath);
using var runOptions = new RunOptions();
runOptions.AddActiveLoraAdapter(ortLoraAdapter);

using var outputs = session.Run(runOptions, ["input_x"], [inputOrtValue], ["output"]);
Assert.Single(outputs);
var output = outputs[0].GetTensorDataAsSpan<float>();
Assert.Equal(expectedOutput.Length, output.Length);
Assert.Equal(expectedOutput, output.ToArray(), new FloatComparer());
}

[Fact(DisplayName = "TestInferenceWithLoraAdapterFromFile")]
private void TestInferenceWithLoraAdapterFromFile()
{
using var ortAdapter = CreateLoraAdapterFromFile();
TestInferenceWithLoraAdapter(ortAdapter);
}

[Fact(DisplayName = "TestInferenceWithLoraAdapterFromArray")]
private void TestInferenceWithLoraAdapterFromArray()
{
using var ortAdapter = CreateLoraAdapterFromArray();
TestInferenceWithLoraAdapter(ortAdapter);
}

// TestGpu() will test
// - the CUDA EP on CUDA enabled builds
// - the DML EP on DML enabled builds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(OnnxRuntimeCSharpRoot)\..\onnxruntime\test\testdata\lora\two_params_lora_model.onnx">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(OnnxRuntimeCSharpRoot)\..\onnxruntime\test\testdata\lora\two_params_lora_model.onnx_adapter">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<BuildEnvVars Include="OnnxRuntimeBuildDirectory=$(OnnxRuntimeBuildDirectory)" />
</ItemGroup>

Expand Down

0 comments on commit 224f065

Please sign in to comment.