Skip to content

Commit

Permalink
[C#] Expose Adapters API and add tests (#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin authored Oct 23, 2024
1 parent 28f77a3 commit 147a311
Show file tree
Hide file tree
Showing 14 changed files with 348 additions and 34 deletions.
62 changes: 62 additions & 0 deletions src/csharp/Adapters.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
/// <summary>
/// A container of adapters.
/// </summary>
public class Adapters : SafeHandle
{
/// <summary>
/// Creates a container for adapters
/// used to load, unload and hold them.
/// Throws on error.
/// </summary>
/// <param name="model">Reference to a loaded model</param>
/// <returns>new Adapters object</returns>
public Adapters(Model model) : base(IntPtr.Zero, true)
{
Result.VerifySuccess(NativeMethods.OgaCreateAdapters(model.Handle, out handle));
}

/// <summary>
/// Method that loads adapter data and assigns it a nmae that
/// it can be referred to. Throws on error.
/// </summary>
/// <param name="adapterPath">file path to load</param>
/// <param name="adapterName">adapter name</param>
public void LoadAdapter(string adapterPath, string adapterName)
{
Result.VerifySuccess(NativeMethods.OgaLoadAdapter(handle,
StringUtils.ToUtf8(adapterPath), StringUtils.ToUtf8(adapterName)));
}

/// <summary>
/// Unload the adatper that was loaded by the LoadAdapter method.
/// Throws on error.
/// </summary>
/// <param name="adapterName"></param>
public void UnloadAdapter(string adapterName)
{
Result.VerifySuccess(NativeMethods.OgaUnloadAdapter(handle, StringUtils.ToUtf8(adapterName)));
}

internal IntPtr Handle { get { return handle; } }

/// <summary>
/// Implement SafeHandle override
/// </summary>
public override bool IsInvalid => handle == IntPtr.Zero;

protected override bool ReleaseHandle()
{
NativeMethods.OgaDestroyAdapters(handle);
handle = IntPtr.Zero;
return true;
}
}
}
1 change: 0 additions & 1 deletion src/csharp/Exceptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Collections.Generic;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
27 changes: 27 additions & 0 deletions src/csharp/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,33 @@ public ReadOnlySpan<int> GetSequence(ulong index)
}
}

/// <summary>
/// Fetches and returns the output tensor with the given name.
/// Throw on error
/// </summary>
/// <param name="outputName"></param>
/// <returns>a disposable instance of Tensor</returns>
public Tensor GetOutput(string outputName)
{
Result.VerifySuccess(NativeMethods.OgaGenerator_GetOutput(_generatorHandle,
StringUtils.ToUtf8(outputName),
out IntPtr outputTensor));
return new Tensor(outputTensor);
}

/// <summary>
/// Activates one of the loaded adapters.
/// Throws on error.
/// </summary>
/// <param name="adapters">Adapters container</param>
/// <param name="adapterName">adapter name that was previously loaded</param>
public void SetActiveAdapter(Adapters adapters, string adapterName)
{
Result.VerifySuccess(NativeMethods.OgaSetActiveAdapter(_generatorHandle,
adapters.Handle,
StringUtils.ToUtf8(adapterName)));
}

~Generator()
{
Dispose(false);
Expand Down
1 change: 0 additions & 1 deletion src/csharp/MultiModalProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
25 changes: 25 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,15 @@ internal class NativeLib
public static extern IntPtr /* const in32_t* */ OgaGenerator_GetSequenceData(IntPtr /* const OgaGenerator* */ generator,
UIntPtr /* size_t */ index);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerator_GetOutput(IntPtr /* cosnt OgaGenerator* */ generator,
byte[] outputName, out IntPtr tensor);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaSetActiveAdapter(IntPtr /* OgaGenerator* */ generator,
IntPtr /* OgaAdapters* */ adapters,
byte[] /*const char**/ adapterName);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateSequences(out IntPtr /* OgaSequences** */ sequences);

Expand Down Expand Up @@ -262,5 +271,21 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyStringArray(IntPtr /* OgaStringArray* */ stringArray);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateAdapters(IntPtr /* const OgaModel* */ model,
out IntPtr /* OgaAdapters** */ adapters);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyAdapters(IntPtr /* OgaAdapters* */ adapters);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaLoadAdapter(IntPtr /* OgaAdapters* */ adapters,
byte[] /* const char* */ adapterFilePath,
byte[] /* const char* */ adapterName);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaUnloadAdapter(IntPtr /* OgaAdapters* */ adapters,
byte[] /* const char* */ adapterName);
}
}
1 change: 0 additions & 1 deletion src/csharp/Result.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
1 change: 0 additions & 1 deletion src/csharp/Sequences.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
51 changes: 50 additions & 1 deletion src/csharp/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;
using System.Diagnostics;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down Expand Up @@ -34,6 +34,14 @@ public Tensor(IntPtr data, Int64[] shape, ElementType type)
{
Result.VerifySuccess(NativeMethods.OgaCreateTensorFromBuffer(data, shape, (UIntPtr)shape.Length, type, out _tensorHandle));
}

internal Tensor(IntPtr tensorHandle)
{
Debug.Assert(tensorHandle != IntPtr.Zero);
_tensorHandle = tensorHandle;
_disposed = false;
}

internal IntPtr Handle { get { return _tensorHandle; } }

~Tensor()
Expand All @@ -54,6 +62,47 @@ public Int64[] Shape()
return shape;
}

/// <summary>
/// Computes number of elements in the tensor
/// given the shape
/// </summary>
/// <param name="shape">shape</param>
/// <returns>product of dimensions</returns>
public static Int64 ElementsFromShape(Int64[] shape)
{
Int64 size = 1;
foreach (Int64 dim in shape)
{
size *= dim;
}
return size;
}

/// <summary>
/// Computes and returns number of elements in the tensor
/// </summary>
/// <returns></returns>
public Int64 NumElements()
{
return ElementsFromShape(Shape());
}

/// <summary>
/// Return a ReadOnlySpan to tensor data
/// no type checks are made
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns>read only span</returns>
public ReadOnlySpan<T> GetData<T>()
{
var elements = NumElements();
Result.VerifySuccess(NativeMethods.OgaTensorGetData(Handle, out IntPtr data));
unsafe
{
return new ReadOnlySpan<T>(data.ToPointer(), (int)elements);
}
}

public void Dispose()
{
Dispose(true);
Expand Down
2 changes: 0 additions & 2 deletions src/csharp/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
1 change: 0 additions & 1 deletion src/csharp/TokenizerStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
Expand Down
1 change: 0 additions & 1 deletion src/csharp/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace Microsoft.ML.OnnxRuntimeGenAI
Expand Down
51 changes: 49 additions & 2 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <models/model.h>
#include <iostream>
#include <ort_genai.h>
#include "../src/span.h"

#ifndef MODEL_PATH
#define MODEL_PATH "../../test/test_models/"
#endif
Expand Down Expand Up @@ -372,8 +374,16 @@ TEST(CAPITests, TopKTopPCAPI) {

#endif // TEST_PHI2

TEST(CAPITests, AdaptersTest) {
#if TEST_PHI2
TEST(CAPITests, AdaptersTest) {

#ifdef USE_CUDA
using OutputType = Ort::Float16_t;
#else
using OutputType = float;
#endif


// The python unit tests create the adapter model.
// In order to run this test, the python unit test must have been run first.
auto model = OgaModel::Create(MODEL_PATH "adapters");
Expand All @@ -392,6 +402,32 @@ TEST(CAPITests, AdaptersTest) {
for (auto& string : input_strings)
tokenizer->Encode(string, *input_sequences);

// Run base scenario
size_t output_size = 0;
std::vector<int64_t> output_shape;
std::vector<OutputType> base_output;
{
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
params->SetInputSequences(*input_sequences);

auto generator = OgaGenerator::Create(*model, *params);

while (!generator->IsDone()) {
generator->ComputeLogits();
generator->GenerateNextToken();
}

auto logits = generator->GetOutput("logits");
output_shape = logits->Shape();
output_size = static_cast<size_t>(std::accumulate(output_shape.begin(), output_shape.end(), 1LL,
std::multiplies<int64_t>()));
base_output.reserve(output_size);
std::span<const OutputType> src(reinterpret_cast<const OutputType*>(logits->Data()), output_size);
std::copy(src.begin(), src.end(), std::back_inserter(base_output));
}
// Run scenario with an adapter
// We are expecting a difference in output
{
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
Expand All @@ -404,13 +440,24 @@ TEST(CAPITests, AdaptersTest) {
generator->ComputeLogits();
generator->GenerateNextToken();
}

auto logits = generator->GetOutput("logits");
const auto shape = logits->Shape();
// Expecting the same shape
ASSERT_TRUE(std::equal(output_shape.begin(), output_shape.end(), shape.begin(), shape.end()));

const auto size = static_cast<size_t>(std::accumulate(shape.begin(), shape.end(), 1LL,
std::multiplies<int64_t>()));
ASSERT_EQ(output_size, size);
std::span<const OutputType> src(reinterpret_cast<const OutputType*>(logits->Data()), size);
ASSERT_FALSE(std::equal(base_output.begin(), base_output.end(), src.begin(), src.end()));
}

// Unload the adapter. Will error out if the adapter is still active.
// So, the generator must go out of scope before the adapter can be unloaded.
adapters->UnloadAdapter("adapters_a_and_b");
#endif
}
#endif

TEST(CAPITests, AdaptersTestMultipleAdapters) {
#if TEST_PHI2
Expand Down
Loading

0 comments on commit 147a311

Please sign in to comment.