Skip to content

Commit

Permalink
Merge pull request #1096 from pedramr/pedram/text-splitter-markdown-l…
Browse files Browse the repository at this point in the history
…enfunc

textsplitter: add an optional lenFunc to MarkdownTextSplitter
  • Loading branch information
FluffyKebab authored Jan 7, 2025
2 parents 29259e5 + 7ed5f7f commit 71ded3c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
16 changes: 12 additions & 4 deletions textsplitter/markdown_splitter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"fmt"
"reflect"
"strings"
"unicode/utf8"

"gitlab.com/golang-commonmark/markdown"
)
Expand All @@ -25,6 +24,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
ReferenceLinks: options.ReferenceLinks,
HeadingHierarchy: options.KeepHeadingHierarchy,
JoinTableRows: options.JoinTableRows,
LenFunc: options.LenFunc,
}

if sp.SecondSplitter == nil {
Expand All @@ -36,6 +36,7 @@ func NewMarkdownTextSplitter(opts ...Option) *MarkdownTextSplitter {
"\n", // new line
" ", // space
}),
WithLenFunc(options.LenFunc),
)
}

Expand All @@ -57,6 +58,7 @@ type MarkdownTextSplitter struct {
ReferenceLinks bool
HeadingHierarchy bool
JoinTableRows bool
LenFunc func(string) int
}

// SplitText splits a text into multiple text.
Expand All @@ -76,6 +78,7 @@ func (sp MarkdownTextSplitter) SplitText(text string) ([]string, error) {
joinTableRows: sp.JoinTableRows,
hTitleStack: []string{},
hTitlePrependHierarchy: sp.HeadingHierarchy,
lenFunc: sp.LenFunc,
}

chunks := mc.splitText()
Expand Down Expand Up @@ -133,6 +136,9 @@ type markdownContext struct {
// joinTableRows determines whether a chunk should contain multiple table rows,
// or if each row in a table should be split into a separate chunk.
joinTableRows bool

// lenFunc represents the function to calculate the length of a string.
lenFunc func(string) int
}

// splitText splits Markdown text.
Expand Down Expand Up @@ -193,6 +199,8 @@ func (mc *markdownContext) clone(startAt, endAt int) *markdownContext {
chunkSize: mc.chunkSize,
chunkOverlap: mc.chunkOverlap,
secondSplitter: mc.secondSplitter,

lenFunc: mc.lenFunc,
}
}

Expand Down Expand Up @@ -438,7 +446,7 @@ func (mc *markdownContext) splitTableRows(header []string, bodies [][]string) {
// If we're at the start of the current snippet, or adding the current line would
// overflow the chunk size, prepend the header to the line (so that the new chunk
// will include the table header).
if len(mc.curSnippet) == 0 || utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(line) >= mc.chunkSize {
if len(mc.curSnippet) == 0 || mc.lenFunc(mc.curSnippet+line) >= mc.chunkSize {
line = fmt.Sprintf("%s\n%s", headerMD, line)
}

Expand Down Expand Up @@ -617,7 +625,7 @@ func (mc *markdownContext) joinSnippet(snippet string) {
}

// check whether current chunk exceeds chunk size, if so, apply to chunks
if utf8.RuneCountInString(mc.curSnippet)+utf8.RuneCountInString(snippet) >= mc.chunkSize {
if mc.lenFunc(mc.curSnippet+snippet) >= mc.chunkSize {
mc.applyToChunks()
mc.curSnippet = snippet
} else {
Expand All @@ -634,7 +642,7 @@ func (mc *markdownContext) applyToChunks() {
var chunks []string
if mc.curSnippet != "" {
// check whether current chunk is over ChunkSize,if so, re-split current chunk
if utf8.RuneCountInString(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
if mc.lenFunc(mc.curSnippet) <= mc.chunkSize+mc.chunkOverlap {
chunks = []string{mc.curSnippet}
} else {
// split current snippet to chunks
Expand Down
45 changes: 45 additions & 0 deletions textsplitter/markdown_splitter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"os"
"testing"

"github.com/pkoukk/tiktoken-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/schema"
Expand Down Expand Up @@ -579,3 +580,47 @@ func TestMarkdownHeaderTextSplitter_SplitInline(t *testing.T) {
})
}
}

func TestMarkdownHeaderTextSplitter_LenFunc(t *testing.T) {
t.Parallel()

tokenEncoder, _ := tiktoken.GetEncoding("cl100k_base")

sampleText := "The quick brown fox jumped over the lazy dog."
tokensPerChunk := len(tokenEncoder.Encode(sampleText, nil, nil))

type testCase struct {
markdown string
expectedDocs []schema.Document
}

testCases := []testCase{
{
markdown: `# Title` + "\n" + sampleText + "\n" + sampleText,
expectedDocs: []schema.Document{
{
PageContent: "# Title" + "\n" + sampleText,
Metadata: map[string]any{},
},
{
PageContent: "# Title" + "\n" + sampleText,
Metadata: map[string]any{},
},
},
},
}

splitter := NewMarkdownTextSplitter(
WithChunkSize(tokensPerChunk+1),
WithChunkOverlap(0),
WithLenFunc(func(s string) int {
return len(tokenEncoder.Encode(s, nil, nil))
}),
)

for _, tc := range testCases {
docs, err := CreateDocuments(splitter, []string{tc.markdown}, nil)
require.NoError(t, err)
assert.Equal(t, tc.expectedDocs, docs)
}
}

0 comments on commit 71ded3c

Please sign in to comment.