Skip to content

Commit

Permalink
Merge pull request #53 from coreweave/eta/fix-32-bit
Browse files Browse the repository at this point in the history
fix: Fix the handling of 32-bit tokens in several places
  • Loading branch information
wbrown authored Oct 13, 2024
2 parents 698d0bc + 818d932 commit a59353e
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 100 deletions.
97 changes: 66 additions & 31 deletions cmd/dataset_tokenizer/dataset_tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,11 +631,17 @@ func getAndCheckToken(
s = strings.ReplaceAll(s, "\\n", "\n")
token := t.Get(s)
if token == nil {
tokens := t.Encode(&s)
if len(*tokens) != 1 {
return 0, fmt.Errorf("'%s' is not a valid token for %s", s, id)
tokens := *t.Encode(&s)
// Also allow a single "real" token surrounded by an EosToken and/or a BosToken
if len(tokens) == 1 ||
len(tokens) == 2 && tokens[1] == t.EosToken && tokens[0] != t.BosToken {
return tokens[0], nil
} else if len(tokens) == 3 &&
tokens[0] == t.BosToken && tokens[2] == t.EosToken ||
len(tokens) == 2 && tokens[0] == t.BosToken && tokens[1] != t.EosToken {
return tokens[1], nil
} else {
return (*tokens)[0], nil
return 0, fmt.Errorf("'%s' is not a valid token for %s", s, id)
}
} else {
return *token, nil
Expand Down Expand Up @@ -758,11 +764,14 @@ func (tt TextsTokenizer) handleExclusions(
func (tt TextsTokenizer) TokenizeTexts(
texts chan namedRuneReader,
indexPath string,
tokenizerPtr *gpt_bpe.GPTEncoder,
) (chan gpt_bpe.Tokens, error) {
tokenizerPtr, tokErr := tt.InitTokenizer()

if tokErr != nil {
return nil, tokErr
var tokErr error
if tokenizerPtr == nil {
tokenizerPtr, tokErr = tt.InitTokenizer()
if tokErr != nil {
return nil, tokErr
}
}
tokenizer := *tokenizerPtr
var endOfText gpt_bpe.Token
Expand Down Expand Up @@ -853,11 +862,14 @@ func (tt TextsTokenizer) TokenizeTexts(
// that returns tokenized contexts that are fixed and padded out to
// `contextSize`.
func (tt TextsTokenizer) TokenizeTextsToContexts(
texts chan namedRuneReader,
texts chan namedRuneReader, tokenizerPtr *gpt_bpe.GPTEncoder,
) (chan gpt_bpe.Tokens, error) {
tokenizerPtr, tokErr := tt.InitTokenizer()
if tokErr != nil {
return nil, tokErr
var tokErr error
if tokenizerPtr == nil {
tokenizerPtr, tokErr = tt.InitTokenizer()
if tokErr != nil {
return nil, tokErr
}
}
tokenizer := *tokenizerPtr
var padToken, endOfText gpt_bpe.Token
Expand Down Expand Up @@ -888,7 +900,7 @@ func (tt TextsTokenizer) TokenizeTextsToContexts(

var boundary gpt_bpe.Token
if tt.Boundary == "" {
boundary = 65535
boundary = 0xFFFFFFFF
} else {
var boundaryErr error
boundary, boundaryErr = getAndCheckToken(
Expand Down Expand Up @@ -1046,7 +1058,7 @@ func (tt TextsTokenizer) TokenizeTextsToContexts(
// We were given a hard index to use as the chunk boundary,
// and it may not be a complete unicode character, so we
// need to align it to a valid unicode character.
if boundary == 65535 && doUnitrim {
if boundary == 0xFFFFFFFF && doUnitrim {
// Ensure that our next chunk is aligned to valid
// unicode.
_, offset := tokenizer.AlignAndSizeTokens(
Expand Down Expand Up @@ -1116,16 +1128,23 @@ func WriteContexts(
sampling int,
shuffle bool,
enforceUint32 bool,
showContexts bool,
) (int, error) {
totalTokens := 0
var useUint32 bool
// We only use uint32 if we're enforcing it and vocab size is greater than
// 65536.
if encoder != nil {
if enforceUint32 && len(encoder.Encoder) > 65536 {
useUint32 := enforceUint32
// Use uint32 if explicitly requested or if the vocab size is greater than 65536.
if !useUint32 {
if encoder == nil {
return 0, fmt.Errorf("WriteContexts called with unknown encoder; cannot determine output byte width")
} else if len(encoder.Encoder) > 65536 {
useUint32 = true
log.Println("warning: tokenizer vocab too large for 16-bit, outputting as 32-bit")
}
}
if showContexts && encoder == nil {
showContexts = false
log.Println("warning: no encoder info, cannot show contexts")
}

// create file AND filepath if not exists
if err := os.MkdirAll(filepath.Dir(outPath), os.ModePerm); err != nil {
Expand Down Expand Up @@ -1158,6 +1177,11 @@ func WriteContexts(
doKeepSampling := sampling == 100 || (samplingIdx%lcd < skipEveryX)
if doKeepSampling {
sampledContexts <- context
if showContexts {
fmt.Println(len(context))
fmt.Println("======================================")
fmt.Println(encoder.Decode(&context))
}
}
samplingIdx += 1
}
Expand Down Expand Up @@ -1187,7 +1211,10 @@ func WriteContexts(
if !ok {
break
}
binContext := context.ToBin(useUint32)
binContext, err := context.ToBin(useUint32)
if err != nil {
return totalTokens, err
}
// We keep track of the final file position
if endpos == 0 {
// On the first context, we discern the context size and make the
Expand Down Expand Up @@ -1421,7 +1448,7 @@ func main() {
)
enforceUint32 := flag.Bool(
"uint32_enforce", false,
"enforce uint32 tokenization if needed (vocab size > 65535)",
"output tokens as uint32 instead of uint16 (required for vocabs with over 2^16 tokens)",
)

flag.Parse()
Expand Down Expand Up @@ -1513,7 +1540,9 @@ func main() {
)
}
}
if _, tokErr := textsTokenizer.InitTokenizer(); tokErr != nil {

encoder, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

Expand Down Expand Up @@ -1591,13 +1620,19 @@ func main() {
contexts, tokErr = textsTokenizer.TokenizeTexts(
textReaders,
indexFilePath,
encoder,
)
if tokErr != nil {
log.Fatal(tokErr)
}
total, writeErr := WriteContexts(
outputFilePath, contexts,
nil, sampling, false, *enforceUint32,
outputFilePath,
contexts,
encoder,
sampling,
false,
*enforceUint32,
*showContexts,
)
if writeErr != nil {
log.Fatal(writeErr)
Expand All @@ -1611,20 +1646,20 @@ func main() {
var contexts chan gpt_bpe.Tokens
var tokErr error
contexts, tokErr = textsTokenizer.TokenizeTextsToContexts(
textReaders,
textReaders, encoder,
)
if tokErr != nil {
log.Fatal(tokErr)
}
var enc *gpt_bpe.GPTEncoder
if *showContexts {
enc, _ = textsTokenizer.InitTokenizer()
}
var writeErr error
numTokens, writeErr = WriteContexts(
*outputFile, contexts, enc,
*outputFile,
contexts,
encoder,
sampling,
*reorderPaths == "shuffle", *enforceUint32,
*reorderPaths == "shuffle",
*enforceUint32,
*showContexts,
)
if writeErr != nil {
log.Fatal(writeErr)
Expand Down
67 changes: 45 additions & 22 deletions cmd/dataset_tokenizer/dataset_tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,7 @@ func TestEncodeText1(t *testing.T) {
textsTokenizer.BoundaryBegin = false
var enc *gpt_bpe.GPTEncoder

reorderPaths := ""
sampling := 100
enforceUint32 := true // Only if needed
outputFile := "base.chunk"

enc, tokErr := textsTokenizer.InitTokenizer()
Expand All @@ -252,14 +250,19 @@ func TestEncodeText1(t *testing.T) {
}()

begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(reader, "./test")
contexts, tokErr := textsTokenizer.TokenizeTexts(reader, "./test", enc)
if tokErr != nil {
log.Fatal("Error tokenizing texts: ", tokErr)
}

total, writeErr := WriteContexts(
outputFile, contexts, enc, sampling, reorderPaths == "shuffle",
enforceUint32,
outputFile,
contexts,
enc,
sampling,
false,
false,
false,
)
if writeErr != nil {
log.Fatal("Error writing contexts: ", writeErr)
Expand Down Expand Up @@ -310,7 +313,8 @@ func TestSampling50(t *testing.T) {
sampling := 100
outputFile := "base.chunk"

if _, tokErr := textsTokenizer.InitTokenizer(); tokErr != nil {
enc, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

Expand All @@ -323,18 +327,22 @@ func TestSampling50(t *testing.T) {
} else {
begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(
texts, "./test",
texts, "./test", enc,
)
if tokErr != nil {
log.Fatal(tokErr)
}

var enc *gpt_bpe.GPTEncoder
// *showContexts = true

total, writeErr := WriteContexts(
outputFile, contexts, enc, sampling,
reorderPaths == "shuffle", false,
outputFile,
contexts,
enc,
sampling,
false,
false,
false,
)
all1 += total
if writeErr != nil {
Expand All @@ -361,7 +369,8 @@ func TestSampling50(t *testing.T) {
sampling = 50
outputFile = "samp50.chunk"

if _, tokErr := textsTokenizer.InitTokenizer(); tokErr != nil {
enc, tokErr = textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

Expand All @@ -374,16 +383,20 @@ func TestSampling50(t *testing.T) {
} else {
begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(
texts2, "./test",
texts2, "./test", enc,
)
if tokErr != nil {
log.Fatal(tokErr)
}
var enc *gpt_bpe.GPTEncoder
// *showContexts = true

total2, writeErr := WriteContexts(
outputFile, contexts, enc, sampling, reorderPaths == "shuffle",
outputFile,
contexts,
enc,
sampling,
reorderPaths == "shuffle",
false,
false,
)
all2 += total2
Expand Down Expand Up @@ -430,7 +443,8 @@ func TestShuffle(t *testing.T) {
sampling := 100
outputFile := "noshuffle.chunk"

if _, tokErr := textsTokenizer.InitTokenizer(); tokErr != nil {
enc, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

Expand All @@ -442,16 +456,20 @@ func TestShuffle(t *testing.T) {
} else {
begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(
texts, "./test",
texts, "./test", enc,
)
if tokErr != nil {
log.Fatal(tokErr)
}
var enc *gpt_bpe.GPTEncoder
// *showContexts = true

total, writeErr := WriteContexts(
outputFile, contexts, enc, sampling, reorderPaths == "shuffle",
outputFile,
contexts,
enc,
sampling,
false,
false,
false,
)
all1 += total
Expand Down Expand Up @@ -479,7 +497,8 @@ func TestShuffle(t *testing.T) {
sampling = 100
outputFile = "shuffle.chunk"

if _, tokErr := textsTokenizer.InitTokenizer(); tokErr != nil {
enc2, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

Expand All @@ -491,16 +510,20 @@ func TestShuffle(t *testing.T) {
} else {
begin := time.Now()
contexts2, tokErr := textsTokenizer.TokenizeTexts(
texts2, "./test",
texts2, "./test", enc2,
)
if tokErr != nil {
log.Fatal(tokErr)
}
var enc2 *gpt_bpe.GPTEncoder
// *showContexts = true

total2, writeErr := WriteContexts(
outputFile, contexts2, enc2, sampling, reorderPaths == "shuffle",
outputFile,
contexts2,
enc2,
sampling,
true,
false,
false,
)
all2 += total2
Expand Down
Loading

0 comments on commit a59353e

Please sign in to comment.