Skip to content

Commit

Permalink
add new textspliter option: lenfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
whyiug20231206 committed Feb 8, 2024
1 parent 677c14c commit 8734b60
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
13 changes: 13 additions & 0 deletions textsplitter/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type Options struct {
ChunkSize int
ChunkOverlap int
Separators []string
LenFunc func(string) int
ModelName string
EncodingName string
AllowedSpecial []string
Expand All @@ -20,6 +21,7 @@ func DefaultOptions() Options {
ChunkSize: _defaultTokenChunkSize,
ChunkOverlap: _defaultTokenChunkOverlap,
Separators: []string{"\n\n", "\n", " ", ""},
LenFunc: defaultLenFunc,

ModelName: _defaultTokenModelName,
EncodingName: _defaultTokenEncoding,
Expand Down Expand Up @@ -52,6 +54,13 @@ func WithSeparators(separators []string) Option {
}
}

// WithLenFunc sets the lenfunc for a text splitter.
func WithLenFunc(lenFunc func(string) int) Option {
return func(o *Options) {
o.LenFunc = lenFunc
}
}

// WithModelName sets the model name for a text splitter.
func WithModelName(modelName string) Option {
return func(o *Options) {
Expand Down Expand Up @@ -107,3 +116,7 @@ func WithReferenceLinks(referenceLinks bool) Option {
o.ReferenceLinks = referenceLinks
}
}

func defaultLenFunc(s string) int {
return len(s)
}
8 changes: 5 additions & 3 deletions textsplitter/recursive_character.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type RecursiveCharacter struct {
Separators []string
ChunkSize int
ChunkOverlap int
LenFunc func(string) int
}

// NewRecursiveCharacter creates a new recursive character splitter with default values. By
Expand All @@ -25,6 +26,7 @@ func NewRecursiveCharacter(opts ...Option) RecursiveCharacter {
Separators: options.Separators,
ChunkSize: options.ChunkSize,
ChunkOverlap: options.ChunkOverlap,
LenFunc: options.LenFunc,
}

return s
Expand All @@ -50,13 +52,13 @@ func (s RecursiveCharacter) SplitText(text string) ([]string, error) {

// Merge the splits, recursively splitting larger texts.
for _, split := range splits {
if len(split) < s.ChunkSize {
if s.LenFunc(split) < s.ChunkSize {
goodSplits = append(goodSplits, split)
continue
}

if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap)
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)

finalChunks = append(finalChunks, mergedText...)
goodSplits = make([]string, 0)
Expand All @@ -74,7 +76,7 @@ func (s RecursiveCharacter) SplitText(text string) ([]string, error) {
}

if len(goodSplits) > 0 {
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap)
mergedText := mergeSplits(goodSplits, separator, s.ChunkSize, s.ChunkOverlap, s.LenFunc)
finalChunks = append(finalChunks, mergedText...)
}

Expand Down
16 changes: 16 additions & 0 deletions textsplitter/recursive_character_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package textsplitter

import (
"testing"
"unicode/utf8"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -17,8 +18,20 @@ func TestRecursiveCharacterSplitter(t *testing.T) {
chunkSize int
separators []string
expectedDocs []schema.Document
lenFunc func(string) int
}
testCases := []testCase{
{
text: "哈里森\n很高兴遇见你\n欢迎你来中国",
chunkOverlap: 0,
chunkSize: 10,
separators: []string{"\n\n", "\n", " "},
lenFunc: utf8.RuneCountInString,
expectedDocs: []schema.Document{
{PageContent: "哈里森\n很高兴遇见你", Metadata: map[string]any{}},
{PageContent: "欢迎你来中国", Metadata: map[string]any{}},
},
},
{
text: "Hi, Harrison. \nI am glad to meet you",
chunkOverlap: 1,
Expand Down Expand Up @@ -102,6 +115,9 @@ Bye!
splitter.ChunkOverlap = tc.chunkOverlap
splitter.ChunkSize = tc.chunkSize
splitter.Separators = tc.separators
if tc.lenFunc != nil {
splitter.LenFunc = tc.lenFunc
}

docs, err := CreateDocuments(splitter, []string{tc.text}, nil)
require.NoError(t, err)
Expand Down
16 changes: 8 additions & 8 deletions textsplitter/split_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ func joinDocs(docs []string, separator string) string {
}

// mergeSplits merges smaller splits into splits that are closer to the chunkSize.
func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int) []string { //nolint:cyclop
func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap int, lenFunc func(string) int) []string { //nolint:cyclop
docs := make([]string, 0)
currentDoc := make([]string, 0)
total := 0

for _, split := range splits {
totalWithSplit := total + len(split)
totalWithSplit := total + lenFunc(split)
if len(currentDoc) != 0 {
totalWithSplit += len(separator)
totalWithSplit += lenFunc(separator)
}

maybePrintWarning(total, chunkSize)
Expand All @@ -86,19 +86,19 @@ func mergeSplits(splits []string, separator string, chunkSize int, chunkOverlap
docs = append(docs, doc)
}

for shouldPop(chunkOverlap, chunkSize, total, len(split), len(separator), len(currentDoc)) {
total -= len(currentDoc[0]) //nolint:gosec
for shouldPop(chunkOverlap, chunkSize, total, lenFunc(split), lenFunc(separator), len(currentDoc)) {
total -= lenFunc(currentDoc[0]) //nolint:gosec
if len(currentDoc) > 1 {
total -= len(separator)
total -= lenFunc(separator)
}
currentDoc = currentDoc[1:] //nolint:gosec
}
}

currentDoc = append(currentDoc, split)
total += len(split)
total += lenFunc(split)
if len(currentDoc) > 1 {
total += len(separator)
total += lenFunc(separator)
}
}

Expand Down

0 comments on commit 8734b60

Please sign in to comment.