Skip to content

Commit

Permalink
Use int32 for index and add functions to get the control words
Browse files Browse the repository at this point in the history
  • Loading branch information
vikesh-raj committed Nov 6, 2020
1 parent 1be0fae commit d8a7484
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 17 deletions.
42 changes: 32 additions & 10 deletions sentencepiece/sentencepiece.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ const unknown string = "<unk>"

type slice struct {
score float32
index int64
index int32
start int
end int
}
Expand Down Expand Up @@ -42,7 +42,7 @@ type trieNode struct {
text string
level int
score float32
index int64
index int32
end bool
children map[rune]trieNode
}
Expand All @@ -60,21 +60,43 @@ func newTrieNode(text string, level int) trieNode {

// Sentencepiece holds the model
type Sentencepiece struct {
root trieNode
lowercase bool
unknown int64
root trieNode
lowercase bool
unknown int32
controlWords map[string]int32
}

// NewEmptySentencepiece creates an empty sentencepiece model
func NewEmptySentencepiece(lowercase bool) Sentencepiece {
return Sentencepiece{root: newTrieNode("", 0), lowercase: lowercase}
return Sentencepiece{
root: newTrieNode("", 0),
lowercase: lowercase,
unknown: 0,
controlWords: make(map[string]int32),
}
}

// SetUnknownIndex sets the index for the unknown id
func (s *Sentencepiece) SetUnknownIndex(index int64) {
func (s *Sentencepiece) SetUnknownIndex(index int32) {
s.unknown = index
}

// GetUnknownIndex gets the index of the unknown id
func (s *Sentencepiece) GetUnknownIndex() int32 {
return s.unknown
}

// SetControlWord sets the index for the given control word
func (s *Sentencepiece) SetControlWord(word string, index int32) {
s.controlWords[word] = index
}

// GetControlWord gets the index for the given control word
func (s *Sentencepiece) GetControlWord(word string) (int32, bool) {
v, ok := s.controlWords[word]
return v, ok
}

// Tokenize tokenizes text into pieces
func (s *Sentencepiece) Tokenize(text string) []Token {
text = normalize(text)
Expand All @@ -91,16 +113,16 @@ func (s *Sentencepiece) Tokenize(text string) []Token {
}

// TokenizeToIDs tokenizes text into ids from the vocab
func (s *Sentencepiece) TokenizeToIDs(text string) []int64 {
func (s *Sentencepiece) TokenizeToIDs(text string) []int32 {
tokens := s.Tokenize(text)
ids := make([]int64, len(tokens))
ids := make([]int32, len(tokens))
for i, token := range tokens {
ids[i] = token.ID
}
return ids
}

func (s *Sentencepiece) insert(word string, score float32, index int64) {
func (s *Sentencepiece) insert(word string, score float32, index int32) {
_, size := utf8.DecodeLastRuneInString(word)
charCount := len(word)
node := &s.root
Expand Down
13 changes: 8 additions & 5 deletions sentencepiece/sentencepiece_proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,19 @@ func NewSentencepieceFromFile(filename string, lowercase bool) (Sentencepiece, e
}

count := 0
unknownIndex := int64(0)
for i, piece := range model.GetPieces() {
typ := piece.GetType()
word := piece.GetPiece()
if word == unknown {
unknownIndex = int64(i)
switch typ {
case ModelProto_SentencePiece_NORMAL:
s.insert(word, piece.GetScore(), int32(i))
case ModelProto_SentencePiece_UNKNOWN:
s.SetUnknownIndex(int32(i))
case ModelProto_SentencePiece_CONTROL:
s.SetControlWord(word, int32(i))
}
s.insert(word, piece.GetScore(), int64(i))
count++
}

s.SetUnknownIndex(unknownIndex)
return s, nil
}
37 changes: 37 additions & 0 deletions sentencepiece/sentencepiece_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,43 @@ func TestTokenizationSPM(t *testing.T) {
}
}

func TestControlWords(t *testing.T) {
sp, err := NewSentencepieceFromFile("test_data/xlnet-base-cased-spiece.model", false)
if err != nil {
t.Errorf("Unable to create sentencepiece")
return
}

unknownIndex := sp.GetUnknownIndex()
if unknownIndex != 0 {
t.Errorf("Unknown index not equal to 0")
}

clsIndex, ok := sp.GetControlWord("<cls>")
if !ok || clsIndex != 3 {
t.Errorf("Control word [CLS] not correct : %d", clsIndex)
}

}

func TestControlWords2(t *testing.T) {
sp, err := NewSentencepieceFromFile("test_data/spm.model", true)
if err != nil {
t.Errorf("Unable to create sentencepiece")
return
}

unknownIndex := sp.GetUnknownIndex()
if unknownIndex != 1 {
t.Errorf("Unknown index not equal to 1")
}

clsIndex, ok := sp.GetControlWord("[CLS]")
if !ok || clsIndex != 2 {
t.Errorf("Control word [CLS] not correct")
}
}

func BenchmarkSentencePiece(b *testing.B) {
sp, err := NewSentencepieceFromFile("test_data/xlnet-base-cased-spiece.model", false)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions sentencepiece/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package sentencepiece

// Token holds a unit of a tokenized word
type Token struct {
ID int64
ID int32
Text string
}

type tokenOffset struct {
id int64
id int32
text string
start int
end int
Expand Down

0 comments on commit d8a7484

Please sign in to comment.