Skip to content

Commit

Permalink
Tokenizer's Interfaces Cleanup (dotnet#7001)
Browse files Browse the repository at this point in the history
* Tokenizer's Interfaces Cleanup

* Address the feedback

* Optimization
  • Loading branch information
tarekgh authored Feb 16, 2024
1 parent 64523e8 commit 4635a86
Show file tree
Hide file tree
Showing 11 changed files with 470 additions and 226 deletions.
99 changes: 58 additions & 41 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st

(Dictionary<string, int>? vocab1, Vec<(string, string)> merges) = ReadFile(vocabFile, mergesFile);
Vocab = vocab1 ?? new Dictionary<string, int>();
Cache = new Cache<string, Word>();

VocabReverse = new();

Expand Down Expand Up @@ -146,23 +147,33 @@ public Bpe(string vocabFile, string? mergesFile, string? unknownToken = null, st
/// Tokenize a sequence string to a list of tokens.
/// </summary>
/// <param name="sequence">The sequence to tokenize.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <returns>The list of tokens generated from the sequence tokenization.</returns>
public override IReadOnlyList<Token> Tokenize(string sequence)
public override IReadOnlyList<Token> Tokenize(string sequence, bool isSpecialToken = false)
{
if (sequence.Length == 0)
{
return EmptyTokensList;
}

if (!Dropout.HasValue)
{
return TokenizeWithCache(sequence);
}
return TokenizeWithCache(sequence);
}

Word word = MergeWord(sequence);
/// <summary>
/// Tokenize a split sequence string to a list of Ids and add them to the accumulatedIds list.
/// </summary>
/// <param name="sequence">The sequence to split.</param>
/// <param name="isSpecialToken">Indicate if the token is a special token.</param>
/// <param name="accumulatedIds">The list of accumulated tokenized Ids.</param>
public override void TokenizeToIds(string sequence, bool isSpecialToken, IList<int> accumulatedIds) => TokenizeToIdsWithCache(sequence, accumulatedIds);

return WordToTokens(ref word);
}
/// <summary>
/// Get the number of tokens that the input sequence will be encoded to.
/// </summary>
/// <param name="sequence">The text to tokenize.</param>
/// <param name="isSpecialToken">Indicate if the token is special token.</param>
/// <returns>The number of tokens that the input sequence will be encoded to.</returns>
public override int CountTokens(string sequence, bool isSpecialToken) => TokenizeToIdsWithCache(sequence, null);

/// <summary>
/// Map the token to tokenized Id.
Expand Down Expand Up @@ -195,14 +206,6 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
return null;
}

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand Down Expand Up @@ -332,7 +335,7 @@ internal string CharToString(char c)

internal Word MergeWord(string w)
{
Word word = Word.WithCapacity((int)w.Length);
Word word = Word.WithCapacity(w.Length);
(int Id, int Len)? unk = null;
int i = 0;

Expand All @@ -344,7 +347,7 @@ internal Word MergeWord(string w)
if (Char.IsHighSurrogate(w[i]) && i < w.Length - 1 && Char.IsLowSurrogate(w[i + 1]))
{
length = 2;
s = w.Substring(i, (int)length);
s = w.Substring(i, length);
}
else
{
Expand Down Expand Up @@ -403,7 +406,7 @@ internal Word MergeWord(string w)
}
}

i += (int)length;
i += length;
}

if (unk.HasValue)
Expand All @@ -415,45 +418,59 @@ internal Word MergeWord(string w)
return word;
}

// internal Word.Enumerator WordToTokens(Word word) => word.GetIterator(VocabReverse);
internal List<Token> WordToTokens(ref Word word)
internal List<Token> WordToTokens(ref Word word) => word.ToTokens(VocabReverse);

internal List<Token> TokenizeWithCache(string sequence)
{
List<Token> tokens = new(word.SymbolsCount);
Word word;
if (Cache is not null)
{
if (Cache.TryGet(sequence, out word))
{
return WordToTokens(ref word);
}

foreach (Token token in word.GetIterator(VocabReverse))
word = MergeWord(sequence);
Cache.Set(sequence, word);
}
else
{
tokens.Add(token);
word = MergeWord(sequence);
}

return tokens;
return WordToTokens(ref word);
}

internal List<Token> TokenizeWithCache(string sequence)
internal int WordToIds(ref Word word, IList<int>? accumulatedIds)
{
if (Cache is not null)
if (accumulatedIds is not null)
{
Word? hit = Cache.Get(sequence);
if (hit.HasValue)
{
Word w = hit.Value;
return WordToTokens(ref w);
}
word.PopulateIds(accumulatedIds);
}

Word word = MergeWord(sequence);
List<Token> tokens = WordToTokens(ref word);
return word.SymbolsCount;
}

internal int TokenizeToIdsWithCache(string sequence, IList<int>? accumulatedIds)
{
Word word;

if (Cache is not null)
{
if (Cache.TryGet(sequence, out Word hit))
{
return WordToIds(ref hit, accumulatedIds);
}

word = MergeWord(sequence);
Cache.Set(sequence, word);
}
else
{
word = MergeWord(sequence);
}

return tokens;
}

public override bool IsValidChar(char ch)
{
throw new NotImplementedException();
return WordToIds(ref word, accumulatedIds);
}

internal static readonly List<Token> EmptyTokensList = new();
Expand Down
19 changes: 7 additions & 12 deletions src/Microsoft.ML.Tokenizers/Model/Cache.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@

namespace Microsoft.ML.Tokenizers
{
internal sealed class Cache<TKey, TValue> where TKey : notnull
internal sealed class Cache<TKey, TValue> where TKey : notnull where TValue : notnull
{
internal Cache() : this(Bpe.DefaultCacheCapacity) { }

internal Cache(int capacity)
{
Capacity = capacity;
Map = new Dictionary<TKey, TValue>((int)Capacity);
Map = new Dictionary<TKey, TValue>(Capacity);
}

private readonly ReaderWriterLockSlim _cacheLock = new ReaderWriterLockSlim();
Expand All @@ -25,7 +25,7 @@ internal Cache(int capacity)

internal int Capacity { get; set; }

internal void Fresh() => Map = new Dictionary<TKey, TValue>((int)Capacity);
internal void Fresh() => Map = new Dictionary<TKey, TValue>(Capacity);

internal void Clear()
{
Expand Down Expand Up @@ -56,27 +56,22 @@ internal List<TValue> GetValues(IEnumerable<TKey> keys)
return values;
}

internal TValue? Get(TKey key)
internal bool TryGet(TKey key, out TValue value)
{
_cacheLock.EnterReadLock();
try
{
if (Map.TryGetValue(key, out TValue? value))
{
return value;
}
return Map.TryGetValue(key, out value!);
}
finally { _cacheLock.ExitReadLock(); }

return default;
}

internal void SetValues(IEnumerable<(TKey, TValue)> enteries)
internal void SetValues(IEnumerable<(TKey, TValue)> entries)
{
_cacheLock.EnterWriteLock();
try
{
foreach ((TKey, TValue) entry in enteries)
foreach ((TKey, TValue) entry in entries)
{
if (Capacity <= Map.Count)
{
Expand Down
Loading

0 comments on commit 4635a86

Please sign in to comment.