Skip to content

Commit

Permalink
Consider extra special tokens during BPE decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed Sep 25, 2022
1 parent 98c30d2 commit 8db7ba0
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 11 deletions.
1 change: 1 addition & 0 deletions pkg/models/bart/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type Config struct {
BadWordsIDs [][]int `json:"bad_words_ids"`
EarlyStopping bool `json:"early_stopping"`
NoRepeatNGramSize int `json:"no_repeat_ngram_size"`
ExtraSpecialTokens map[int]string `json:"extra_special_tokens"`
Cybertron struct {
Training bool `json:"training"`
PositionalEncoderOffset int `json:"positional_encoder_offset"`
Expand Down
7 changes: 5 additions & 2 deletions pkg/tasks/text2text/bart/text2text.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Text2Text struct {

type Tokenizer interface {
Tokenize(text string) ([]int, error)
Detokenize(tokenIds []int) string
Detokenize(tokenIds []int, stripPaddingTokens bool) string
}

// LoadText2Text returns a Text2Text loading the model, the embeddings and the tokenizer from a directory.
Expand Down Expand Up @@ -98,6 +98,9 @@ func loadBPETokenizer(path string, config bart.Config) (Tokenizer, error) {
if err != nil {
return nil, fmt.Errorf("failed to load bpe tokenizer for zero-shot: %w", err)
}
if config.ExtraSpecialTokens != nil {
tok.SetExtraSpecialTokens(config.ExtraSpecialTokens)
}
return &BPETokenizer{
BPETokenizer: tok,
EosTokenID: config.EosTokenID,
Expand Down Expand Up @@ -138,7 +141,7 @@ func (m *Text2Text) Generate(ctx context.Context, text string, opts *text2text.O
Scores: make([]float64, len(scores)),
}
for i, sequence := range sequences {
result.Texts[i], result.Scores[i] = m.Tokenizer.Detokenize(sequence), scores[i]
result.Texts[i], result.Scores[i] = m.Tokenizer.Detokenize(sequence, true), scores[i]
}
return result, nil
}
Expand Down
16 changes: 12 additions & 4 deletions pkg/tasks/text2text/bart/text2text_tokenizer_bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import "github.com/nlpodyssey/cybertron/pkg/tokenizers/bpetokenizer"

type BPETokenizer struct {
*bpetokenizer.BPETokenizer
EosTokenID, BosTokenID, PadTokenID, DecoderStartTokenID int
EosTokenID int
BosTokenID int
PadTokenID int
DecoderStartTokenID int
ExtraSpecialTokenIDs map[int]string
}

// Tokenize returns the token IDs of the input text applying the EOS pad token.
Expand All @@ -27,8 +31,12 @@ func (m *BPETokenizer) Tokenize(text string) ([]int, error) {
}

// Detokenize returns the text of the input token IDs removing the padding token.
func (m *BPETokenizer) Detokenize(tokenIds []int) string {
stripBadTokens := func(tokenIds []int) []int {
func (m *BPETokenizer) Detokenize(tokenIds []int, stripPaddingTokens bool) string {
if !stripPaddingTokens {
return m.BPETokenizer.Detokenize(tokenIds)
}

stripPaddingTokensFn := func(tokenIds []int) []int {
result := make([]int, 0, len(tokenIds))
for _, id := range tokenIds {
if id == m.EosTokenID || id == m.PadTokenID || id == m.BosTokenID || id == m.DecoderStartTokenID {
Expand All @@ -39,5 +47,5 @@ func (m *BPETokenizer) Detokenize(tokenIds []int) string {
return result
}

return m.BPETokenizer.Detokenize(stripBadTokens(tokenIds))
return m.BPETokenizer.Detokenize(stripPaddingTokensFn(tokenIds))
}
11 changes: 9 additions & 2 deletions pkg/tasks/text2text/bart/text2text_tokenizer_sentencepiece.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ import "github.com/nlpodyssey/cybertron/pkg/tokenizers/sentencepiece"

type SentencePieceTokenizer struct {
*sentencepiece.Tokenizer
EosTokenID, BosTokenID, PadTokenID, DecoderStartTokenID int
EosTokenID int
BosTokenID int
PadTokenID int
DecoderStartTokenID int
}

// Tokenize returns the token IDs of the input text applying the EOS pad token.
Expand All @@ -17,7 +20,11 @@ func (m *SentencePieceTokenizer) Tokenize(text string) ([]int, error) {
}

// Detokenize returns the text of the input token IDs removing the padding token.
func (m *SentencePieceTokenizer) Detokenize(tokenIds []int) string {
func (m *SentencePieceTokenizer) Detokenize(tokenIds []int, stripPaddingTokens bool) string {
if !stripPaddingTokens {
return m.Tokenizer.Detokenize(m.Tokenizer.IDsToTokens(tokenIds))
}

stripBadTokens := func(tokenIds []int) []int {
result := make([]int, 0, len(tokenIds))
for _, id := range tokenIds {
Expand Down
16 changes: 13 additions & 3 deletions pkg/tokenizers/bpetokenizer/tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import (

// BPETokenizer is a higher-level tokenizer, which includes byte-level pre-tokenization.
type BPETokenizer struct {
preTokenizer *bytelevelpretokenizer.ByteLevelPreTokenizer
model *bpemodel.BPEModel
vocab *vocabulary.Vocabulary
preTokenizer *bytelevelpretokenizer.ByteLevelPreTokenizer
model *bpemodel.BPEModel
vocab *vocabulary.Vocabulary
extraSpecialTokenIDs map[int]string
}

// New returns a new BPETokenizer.
Expand Down Expand Up @@ -93,6 +94,10 @@ func NewFromModelFolder(path string) (*BPETokenizer, error) {
return New(preTokenizer, model, vocab), nil
}

func (t *BPETokenizer) SetExtraSpecialTokens(extra map[int]string) {
t.extraSpecialTokenIDs = extra
}

// Tokenize performs byte-level pre-tokenization and BPE tokenization.
func (t *BPETokenizer) Tokenize(text string) ([]tokenizers.StringOffsetsPair, error) {
pts := pretokenizedstring.FromString(text)
Expand Down Expand Up @@ -179,6 +184,11 @@ func (t *BPETokenizer) Encode(text string) (*encodings.Encoding, error) {
func (t *BPETokenizer) Detokenize(ids []int) string {
var sb strings.Builder
for _, id := range ids {
if s, ok := t.extraSpecialTokenIDs[id]; ok {
sb.WriteString(s)
continue
}

if s, ok := t.vocab.GetString(id); ok {
sb.WriteString(s)
}
Expand Down

0 comments on commit 8db7ba0

Please sign in to comment.