Skip to content

Commit

Permalink
Introduce the C# Tokenizer API (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Feb 27, 2024
1 parent 0a76f21 commit dacd02e
Show file tree
Hide file tree
Showing 14 changed files with 101,473 additions and 86 deletions.
17 changes: 9 additions & 8 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ def validate_cuda_home(cuda_home: str | bytes | os.PathLike | None):
"""Validate the CUDA home paths."""
validated_cuda_home = ""

if cuda_home or os.environ.get("CUDA_HOME"):
if cuda_home or os.environ.get("CUDA_HOME") or (is_windows() and os.environ.get("CUDA_PATH")):
validated_cuda_home = cuda_home if cuda_home else os.getenv("CUDA_HOME")
if not validated_cuda_home and is_windows():
validated_cuda_home = os.getenv("CUDA_PATH")

cuda_home_valid = os.path.exists(validated_cuda_home)

Expand All @@ -98,7 +100,7 @@ def build(
cuda_home: str | bytes | os.PathLike | None = None,
cmake_generator: str | None = None,
ort_home: str | bytes | os.PathLike | None = None,
enable_csharp: bool = False,
skip_csharp: bool = False,
build_dir: str | bytes | os.PathLike | None = None,
):
"""Generates the CMake build tree and builds the project.
Expand Down Expand Up @@ -157,7 +159,7 @@ def build(
make_command = ["cmake", "--build", ".", "--config", config]
run_subprocess(make_command, cwd=build_dir, env=env).check_returncode()

if enable_csharp:
if not skip_csharp:
if not is_windows():
raise RuntimeError("C# API is only supported on Windows.")

Expand Down Expand Up @@ -202,12 +204,11 @@ def build(
parser.add_argument(
"--cuda_home",
help="Path to CUDA home."
"Read from CUDA_HOME environment variable if --use_cuda is true and "
"--cuda_home is not specified.",
"Read from CUDA_HOME or CUDA_PATH environment variable if not specified.",
)
parser.add_argument("--skip_wheel", action="store_true", help="Skip building the Python wheel.")
parser.add_argument("--ort_home", default=None, help="Root directory of onnxruntime.")
parser.add_argument("--enable_csharp", action="store_true", help="Build the C# API.")
parser.add_argument("--skip_csharp", action="store_true", help="Skip building the C# API.")
parser.add_argument("--build_dir", default=None, help="Path to output directory.")
args = parser.parse_args()

Expand All @@ -217,6 +218,6 @@ def build(
cuda_home=args.cuda_home,
cmake_generator=args.cmake_generator,
ort_home=args.ort_home,
enable_csharp=args.enable_csharp,
skip_csharp=args.skip_csharp,
build_dir=args.build_dir,
)
)
5 changes: 5 additions & 0 deletions src/csharp/GeneratorParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public void SetInputIDs(ReadOnlySpan<int> inputIDs, ulong sequenceLength, ulong
}
}

public void SetInputSequences(Sequences sequences)
{
Result.VerifySuccess(NativeMethods.OgaGeneratorParamsSetInputSequences(_generatorParamsHandle, sequences.Handle));
}

~GeneratorParams()
{
Dispose(false);
Expand Down
46 changes: 46 additions & 0 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ internal class NativeLib
UIntPtr /* size_t */ sequenceLength,
UIntPtr /* size_t */ batchSize);


[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGeneratorParamsSetInputSequences(IntPtr /* OgaGeneratorParams* */ generatorParams,
IntPtr /* const OgaSequences* */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
Expand Down Expand Up @@ -83,6 +88,9 @@ internal class NativeLib
public static extern IntPtr /* const in32_t* */ OgaGenerator_GetSequence(IntPtr /* const OgaGenerator* */ generator,
UIntPtr /* size_t */ index);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateSequences(out IntPtr /* OgaSequences** */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroySequences(IntPtr /* OgaSequences* */ sequences);

Expand Down Expand Up @@ -110,5 +118,43 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaSequences** */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateTokenizer(IntPtr /* const OgaModel* */ model,
out IntPtr /* OgaTokenizer** */ tokenizer);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyTokenizer(IntPtr /* OgaTokenizer* */ model);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerEncode(IntPtr /* const OgaTokenizer* */ tokenizer,
byte[] /* const char* */ strings,
IntPtr /* OgaSequences* */ sequences);


// This function is used to decode the given token into a string. The caller is responsible for freeing the
// returned string using the OgaDestroyString function when it is no longer needed.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern unsafe IntPtr /* OgaResult* */ OgaTokenizerDecode(IntPtr /* const OgaTokenizer* */ tokenizer,
int* /* const int32_t* */ sequence,
UIntPtr /* size_t */ sequenceLength,
out IntPtr /* const char** */ outStr);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyString(IntPtr /* const char* */ str);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateTokenizerStream(IntPtr /* const OgaTokenizer* */ tokenizer,
out IntPtr /* OgaTokenizerStream** */ tokenizerStream);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyTokenizerStream(IntPtr /* OgaTokenizerStream* */ tokenizerStream);

// This function is used to decode the given token into a string. The returned pointer is freed when the
// OgaTokenizerStream object is destroyed.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerStreamDecode(IntPtr /* const OgaTokenizerStream* */ tokenizerStream,
int /* int32_t */ token,
out IntPtr /* const char** */ outStr);
}
}
115 changes: 115 additions & 0 deletions src/csharp/Tokenizer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;
using System.Text;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
public class Tokenizer : IDisposable
{
private IntPtr _tokenizerHandle;
private bool _disposed = false;

public Tokenizer(Model model)
{
Result.VerifySuccess(NativeMethods.OgaCreateTokenizer(model.Handle, out _tokenizerHandle));
}

public Sequences EncodeBatch(string[] strings)
{
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
foreach (string str in strings)
{
Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, Utils.ToUtf8(str), nativeSequences));
}

return new Sequences(nativeSequences);
}
catch
{
NativeMethods.OgaDestroySequences(nativeSequences);
throw;
}
}

public string[] DecodeBatch(Sequences sequences)
{
string[] result = new string[sequences.NumSequences];
for (ulong i = 0; i < sequences.NumSequences; i++)
{
result[i] = Decode(sequences[i]);
}

return result;
}

public Sequences Encode(string str)
{
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, Utils.ToUtf8(str), nativeSequences));
return new Sequences(nativeSequences);
}
catch
{
NativeMethods.OgaDestroySequences(nativeSequences);
throw;
}
}

public string Decode(ReadOnlySpan<int> sequence)
{
IntPtr outStr = IntPtr.Zero;
unsafe
{
fixed (int* sequencePtr = sequence)
{
Result.VerifySuccess(NativeMethods.OgaTokenizerDecode(_tokenizerHandle, sequencePtr, (UIntPtr)sequence.Length, out outStr));
}
}
try
{
return Utils.FromUtf8(outStr);
}
finally
{
NativeMethods.OgaDestroyString(outStr);
}
}

public TokenizerStream CreateStream()
{
IntPtr tokenizerStreamHandle = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaCreateTokenizerStream(_tokenizerHandle, out tokenizerStreamHandle));
return new TokenizerStream(tokenizerStreamHandle);
}


~Tokenizer()
{
Dispose(false);
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
NativeMethods.OgaDestroyTokenizer(_tokenizerHandle);
_tokenizerHandle = IntPtr.Zero;
_disposed = true;
}
}
}
50 changes: 50 additions & 0 deletions src/csharp/TokenizerStream.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Runtime.InteropServices;

namespace Microsoft.ML.OnnxRuntimeGenAI
{
public class TokenizerStream : IDisposable
{
private IntPtr _tokenizerStreamHandle;
private bool _disposed = false;

internal TokenizerStream(IntPtr tokenizerStreamHandle)
{
_tokenizerStreamHandle = tokenizerStreamHandle;
}

internal IntPtr Handle { get { return _tokenizerStreamHandle; } }

public string Decode(int token)
{
IntPtr decodedStr = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaTokenizerStreamDecode(_tokenizerStreamHandle, token, out decodedStr));
return Utils.FromUtf8(decodedStr);
}

~TokenizerStream()
{
Dispose(false);
}

public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}

protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
NativeMethods.OgaDestroyTokenizerStream(_tokenizerStreamHandle);
_tokenizerStreamHandle = IntPtr.Zero;
_disposed = true;
}
}
}
4 changes: 2 additions & 2 deletions src/csharp/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ namespace Microsoft.ML.OnnxRuntimeGenAI
{
internal class Utils
{
public static byte[] EmptyByteArray = new byte[] { 0 };
internal static byte[] EmptyByteArray = new byte[] { 0 };

public static byte[] ToUtf8(string str)
internal static byte[] ToUtf8(string str)
{
if (string.IsNullOrEmpty(str))
return EmptyByteArray;
Expand Down
16 changes: 16 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer* p) {
return p->data_.get();
}

OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out) {
OGA_TRY
*out = reinterpret_cast<OgaSequences*>(std::make_unique<Generators::TokenSequences>().release());
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* p) {
return reinterpret_cast<const Generators::TokenSequences*>(p)->size();
}
Expand Down Expand Up @@ -214,6 +221,15 @@ void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count)
delete strings;
}

OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* p, const char* str, OgaSequences* sequences) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
auto& token_sequences = *reinterpret_cast<Generators::TokenSequences*>(sequences);
token_sequences.emplace_back(tokenizer.Encode(str));
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer* p, const int32_t* tokens, size_t token_count, const char** out_string) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
Expand Down
15 changes: 14 additions & 1 deletion src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ OGA_EXPORT size_t OGA_API_CALL OgaBufferGetDimCount(const OgaBuffer*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaBufferGetDims(const OgaBuffer*, size_t* dims, size_t dim_count);
OGA_EXPORT const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer*);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out);

/*
* \param[in] sequences OgaSequences to be destroyed.
*/
Expand Down Expand Up @@ -145,7 +147,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetMaxLength(OgaGeneratorPa

/*
* \brief Sets the input ids for the generator params. The input ids are used to seed the generation.
* \param[in] params The generator params to set the input ids on.
* \param[in] generator_params The generator params to set the input ids on.
* \param[in] input_ids The input ids array of size input_ids_count = batch_size * sequence_length.
* \param[in] input_ids_count The total number of input ids.
* \param[in] sequence_length The sequence length of the input ids.
Expand All @@ -155,6 +157,12 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetMaxLength(OgaGeneratorPa
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputIDs(OgaGeneratorParams* generator_params, const int32_t* input_ids,
size_t input_ids_count, size_t sequence_length, size_t batch_size);

/*
* \brief Sets the input id sequences for the generator params. The input id sequences are used to seed the generation.
* \param[in] generator_params The generator params to set the input ids on.
* \param[in] sequences The input id sequences.
* \return OgaResult containing the error message if the setting of the input id sequences failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGeneratorParams* generator_params, const OgaSequences* sequences);

OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperInputFeatures(OgaGeneratorParams*, const int32_t* inputs, size_t count);
Expand Down Expand Up @@ -222,6 +230,11 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer*,
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatch(const OgaTokenizer*, const OgaSequences* tokens, const char*** out_strings);
OGA_EXPORT void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count);

/* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences
when it is no longer needed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

/* Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecode(const OgaTokenizer*, const int32_t* tokens, size_t token_count, const char** out_string);
Expand Down
Loading

0 comments on commit dacd02e

Please sign in to comment.