From 7ed5f7f00eb5c7c14acb81bbea013bf4eb30e6dd Mon Sep 17 00:00:00 2001 From: Pedram Razavi Date: Mon, 6 Jan 2025 15:37:01 -0800 Subject: [PATCH] textsplitter: add an optional lenFunc to MarkdownTextSplitter Based on the custom lenFuc supported in RecursiveCharacter splitter in 8734b60555715d08b6255b4731dc9bc3317227d2 --- textsplitter/markdown_splitter.go | 16 ++++++--- textsplitter/markdown_splitter_test.go | 45 ++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/textsplitter/markdown_splitter.go b/textsplitter/markdown_splitter.go index 117d9b1e1..03be43f73 100644 --- a/textsplitter/markdown_splitter.go +++ b/textsplitter/markdown_splitter.go @@ -4,7 +4,6 @@ import ( "fmt" "reflect" "strings" - "unicode/utf8" "gitlab.com/golang-commonmark/markdown" ) @@ -25,6 +24,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter { ReferenceLinks: options.ReferenceLinks, HeadingHierarchy: options.KeepHeadingHierarchy, JoinTableRows: options.JoinTableRows, + LenFunc: options.LenFunc, } if sp.SecondSplitter == nil { @@ -36,6 +36,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter { "\n", // new line " ", // space }), + WithLenFunc(options.LenFunc), ) } @@ -57,6 +58,7 @@ type MarkdownTextSplitter struct { ReferenceLinks bool HeadingHierarchy bool JoinTableRows bool + LenFunc func(string) int } // SplitText splits a text into multiple text. @@ -76,6 +78,7 @@ func (sp MarkdownTextSplitter) SplitText(text string) ([]string, error) { joinTableRows: sp.JoinTableRows, hTitleStack: []string{}, hTitlePrependHierarchy: sp.HeadingHierarchy, + lenFunc: sp.LenFunc, } chunks := mc.splitText() @@ -133,6 +136,9 @@ type markdownContext struct { // joinTableRows determines whether a chunk should contain multiple table rows, // or if each row in a table should be split into a separate chunk. joinTableRows bool + + // lenFunc represents the function to calculate the length of a string. + lenFunc func(string) int } // splitText splits Markdown text. @@ -193,6 +199,8 @@ func (mc *markdownContext) clone(startAt, endAt int) *markdownContext { chunkSize: mc.chunkSize, chunkOverlap: mc.chunkOverlap, secondSplitter: mc.secondSplitter, + + lenFunc: mc.lenFunc, } } @@ -438,7 +446,7 @@ func (mc *markdownContext) splitTableRows(header []string, bodies [][]string) { // If we're at the start of the current snippet, or adding the current line would // overflow the chunk size, prepend the header to the line (so that the new chunk // will include the table header). - if len(mc.curSnippet) == 0 || utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(line) >= mc.chunkSize { + if len(mc.curSnippet) == 0 || mc.lenFunc(mc.curSnippet+line) >= mc.chunkSize { line = fmt.Sprintf("%s\n%s", headerMD, line) } @@ -617,7 +625,7 @@ func (mc *markdownContext) joinSnippet(snippet string) { } // check whether current chunk exceeds chunk size, if so, apply to chunks - if utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(snippet) >= mc.chunkSize { + if mc.lenFunc(mc.curSnippet+snippet) >= mc.chunkSize { mc.applyToChunks() mc.curSnippet = snippet } else { @@ -634,7 +642,7 @@ func (mc *markdownContext) applyToChunks() { var chunks []string if mc.curSnippet != "" { // check whether current chunk is over ChunkSize,if so, re-split current chunk - if utf8.RuneCountInString(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap { + if mc.lenFunc(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap { chunks = []string{mc.curSnippet} } else { // split current snippet to chunks diff --git a/textsplitter/markdown_splitter_test.go b/textsplitter/markdown_splitter_test.go index a75cad2a4..758180065 100644 --- a/textsplitter/markdown_splitter_test.go +++ b/textsplitter/markdown_splitter_test.go @@ -4,6 +4,7 @@ import ( "os" "testing" + "github.com/pkoukk/tiktoken-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/schema" @@ -579,3 +580,47 @@ func TestMarkdownHeaderTextSplitter_SplitInline(t *testing.T) { }) } } + +func TestMarkdownHeaderTextSplitter_LenFunc(t *testing.T) { + t.Parallel() + + tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base") + + sampleText := "The quick brown fox jumped over the lazy dog." + tokensPerChunk := len(tokenEncoder.Encode(sampleText, nil, nil)) + + type testCase struct { + markdown string + expectedDocs []schema.Document + } + + testCases := []testCase{ + { + markdown: `# Title` + "\n" + sampleText + "\n" + sampleText, + expectedDocs: []schema.Document{ + { + PageContent: "# Title" + "\n" + sampleText, + Metadata: map[string]any{}, + }, + { + PageContent: "# Title" + "\n" + sampleText, + Metadata: map[string]any{}, + }, + }, + }, + } + + splitter := NewMarkdownTextSplitter( + WithChunkSize(tokensPerChunk+1), + WithChunkOverlap(0), + WithLenFunc(func(s string) int { + return len(tokenEncoder.Encode(s, nil, nil)) + }), + ) + + for _, tc := range testCases { + docs, err := CreateDocuments(splitter, []string{tc.markdown}, nil) + require.NoError(t, err) + assert.Equal(t, tc.expectedDocs, docs) + } +}