Skip to content

Commit

Permalink
Test: Add Uint32 tests to dataset_tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Rexwang8 committed Oct 25, 2024
1 parent baeedfc commit d13b6bf
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 17 deletions.
204 changes: 189 additions & 15 deletions cmd/dataset_tokenizer/dataset_tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
Expand All @@ -19,6 +18,7 @@ import (
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
"github.com/wbrown/gpt_bpe"
"github.com/wbrown/gpt_bpe/types"
)

type SanitizerTest struct {
Expand Down Expand Up @@ -66,24 +66,11 @@ var sanitizerTests = SanitizerTests{

const corpusPath = "../../resources/frankenstein.txt"

func TokensFromBin(bin *[]byte) *gpt_bpe.Tokens {
tokens := make(gpt_bpe.Tokens, 0)
buf := bytes.NewReader(*bin)
for {
var token gpt_bpe.Token
if err := binary.Read(buf, binary.LittleEndian, &token); err != nil {
break
}
tokens = append(tokens, token)
}
return &tokens
}

// DecodeBuffer
// Decode Tokens from a byte array into a string.
func DecodeBuffer(encoded *[]byte) (text string) {
// First convert our bytearray into a uint32 `Token` array.
tokens := TokensFromBin(encoded)
tokens := types.TokensFromBin(encoded)
// Decode our tokens into a string.
var enc *gpt_bpe.GPTEncoder
encoderString := "gpt2"
Expand Down Expand Up @@ -736,3 +723,190 @@ func TestListObjectsRecursively(t *testing.T) {

wg.Wait() // Wait for all goroutines to finish
}

func TestUInt16WithNoEnforce(t *testing.T) {
// Test if with Uint32 enforce disabled,
// using a Uint16 tokenizer works as intended with no padding.

textsTokenizer := NewTextsTokenizer()
textsTokenizer.ContextSize = 2048
textsTokenizer.TokenizerId = "gpt2"
textsTokenizer.EndOfText = ""

// Test data
testString := "The quick brown fox jumps over the lazy dog."
expectedTokens := types.Tokens{464, 2068, 7586, 21831, 18045, 625, 262, 16931, 3290, 13, 50256}
// Generate temp directory and test file
tempDir := os.TempDir()
testFile := tempDir + "/test.txt"
f, err := os.Create(testFile)
if err != nil {
log.Fatal(err)
}
// Write test string to file
_, err = f.WriteString(testString)
if err != nil {
log.Fatal(err)
}
f.Close()
defer os.Remove(testFile)

reorderPaths := ""
sampling := 100
outputFile := "base.chunk"
defer os.Remove(outputFile)

enc, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

if texts, err := ReadTexts(
testFile, false,
reorderPaths,
1,
); err != nil {
log.Fatal(err)
} else {
begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(
texts, "./test", enc,
)
if tokErr != nil {
log.Fatal(tokErr)
}

total, writeErr := WriteContexts(
outputFile,
contexts,
enc,
sampling,
false,
false,
false,
)
if writeErr != nil {
log.Fatal(writeErr)
}
duration := time.Since(begin).Seconds()
log.Printf(
"%d tokens in %0.2fs, %0.2f tokens/s", total,
duration, float64(total)/duration,
)
}
// Read the encoded tokens from the output file
binaryData, err := os.ReadFile(outputFile)
if err != nil {
log.Fatal(err)
}

// Convert to Tokens array
tokens := types.TokensFromBin(&binaryData)

if len(*tokens) != len(expectedTokens) {
t.Fatalf(
"Expected %d tokens, but got %d", len(expectedTokens),
len(*tokens),
)
}
if &expectedTokens != tokens {
t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens)
}

// Verify the encoded tokens
assert.Equal(t, &expectedTokens, tokens)
}

func TestUInt16WithEnforce(t *testing.T) {
// Test if with Uint32 enforce enabled,
// using a Uint16 tokenizer works as intended with padding
// ie X, 0 Y, 0, Z, 0

textsTokenizer := NewTextsTokenizer()
textsTokenizer.ContextSize = 2048
textsTokenizer.TokenizerId = "gpt2"
textsTokenizer.EndOfText = ""

// Test data
testString := "The quick brown fox jumps over the lazy dog."
expectedTokens := types.Tokens{464, 0, 2068, 0, 7586, 0, 21831, 0, 18045, 0, 625, 0, 262, 0, 16931, 0, 3290, 0, 13, 0, 50256, 0}
// Generate temp directory and test file
tempDir := os.TempDir()
testFile := tempDir + "/test.txt"
f, err := os.Create(testFile)
if err != nil {
log.Fatal(err)
}
// Write test string to file
_, err = f.WriteString(testString)
if err != nil {
log.Fatal(err)
}
f.Close()
defer os.Remove(testFile)

reorderPaths := ""
sampling := 100
outputFile := "base.chunk"
defer os.Remove(outputFile)

enc, tokErr := textsTokenizer.InitTokenizer()
if tokErr != nil {
log.Fatal(tokErr)
}

if texts, err := ReadTexts(
testFile, false,
reorderPaths,
1,
); err != nil {
log.Fatal(err)
} else {
begin := time.Now()
contexts, tokErr := textsTokenizer.TokenizeTexts(
texts, "./test", enc,
)
if tokErr != nil {
log.Fatal(tokErr)
}

total, writeErr := WriteContexts(
outputFile,
contexts,
enc,
sampling,
false,
true,
false,
)
if writeErr != nil {
log.Fatal(writeErr)
}
duration := time.Since(begin).Seconds()
log.Printf(
"%d tokens in %0.2fs, %0.2f tokens/s", total,
duration, float64(total)/duration,
)
}
// Read the encoded tokens from the output file
binaryData, err := os.ReadFile(outputFile)
if err != nil {
log.Fatal(err)
}

// Convert to Tokens array
tokens := types.TokensFromBin(&binaryData)

if len(*tokens) != len(expectedTokens) {
t.Fatalf(
"Expected %d tokens, but got %d", len(expectedTokens),
len(*tokens),
)
}
if &expectedTokens != tokens {
t.Fatalf("Expected tokens: %v, but got: %v", expectedTokens, tokens)
}

// Verify the encoded tokens
assert.Equal(t, &expectedTokens, tokens)
}
12 changes: 12 additions & 0 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,18 @@ func TestGPTEncoder_Encode(t *testing.T) {
}
}

func TestGPTEncode(t *testing.T) {
// This test is to check if the GPTEncoder is able to encode the tokens correctly
strin := "The quick brown fox jumps over the lazy dog."
expected := Tokens{464, 21831, 11687, 625, 262, 387, 260, 25970, 82, 29, 464, 28699, 318, 5443, 621, 262, 387, 260, 13}
encoded := gpt2Encoder.Encode(&strin)
fmt.Printf("Encoded: with commas:")
for _, token := range *encoded {
fmt.Printf("%v, ", token)
}
assert.Equal(t, *encoded, expected)
}

func TestGPTEncoder_StreamingEncode(t *testing.T) {
// This test is to check if the GPTEncoder is able to encode the tokens correctly
start := time.Now()
Expand Down
6 changes: 4 additions & 2 deletions js/js.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package main
//go:generate gopherjs build --minify

import (
"log"

"github.com/gopherjs/gopherjs/js"
"github.com/wbrown/gpt_bpe"
"log"
"github.com/wbrown/gpt_bpe/types"
)

var encoder gpt_bpe.GPTEncoder
Expand All @@ -15,7 +17,7 @@ func Tokenize(text string) gpt_bpe.Tokens {
}

func Decode(arr []byte) string {
tokens := gpt_bpe.TokensFromBin(&arr)
tokens := types.TokensFromBin(&arr)
return encoder.Decode(tokens)
}

Expand Down

0 comments on commit d13b6bf

Please sign in to comment.