From d13b6bf91709222c167a3b21dff8a556be7898cc Mon Sep 17 00:00:00 2001 From: rexwang8 Date: Fri, 25 Oct 2024 12:11:09 -0400 Subject: [PATCH] Test: Add Uint32 tests to dataset_tokenizer --- .../dataset_tokenizer_test.go | 204 ++++++++++++++++-- gpt_bpe_test.go | 12 ++ js/js.go | 6 +- 3 files changed, 205 insertions(+), 17 deletions(-) diff --git a/cmd/dataset_tokenizer/dataset_tokenizer_test.go b/cmd/dataset_tokenizer/dataset_tokenizer_test.go index c0299cf..9e83316 100644 --- a/cmd/dataset_tokenizer/dataset_tokenizer_test.go +++ b/cmd/dataset_tokenizer/dataset_tokenizer_test.go @@ -3,7 +3,6 @@ package main import ( "bufio" "bytes" - "encoding/binary" "errors" "fmt" "io" @@ -19,6 +18,7 @@ import ( "github.com/aws/aws-sdk-go/service/s3" "github.com/stretchr/testify/assert" "github.com/wbrown/gpt_bpe" + "github.com/wbrown/gpt_bpe/types" ) type SanitizerTest struct { @@ -66,24 +66,11 @@ var sanitizerTests = SanitizerTests{ const corpusPath = "../../resources/frankenstein.txt" -func TokensFromBin(bin *[]byte) *gpt_bpe.Tokens { - tokens := make(gpt_bpe.Tokens, 0) - buf := bytes.NewReader(*bin) - for { - var token gpt_bpe.Token - if err := binary.Read(buf, binary.LittleEndian, &token); err != nil { - break - } - tokens = append(tokens, token) - } - return &tokens -} - // DecodeBuffer // Decode Tokens from a byte array into a string. func DecodeBuffer(encoded *[]byte) (text string) { // First convert our bytearray into a uint32 `Token` array. - tokens := TokensFromBin(encoded) + tokens := types.TokensFromBin(encoded) // Decode our tokens into a string. var enc *gpt_bpe.GPTEncoder encoderString := "gpt2" @@ -736,3 +723,190 @@ func TestListObjectsRecursively(t *testing.T) { wg.Wait() // Wait for all goroutines to finish } + +func TestUInt16WithNoEnforce(t *testing.T) { + // Test if with Uint32 enforce disabled, + // using a Uint16 tokenizer works as intended with no padding. + + textsTokenizer := NewTextsTokenizer() + textsTokenizer.ContextSize = 2048 + textsTokenizer.TokenizerId = "gpt2" + textsTokenizer.EndOfText = "" + + // Test data + testString := "The quick brown fox jumps over the lazy dog." + expectedTokens := types.Tokens{464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256} + // Generate temp directory and test file + tempDir := os.TempDir() + testFile := tempDir + "/test.txt" + f, err := os.Create(testFile) + if err != nil { + log.Fatal(err) + } + // Write test string to file + _, err = f.WriteString(testString) + if err != nil { + log.Fatal(err) + } + f.Close() + defer os.Remove(testFile) + + reorderPaths := "" + sampling := 100 + outputFile := "base.chunk" + defer os.Remove(outputFile) + + enc, tokErr := textsTokenizer.InitTokenizer() + if tokErr != nil { + log.Fatal(tokErr) + } + + if texts, err := ReadTexts( + testFile, false, + reorderPaths, + 1, + ); err != nil { + log.Fatal(err) + } else { + begin := time.Now() + contexts, tokErr := textsTokenizer.TokenizeTexts( + texts, "./test", enc, + ) + if tokErr != nil { + log.Fatal(tokErr) + } + + total, writeErr := WriteContexts( + outputFile, + contexts, + enc, + sampling, + false, + false, + false, + ) + if writeErr != nil { + log.Fatal(writeErr) + } + duration := time.Since(begin).Seconds() + log.Printf( + "%d tokens in %0.2fs, %0.2f tokens/s", total, + duration, float64(total)/duration, + ) + } + // Read the encoded tokens from the output file + binaryData, err := os.ReadFile(outputFile) + if err != nil { + log.Fatal(err) + } + + // Convert to Tokens array + tokens := types.TokensFromBin(&binaryData) + + if len(*tokens) != len(expectedTokens) { + t.Fatalf( + "Expected %d tokens, but got %d", len(expectedTokens), + len(*tokens), + ) + } + if &expectedTokens != tokens { + t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens) + } + + // Verify the encoded tokens + assert.Equal(t, &expectedTokens, tokens) +} + +func TestUInt16WithEnforce(t *testing.T) { + // Test if with Uint32 enforce enabled, + // using a Uint16 tokenizer works as intended with padding + // ie X, 0 Y, 0, Z, 0 + + textsTokenizer := NewTextsTokenizer() + textsTokenizer.ContextSize = 2048 + textsTokenizer.TokenizerId = "gpt2" + textsTokenizer.EndOfText = "" + + // Test data + testString := "The quick brown fox jumps over the lazy dog." + expectedTokens := types.Tokens{464, 0, 2068, 0, 7586, 0, 21831, 0, 18045, 0, 625, 0, 262, 0, 16931, 0, 3290, 0, 13, 0, 50256, 0} + // Generate temp directory and test file + tempDir := os.TempDir() + testFile := tempDir + "/test.txt" + f, err := os.Create(testFile) + if err != nil { + log.Fatal(err) + } + // Write test string to file + _, err = f.WriteString(testString) + if err != nil { + log.Fatal(err) + } + f.Close() + defer os.Remove(testFile) + + reorderPaths := "" + sampling := 100 + outputFile := "base.chunk" + defer os.Remove(outputFile) + + enc, tokErr := textsTokenizer.InitTokenizer() + if tokErr != nil { + log.Fatal(tokErr) + } + + if texts, err := ReadTexts( + testFile, false, + reorderPaths, + 1, + ); err != nil { + log.Fatal(err) + } else { + begin := time.Now() + contexts, tokErr := textsTokenizer.TokenizeTexts( + texts, "./test", enc, + ) + if tokErr != nil { + log.Fatal(tokErr) + } + + total, writeErr := WriteContexts( + outputFile, + contexts, + enc, + sampling, + false, + true, + false, + ) + if writeErr != nil { + log.Fatal(writeErr) + } + duration := time.Since(begin).Seconds() + log.Printf( + "%d tokens in %0.2fs, %0.2f tokens/s", total, + duration, float64(total)/duration, + ) + } + // Read the encoded tokens from the output file + binaryData, err := os.ReadFile(outputFile) + if err != nil { + log.Fatal(err) + } + + // Convert to Tokens array + tokens := types.TokensFromBin(&binaryData) + + if len(*tokens) != len(expectedTokens) { + t.Fatalf( + "Expected %d tokens, but got %d", len(expectedTokens), + len(*tokens), + ) + } + if &expectedTokens != tokens { + t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens) + } + + // Verify the encoded tokens + assert.Equal(t, &expectedTokens, tokens) +} diff --git a/gpt_bpe_test.go b/gpt_bpe_test.go index 436874a..86d4859 100644 --- a/gpt_bpe_test.go +++ b/gpt_bpe_test.go @@ -539,6 +539,18 @@ func TestGPTEncoder_Encode(t *testing.T) { } } +func TestGPTEncode(t *testing.T) { + // This test is to check if the GPTEncoder is able to encode the tokens correctly + strin := "The quick brown fox jumps over the lazy dog." + expected := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 464, 28699, 318, 5443, 621, 262, 387, 260, 13} + encoded := gpt2Encoder.Encode(&strin) + fmt.Printf("Encoded: with commas:") + for _, token := range *encoded { + fmt.Printf("%v, ", token) + } + assert.Equal(t, *encoded, expected) +} + func TestGPTEncoder_StreamingEncode(t *testing.T) { // This test is to check if the GPTEncoder is able to encode the tokens correctly start := time.Now() diff --git a/js/js.go b/js/js.go index 990c238..c224565 100644 --- a/js/js.go +++ b/js/js.go @@ -3,9 +3,11 @@ package main //go:generate gopherjs build --minify import ( + "log" + "github.com/gopherjs/gopherjs/js" "github.com/wbrown/gpt_bpe" - "log" + "github.com/wbrown/gpt_bpe/types" ) var encoder gpt_bpe.GPTEncoder @@ -15,7 +17,7 @@ func Tokenize(text string) gpt_bpe.Tokens { } func Decode(arr []byte) string { - tokens := gpt_bpe.TokensFromBin(&arr) + tokens := types.TokensFromBin(&arr) return encoder.Decode(tokens) }