Skip to content

Commit

Permalink
Address pull request review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Feb 27, 2024
1 parent 5eb32c3 commit 0bcf8d3
Show file tree
Hide file tree
Showing 9 changed files with 101,023 additions and 247 deletions.
38 changes: 4 additions & 34 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ internal class NativeLib
public static extern IntPtr /* const in32_t* */ OgaGenerator_GetSequence(IntPtr /* const OgaGenerator* */ generator,
UIntPtr /* size_t */ index);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateSequences(out IntPtr /* OgaSequences** */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroySequences(IntPtr /* OgaSequences* */ sequences);

Expand Down Expand Up @@ -116,50 +119,17 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaSequences** */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern unsafe IntPtr /* OgaResult* */ OgaCreateAllocatedStrings(UIntPtr /* size_t */ numStrings,
UIntPtr* /* const size_t* */ stringLengths,
out IntPtr /* OgaStrings** */ strings);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaStringsGetBuffer(IntPtr /* OgaStrings* */ strings,
UIntPtr /* size_t */ index,
out IntPtr /* char** */ str);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyStrings(IntPtr /* OgaStrings* */ strings);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern UIntPtr /* size_t */ OgaStringsGetCount(IntPtr /* const OgaStrings* */ strings);

// This function returns the string at the given index of the OgaStrings object. The returned pointer
// is owned by the OgaStrings object and will be freed when the OgaStrings object is destroyed.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaStringsGetString(IntPtr /* const OgaStrings* */ strings,
UIntPtr /* size_t */ index,
out IntPtr /* const char** */ outStr);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateTokenizer(IntPtr /* const OgaModel* */ model,
out IntPtr /* OgaTokenizer** */ tokenizer);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern void OgaDestroyTokenizer(IntPtr /* OgaTokenizer* */ model);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerEncodeBatchStrings(IntPtr /* const OgaTokenizer* */ tokenizer,
IntPtr /* const OgaStrings* */ strings,
out IntPtr /* OgaSequences** */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerDecodeBatchStrings(IntPtr /* const OgaTokenizer* */ tokenizer,
IntPtr /* const OgaSequences* */ sequences,
out IntPtr /* OgaStrings** */ strings);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaTokenizerEncode(IntPtr /* const OgaTokenizer* */ tokenizer,
byte[] /* const char* */ strings,
out IntPtr /* OgaSequences** */ sequences);
IntPtr /* OgaSequences* */ sequences);


// This function is used to decode the given token into a string. The caller is responsible for freeing the
Expand Down
63 changes: 23 additions & 40 deletions src/csharp/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,48 @@ public Tokenizer(Model model)

public Sequences EncodeBatch(string[] strings)
{
IntPtr stringArray = IntPtr.Zero;
UIntPtr[] stringLengths = new UIntPtr[strings.Length];
for (int i = 0; i < strings.Length; i++)
{
stringLengths[i] = (UIntPtr)UTF8Encoding.UTF8.GetByteCount(strings[i]);
}

unsafe
{
fixed (UIntPtr* stringLengthsPtr = stringLengths)
{
Result.VerifySuccess(NativeMethods.OgaCreateAllocatedStrings((UIntPtr)strings.Length, stringLengthsPtr, out stringArray));
}
}

Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
for (ulong i = 0; i < (ulong)strings.Length; i++)
foreach (string str in strings)
{
Result.VerifySuccess(NativeMethods.OgaStringsGetBuffer(stringArray, (UIntPtr)i, out IntPtr buffer));
Utils.ToNativeBuffer(strings[i], buffer, (int)stringLengths[i]);
Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, Utils.ToUtf8(str), nativeSequences));
}
Result.VerifySuccess(NativeMethods.OgaTokenizerEncodeBatchStrings(_tokenizerHandle, stringArray, out IntPtr nativeSequences));

return new Sequences(nativeSequences);
}
finally
catch
{
NativeMethods.OgaDestroyStrings(stringArray);
NativeMethods.OgaDestroySequences(nativeSequences);
throw;
}
}

public string[] DecodeBatch(Sequences sequences)
{
IntPtr stringArray = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaTokenizerDecodeBatchStrings(_tokenizerHandle, sequences.Handle, out stringArray));
try
string[] result = new string[sequences.NumSequences];
for (ulong i = 0; i < sequences.NumSequences; i++)
{
ulong numStrings = NativeMethods.OgaStringsGetCount(stringArray).ToUInt64();
string[] result = new string[numStrings];
for (ulong i = 0; i < numStrings; i++)
{
IntPtr outStr = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaStringsGetString(stringArray, (UIntPtr)i, out outStr));
result[i] = Utils.FromUtf8(outStr);
}
return result;
}
finally
{
NativeMethods.OgaDestroyStrings(stringArray);
IntPtr outStr = IntPtr.Zero;
result[i] = Decode(sequences[i]);
}

return result;
}

public Sequences Encode(string str)
{
IntPtr nativeSequences = IntPtr.Zero;
Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, Utils.ToUtf8(str), out nativeSequences));
return new Sequences(nativeSequences);
Result.VerifySuccess(NativeMethods.OgaCreateSequences(out IntPtr nativeSequences));
try
{
Result.VerifySuccess(NativeMethods.OgaTokenizerEncode(_tokenizerHandle, Utils.ToUtf8(str), nativeSequences));
return new Sequences(nativeSequences);
}
catch
{
NativeMethods.OgaDestroySequences(nativeSequences);
throw;
}
}

public string Decode(ReadOnlySpan<int> sequence)
Expand Down
91 changes: 10 additions & 81 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer* p) {
return p->data_.get();
}

OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out) {
OGA_TRY
*out = reinterpret_cast<OgaSequences*>(std::make_unique<Generators::TokenSequences>().release());
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaSequencesCount(const OgaSequences* p) {
return reinterpret_cast<const Generators::TokenSequences*>(p)->size();
}
Expand Down Expand Up @@ -214,32 +221,11 @@ void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count)
delete strings;
}

OgaResult* OGA_API_CALL OgaTokenizerEncodeBatchStrings(const OgaTokenizer* p, const OgaStrings* string_array, OgaSequences** out) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
auto& strings = *reinterpret_cast<const std::vector<std::string>*>(string_array);
*out = reinterpret_cast<OgaSequences*>(std::make_unique<Generators::TokenSequences>(std::move(tokenizer.EncodeBatch(strings))).release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerDecodeBatchStrings(const OgaTokenizer* p, const OgaSequences* p_sequences, OgaStrings** out_strings) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
auto& sequences = *reinterpret_cast<const Generators::TokenSequences*>(p_sequences);

*out_strings = reinterpret_cast<OgaStrings*>(
std::make_unique<std::vector<std::string>>(tokenizer.DecodeBatch(sequences)).release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* p, const char* str, OgaSequences** out) {
OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer* p, const char* str, OgaSequences* sequences) {
OGA_TRY
auto& tokenizer = *reinterpret_cast<const Generators::Tokenizer*>(p);
auto out_tokens = std::make_unique<Generators::TokenSequences>();
out_tokens->emplace_back(std::move(tokenizer.Encode(str)));
*out = reinterpret_cast<OgaSequences*>(out_tokens.release());
auto& token_sequences = *reinterpret_cast<Generators::TokenSequences*>(sequences);
token_sequences.emplace_back(tokenizer.Encode(str));
return nullptr;
OGA_CATCH
}
Expand Down Expand Up @@ -276,63 +262,6 @@ OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream* p, int32_t
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateStrings(OgaStrings** out) {
OGA_TRY
*out = reinterpret_cast<OgaStrings*>(std::make_unique<std::vector<std::string>>().release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateAllocatedStrings(size_t num_strings, size_t* str_lengths, OgaStrings** out) {
OGA_TRY
auto strings = std::make_unique<std::vector<std::string>>(num_strings);
for (size_t i = 0; i < num_strings; i++) {
strings->at(i).resize(str_lengths[i]);
}
*out = reinterpret_cast<OgaStrings*>(strings.release());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaStringsGetBuffer(OgaStrings* string_array, size_t index, char** out) {
OGA_TRY
*out = &((*reinterpret_cast<std::vector<std::string>*>(string_array))[index])[0];
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaStringsAddString(OgaStrings* string_array, const char* str) {
OGA_TRY
reinterpret_cast<std::vector<std::string>*>(string_array)->push_back(str);
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaStringsAddStrings(OgaStrings* string_array, const char* const* strings, size_t count) {
OGA_TRY
auto& strs = *reinterpret_cast<std::vector<std::string>*>(string_array);
for (size_t i = 0; i < count; i++) {
strs.push_back(strings[i]);
}
return nullptr;
OGA_CATCH
}

size_t OGA_API_CALL OgaStringsGetCount(const OgaStrings* string_array) {
return reinterpret_cast<const std::vector<std::string>*>(string_array)->size();
}

OgaResult* OGA_API_CALL OgaStringsGetString(const OgaStrings* string_array, size_t index, const char** out) {
OGA_TRY
*out = (*reinterpret_cast<const std::vector<std::string>*>(string_array))[index].c_str();
return nullptr;
OGA_CATCH
}

void OGA_API_CALL OgaDestroyStrings(OgaStrings* string_array) {
delete reinterpret_cast<std::vector<std::string>*>(string_array);
}

void OGA_API_CALL OgaDestroyResult(OgaResult* p) {
delete p;
}
Expand Down
61 changes: 5 additions & 56 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ typedef struct OgaBuffer OgaBuffer;
typedef struct OgaSequences OgaSequences;
typedef struct OgaTokenizer OgaTokenizer;
typedef struct OgaTokenizerStream OgaTokenizerStream;
typedef struct OgaStrings OgaStrings;

/*
* \param[in] result OgaResult that contains the error message.
Expand All @@ -70,6 +69,8 @@ OGA_EXPORT size_t OGA_API_CALL OgaBufferGetDimCount(const OgaBuffer*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaBufferGetDims(const OgaBuffer*, size_t* dims, size_t dim_count);
OGA_EXPORT const void* OGA_API_CALL OgaBufferGetData(const OgaBuffer*);

OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateSequences(OgaSequences** out);

/*
* \param[in] sequences OgaSequences to be destroyed.
*/
Expand Down Expand Up @@ -228,12 +229,11 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizer(OgaTokenizer*);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatch(const OgaTokenizer*, const char** strings, size_t count, OgaSequences** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatch(const OgaTokenizer*, const OgaSequences* tokens, const char*** out_strings);
OGA_EXPORT void OGA_API_CALL OgaTokenizerDestroyStrings(const char** strings, size_t count);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncodeBatchStrings(const OgaTokenizer*, const OgaStrings* string_array, OgaSequences** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerDecodeBatchStrings(const OgaTokenizer*, const OgaSequences* tokens, OgaStrings** out_strings);

/* Encodes a single string and returns the encoded sequence of tokens. The sequence of tokens must be freed with OgaDestroySequences
/* Encodes a single string and adds the encoded sequence of tokens to the OgaSequences. The OgaSequences must be freed with OgaDestroySequences
when it is no longer needed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerEncode(const OgaTokenizer*, const char* str, OgaSequences* sequences);

/* Decode a single token sequence and returns a null terminated utf8 string. out_string must be freed with OgaDestroyString
*/
Expand All @@ -251,57 +251,6 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyTokenizerStream(OgaTokenizerStream*);
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaTokenizerStreamDecode(OgaTokenizerStream*, int32_t token, const char** out);

/*
* \brief Creates an object of type OgaStrings.
* \return The result of the operation. If the operation is successful, a nullptr is returned.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateStrings(OgaStrings** out);

/*
* \brief Creates an object of type OgaStrings with the given number of strings and the given length of each string.
* \return The result of the operation. If the operation is successful, a nullptr is returned.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateAllocatedStrings(size_t num_strings, size_t* str_lengths, OgaStrings** out);

OGA_EXPORT OgaResult* OGA_API_CALL OgaStringsGetBuffer(OgaStrings* string_array, size_t index, char** out);

/*
* \brief Destroys OgaStrings.
*/
OGA_EXPORT void OGA_API_CALL OgaDestroyStrings(OgaStrings* string_array);

/*
* \brief Adds the given string to the OgaStrings.
* \param[inout] string_array The string array to which the string is to be added
* \param[in] str The string to be added to the OgaStrings.
* \return The result of the operation. If the operation is successful, a nullptr is returned.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaStringsAddString(OgaStrings* string_array, const char* str);

/*
* \brief Adds the given strings to the OgaStrings.
* \param[inout] string_array The string array to which the strings are to be added
* \param[in] strings The strings to be added to the OgaStrings.
* \return The result of the operation. If the operation is successful, a nullptr is returned.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaStringsAddStrings(OgaStrings* string_array, const char* const* strings, size_t count);

/*
* \brief Gets the number of strings in the OgaStrings.
* \param[in] string_array The OgaStrings to get the count of the strings.
* \return The number of strings in the OgaStrings.
*/
OGA_EXPORT size_t OGA_API_CALL OgaStringsGetCount(const OgaStrings* string_array);

/*
* \brief Gets the string at the given index in the OgaStrings.
* \param[in] string_array The OgaStrings to get the string from.
* \param[in] index The index of the string to get.
* \param[out] out The string at the given index in the OgaStrings. The string will be valid until the OgaStrings is destroyed.
* \return The result of the operation. If the operation is successful, a nullptr is returned.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaStringsGetString(const OgaStrings* string_array, size_t index, const char** out);

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 0bcf8d3

Please sign in to comment.