-
Notifications
You must be signed in to change notification settings - Fork 13
/
utils.go
114 lines (107 loc) · 2.68 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package gpt_bpe
import (
"strings"
)
type TrimDirection uint
const (
TrimTop TrimDirection = iota
TrimBottom TrimDirection = iota
TrimNone TrimDirection = iota
)
func (encoder *GPTEncoder) TrimNewlines(
tokens *Tokens,
direction TrimDirection,
limit uint,
) (*Tokens, error) {
var err error
trimmed := make(Tokens, 0)
if uint(len(*tokens)) <= limit {
return tokens, err
} else if direction == TrimNone {
return &trimmed, err
}
lines := strings.Split(encoder.Decode(tokens), "\n")
var start, end, step, idx int
switch direction {
case TrimTop:
start = len(lines) - 1
end = -1
step = -1
case TrimBottom:
start = 0
end = len(lines)
step = 1
}
accTokens := make(Tokens, 0)
for idx = start; idx != end; idx += step {
line := lines[idx]
switch direction {
case TrimTop:
line = "\n" + line
case TrimBottom:
line = line + "\n"
}
newTokens := encoder.Encode(&line)
if len(*newTokens)+len(accTokens) > int(limit) {
return &accTokens, err
} else {
switch direction {
case TrimTop:
accTokens = append(*newTokens, accTokens...)
case TrimBottom:
accTokens = append(accTokens, *newTokens...)
}
}
}
return &accTokens, err
}
func (encoder *GPTEncoder) AlignAndSizeTokens(
tokens *Tokens,
desiredLength int,
) (
alignedTokens Tokens,
endAt int,
) {
chunk := (*tokens)[0:desiredLength]
// We trim to valid tokens, as we don't want partials
// that are truncated multi-tokens.
trimmed := encoder.TrimTokens(&chunk)
trimmedLength := len(*trimmed)
isTrimmed := len(*trimmed) != len(chunk)
chunk = *trimmed
idx := trimmedLength
// We do a decode and reencode pass, as this can affect
// the size after a trim.
if isTrimmed {
decodedChunk := encoder.Decode(&chunk)
reencodedChunk := encoder.Encode(&decodedChunk)
chunk = *reencodedChunk
// See if there's any change in size that causes it to
// be smaller than the `desiredLength`.
roundtripRemainder := desiredLength - len(chunk)
if roundtripRemainder > 0 {
addlEnd := idx + roundtripRemainder
addlTokens := (*tokens)[idx:addlEnd]
trimmedAddl := encoder.TrimTokens(&addlTokens)
chunk = append(chunk, *trimmedAddl...)
idx += len(*trimmedAddl)
// Another decode/re-encode pass.
decodedChunk = encoder.Decode(&chunk)
reencodedChunk = encoder.Encode(&decodedChunk)
// Loop, dropping tokens one by one until we have
// valid tokens and we fit within `contextSize`.
for {
chunk = *reencodedChunk
if len(chunk) <= desiredLength &&
encoder.TokensReady(&chunk) {
break
}
chunk = chunk[:len(chunk)-1]
idx -= 1
decodedChunk = encoder.Decode(&chunk)
reencodedChunk = encoder.Encode(&decodedChunk)
}
}
}
return chunk, idx
}