From 7c931a68ae3dd02f25fde169581b3f9385e196a4 Mon Sep 17 00:00:00 2001 From: Julie Tibshirani Date: Mon, 13 Jan 2025 14:58:09 -0800 Subject: [PATCH 1/2] Support BM25 scoring for chunk matches --- api.go | 20 ++- build/scoring_test.go | 17 ++- contentprovider.go | 184 ++++------------------------ eval.go | 4 +- index_test.go | 166 ++++++++++++++++++++++++- read.go | 9 +- score.go | 278 +++++++++++++++++++++++++++++++++++++----- 7 files changed, 474 insertions(+), 204 deletions(-) diff --git a/api.go b/api.go index 70e7774bd..b829c6878 100644 --- a/api.go +++ b/api.go @@ -33,6 +33,7 @@ const ( stringHeaderBytes uint64 = 16 pointerSize uint64 = 8 interfaceBytes uint64 = 16 + maxUInt16 = 0xffff ) // FileMatch contains all the matches within a file. @@ -135,6 +136,22 @@ func (m *FileMatch) sizeBytes() (sz uint64) { return } +// addScore increments the score of the FileMatch by the computed score. If +// debugScore is true, it also adds a debug string to the FileMatch. If raw is +// -1, it is ignored. Otherwise, it is added to the debug string. +func (m *FileMatch) addScore(what string, computed float64, raw float64, debugScore bool) { + if computed != 0 && debugScore { + var b strings.Builder + fmt.Fprintf(&b, "%s", what) + if raw != -1 { + fmt.Fprintf(&b, "(%s)", strconv.FormatFloat(raw, 'f', -1, 64)) + } + fmt.Fprintf(&b, ":%.2f, ", computed) + m.Debug += b.String() + } + m.Score += computed +} + // ChunkMatch is a set of non-overlapping matches within a contiguous range of // lines in the file. type ChunkMatch struct { @@ -976,7 +993,8 @@ type SearchOptions struct { // EXPERIMENTAL. If true, use text-search style scoring instead of the default // scoring formula. The scoring algorithm treats each match in a file as a term - // and computes an approximation to BM25. + // and computes an approximation to BM25. When enabled, BM25 scoring is used for + // the overall FileMatch score, as well as individual LineMatch and ChunkMatch scores. // // The calculation of IDF assumes that Zoekt visits all documents containing any // of the query terms during evaluation. This is true, for example, if all query diff --git a/build/scoring_test.go b/build/scoring_test.go index ea0bfc638..8f5f49ecd 100644 --- a/build/scoring_test.go +++ b/build/scoring_test.go @@ -94,8 +94,21 @@ func TestBM25(t *testing.T) { language: "Java", // bm25-score: 1.81 <- sum-termFrequencyScore: 116.00, length-ratio: 1.00 wantScore: 1.81, - // line 3: public class InnerClasses { - wantBestLineMatch: 3, + // line 54: private static B runInnerInterface(InnerInterface fn, A a) { + wantBestLineMatch: 54, + }, { + // Another content-only match + fileName: "example.java", + query: &query.And{Children: []query.Q{ + &query.Substring{Pattern: "system"}, + &query.Substring{Pattern: "time"}, + }}, + content: exampleJava, + language: "Java", + // bm25-score: 0.96 <- sum-termFrequencies: 12, length-ratio: 1.00 + wantScore: 0.96, + // line 59: if (System.nanoTime() > System.currentTimeMillis()) { + wantBestLineMatch: 59, }, { // Matches only on filename diff --git a/contentprovider.go b/contentprovider.go index 34600f303..d24ae310e 100644 --- a/contentprovider.go +++ b/contentprovider.go @@ -16,12 +16,10 @@ package zoekt import ( "bytes" - "fmt" "log" "path" "slices" "sort" - "strings" "unicode" "unicode/utf8" @@ -145,7 +143,7 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 { // // Note: the byte slices may be backed by mmapped data, so before being // returned by the API it needs to be copied. -func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch { +func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch { var filenameMatches []*candidateMatch contentMatches := make([]*candidateMatch, 0, len(ms)) @@ -160,16 +158,16 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, // If there are any content matches, we only return these and skip filename matches. if len(contentMatches) > 0 { contentMatches = breakMatchesOnNewlines(contentMatches, p.data(false)) - return p.fillContentMatches(contentMatches, numContextLines, language, debug) + return p.fillContentMatches(contentMatches, numContextLines, language, opts) } // Otherwise, we return a single line containing the filematch match. - bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug) + lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts) res := LineMatch{ Line: p.id.fileName(p.idx), FileName: true, - Score: bestMatch.score, - DebugScore: bestMatch.debugScore, + Score: lineScore.score, + DebugScore: lineScore.debugScore, } for _, m := range ms { @@ -192,7 +190,7 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, // // Note: the byte slices may be backed by mmapped data, so before being // returned by the API it needs to be copied. -func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch { +func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch { var filenameMatches []*candidateMatch contentMatches := make([]*candidateMatch, 0, len(ms)) @@ -206,11 +204,11 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines // If there are any content matches, we only return these and skip filename matches. if len(contentMatches) > 0 { - return p.fillContentChunkMatches(contentMatches, numContextLines, language, debug) + return p.fillContentChunkMatches(contentMatches, numContextLines, language, opts) } // Otherwise, we return a single chunk representing the filename match. - bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug) + lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts) fileName := p.id.fileName(p.idx) ranges := make([]Range, 0, len(ms)) for _, m := range ms { @@ -233,12 +231,12 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines ContentStart: Location{ByteOffset: 0, LineNumber: 1, Column: 1}, Ranges: ranges, FileName: true, - Score: bestMatch.score, - DebugScore: bestMatch.debugScore, + Score: lineScore.score, + DebugScore: lineScore.debugScore, }} } -func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch { +func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch { var result []LineMatch for len(ms) > 0 { m := ms[0] @@ -296,9 +294,9 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin finalMatch.After = p.newlines().getLines(data, num+1, num+1+numContextLines) } - bestMatch, symbolInfo := p.candidateMatchScore(lineCands, language, debug) - finalMatch.Score = bestMatch.score - finalMatch.DebugScore = bestMatch.debugScore + lineScore, symbolInfo := p.scoreLine(lineCands, language, num, opts) + finalMatch.Score = lineScore.score + finalMatch.DebugScore = lineScore.debugScore for i, m := range lineCands { fragment := LineFragmentMatch{ @@ -306,6 +304,7 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin LineOffset: int(m.byteOffset) - lineStart, MatchLength: int(m.byteMatchSz), } + if i < len(symbolInfo) && symbolInfo[i] != nil { fragment.SymbolInfo = symbolInfo[i] } @@ -317,8 +316,7 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin return result } -func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch { - newlines := p.newlines() +func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch { data := p.data(false) // columnHelper prevents O(len(ms) * len(data)) lookups for all columns. @@ -332,11 +330,10 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte sort.Sort((sortByOffsetSlice)(ms)) } + newlines := p.newlines() chunks := chunkCandidates(ms, newlines, numContextLines) chunkMatches := make([]ChunkMatch, 0, len(chunks)) for _, chunk := range chunks { - bestMatch, symbolInfo := p.candidateMatchScore(chunk.candidates, language, debug) - ranges := make([]Range, 0, len(chunk.candidates)) for _, cm := range chunk.candidates { startOffset := cm.byteOffset @@ -363,14 +360,7 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte } firstLineStart := newlines.lineStart(firstLineNumber) - bestLineMatch := 0 - if bestMatch.match != nil { - bestLineMatch = newlines.atOffset(bestMatch.match.byteOffset) - if debug { - bestMatch.debugScore = fmt.Sprintf("%s, (line: %d)", bestMatch.debugScore, bestLineMatch) - } - } - + chunkScore, symbolInfo := p.scoreChunk(chunk.candidates, language, opts) chunkMatches = append(chunkMatches, ChunkMatch{ Content: newlines.getLines(data, firstLineNumber, int(chunk.lastLine)+numContextLines+1), ContentStart: Location{ @@ -381,9 +371,9 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte FileName: false, Ranges: ranges, SymbolInfo: symbolInfo, - BestLineMatch: uint32(bestLineMatch), - Score: bestMatch.score, - DebugScore: bestMatch.debugScore, + BestLineMatch: uint32(chunkScore.bestLine), + Score: chunkScore.score, + DebugScore: chunkScore.debugScore, }) } return chunkMatches @@ -405,6 +395,7 @@ type candidateChunk struct { // output invariants: if you flatten candidates the input invariant is retained. func chunkCandidates(ms []*candidateMatch, newlines newlines, numContextLines int) []candidateChunk { var chunks []candidateChunk + for _, m := range ms { startOffset := m.byteOffset endOffset := m.byteOffset + m.byteMatchSz @@ -536,10 +527,6 @@ const ( scoreKindMatch = 100.0 scoreFactorAtomMatch = 400.0 - // File-only scoring signals. For now these are also bounded ~9000 to give them - // equal weight with the query-dependent signals. - scoreFileRankFactor = 9000.0 - // Used for ordering line and chunk matches within a file. scoreLineOrderFactor = 1.0 @@ -643,133 +630,6 @@ func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symb return sec, si, true } -// calculateTermFrequency computes the term frequency for the file match. -// Notes: -// * Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles. -// * Symbol matches also count more than content matches, to reward matches on symbol definitions. -func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { - // Treat each candidate match as a term and compute the frequencies. For now, ignore case - // sensitivity and treat filenames and symbols the same as content. - termFreqs := map[string]int{} - for _, m := range cands { - term := string(m.substrLowered) - if m.fileName || p.matchesSymbol(m) { - termFreqs[term] += 5 - } else { - termFreqs[term]++ - } - } - - for term := range termFreqs { - df[term] += 1 - } - return termFreqs -} - -// scoredMatch holds the score information for a candidate match. -type scoredMatch struct { - score float64 - debugScore string - match *candidateMatch -} - -// candidateMatchScore scores all candidate matches and returns the best-scoring match plus its score information. -// Invariant: there should be at least one input candidate, len(ms) > 0. -func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language string, debug bool) (scoredMatch, []*Symbol) { - score := 0.0 - what := "" - - addScore := func(w string, s float64) { - if s != 0 && debug { - what += fmt.Sprintf("%s:%.2f, ", w, s) - } - score += s - } - - filename := p.data(true) - var symbolInfo []*Symbol - - var bestMatch scoredMatch - for i, m := range ms { - data := p.data(m.fileName) - - endOffset := m.byteOffset + m.byteMatchSz - startBoundary := m.byteOffset < uint32(len(data)) && (m.byteOffset == 0 || byteClass(data[m.byteOffset-1]) != byteClass(data[m.byteOffset])) - endBoundary := endOffset > 0 && (endOffset == uint32(len(data)) || byteClass(data[endOffset-1]) != byteClass(data[endOffset])) - - score = 0 - what = "" - - if startBoundary && endBoundary { - addScore("WordMatch", scoreWordMatch) - } else if startBoundary || endBoundary { - addScore("PartialWordMatch", scorePartialWordMatch) - } - - if m.fileName { - sep := bytes.LastIndexByte(data, '/') - startMatch := int(m.byteOffset) == sep+1 - endMatch := endOffset == uint32(len(data)) - if startMatch && endMatch { - addScore("Base", scoreBase) - } else if startMatch || endMatch { - addScore("EdgeBase", (scoreBase+scorePartialBase)/2) - } else if sep < int(m.byteOffset) { - addScore("InnerBase", scorePartialBase) - } - } else if sec, si, ok := p.findSymbol(m); ok { - startMatch := sec.Start == m.byteOffset - endMatch := sec.End == endOffset - if startMatch && endMatch { - addScore("Symbol", scoreSymbol) - } else if startMatch || endMatch { - addScore("EdgeSymbol", (scoreSymbol+scorePartialSymbol)/2) - } else { - addScore("OverlapSymbol", scorePartialSymbol) - } - - // Score based on symbol data - if si != nil { - symbolKind := ctags.ParseSymbolKind(si.Kind) - sym := sectionSlice(data, sec) - - addScore(fmt.Sprintf("kind:%s:%s", language, si.Kind), scoreSymbolKind(language, filename, sym, symbolKind)) - - // This is from a symbol tree, so we need to store the symbol - // information. - if m.symbol { - if symbolInfo == nil { - symbolInfo = make([]*Symbol, len(ms)) - } - // findSymbols does not hydrate in Sym. So we need to store it. - si.Sym = string(sym) - symbolInfo[i] = si - } - } - } - - // scoreWeight != 1 means it affects score - if !epsilonEqualsOne(m.scoreWeight) { - score = score * m.scoreWeight - if debug { - what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight) - } - } - - if score > bestMatch.score { - bestMatch.score = score - bestMatch.debugScore = what - bestMatch.match = m - } - } - - if debug { - bestMatch.debugScore = fmt.Sprintf("score:%.2f <- %s", bestMatch.score, strings.TrimSuffix(bestMatch.debugScore, ", ")) - } - - return bestMatch, symbolInfo -} - // sectionSlice will return data[sec.Start:sec.End] but will clip Start and // End such that it won't be out of range. func sectionSlice(data []byte, sec DocumentSection) []byte { diff --git a/eval.go b/eval.go index c54070f25..cca78752f 100644 --- a/eval.go +++ b/eval.go @@ -327,9 +327,9 @@ nextFileMatch: finalCands := d.gatherMatches(nextDoc, mt, known, shouldMergeMatches) if opts.ChunkMatches { - fileMatch.ChunkMatches = cp.fillChunkMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore) + fileMatch.ChunkMatches = cp.fillChunkMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts) } else { - fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore) + fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts) } var tf map[string]int diff --git a/index_test.go b/index_test.go index bdb92f5a4..223bd81e2 100644 --- a/index_test.go +++ b/index_test.go @@ -45,17 +45,17 @@ func clearScores(r *SearchResult) { } } -func testIndexBuilder(t *testing.T, repo *Repository, docs ...Document) *IndexBuilder { - t.Helper() +func testIndexBuilder(tb testing.TB, repo *Repository, docs ...Document) *IndexBuilder { + tb.Helper() b, err := NewIndexBuilder(repo) if err != nil { - t.Fatalf("NewIndexBuilder: %v", err) + tb.Fatalf("NewIndexBuilder: %v", err) } for i, d := range docs { if err := b.Add(d); err != nil { - t.Fatalf("Add %d: %v", i, err) + tb.Fatalf("Add %d: %v", i, err) } } @@ -303,7 +303,7 @@ func searchForTest(t *testing.T, b *IndexBuilder, q query.Q, o ...SearchOptions) return res } -func searcherForTest(t *testing.T, b *IndexBuilder) Searcher { +func searcherForTest(t testing.TB, b *IndexBuilder) Searcher { var buf bytes.Buffer if err := b.Write(&buf); err != nil { t.Fatal(err) @@ -375,13 +375,16 @@ func TestCaseFold(t *testing.T) { func wordsAsSymbols(doc Document) Document { re := regexp.MustCompile(`\b\w{2,}\b`) var symbols []DocumentSection + var symbolsMetadata []*Symbol for _, match := range re.FindAllIndex(doc.Content, -1) { symbols = append(symbols, DocumentSection{ Start: uint32(match[0]), End: uint32(match[1]), }) + symbolsMetadata = append(symbolsMetadata, &Symbol{Kind: "method"}) } doc.Symbols = symbols + doc.SymbolsMetaData = symbolsMetadata return doc } @@ -992,6 +995,88 @@ func TestSearchMatchAllRegexp(t *testing.T) { }) } +func TestSearchBM25MatchScores(t *testing.T) { + ctx := context.Background() + searcher := searcherForTest(t, testIndexBuilder(t, nil, + Document{Name: "f1", Content: []byte("one two three\naaaaaaaaaa\nbbbbbbbb\none two two")}, + Document{Name: "f2", Content: []byte("four five six\naaaaaaaaaa\nbbbbbbbb\nfour five five\nsix six")}, + wordsAsSymbols(Document{Name: "f3", Content: []byte("public static void main")}), + )) + + t.Run("LineMatches", func(t *testing.T) { + q := &query.Substring{Pattern: "two"} + sres, err := searcher.Search(ctx, q, &SearchOptions{UseBM25Scoring: true}) + if err != nil { + t.Fatal(err) + } + matches := sres.Files + if len(matches) != 1 { + t.Fatalf("want 1 file match, got %d", len(matches)) + } + + if len(matches[0].LineMatches) != 2 { + t.Fatalf("want 2 chunk matches, got %d", len(matches[0].ChunkMatches)) + } + + if matches[0].LineMatches[0].LineNumber != 4 { + t.Fatalf("want best-scoring line to be line 4, got %d", matches[0].LineMatches[0].LineNumber) + } + }) + + t.Run("ChunkMatches", func(t *testing.T) { + q := &query.Substring{Pattern: "five"} + sres, err := searcher.Search(ctx, q, &SearchOptions{UseBM25Scoring: true, ChunkMatches: true, NumContextLines: 1}) + if err != nil { + t.Fatal(err) + } + + matches := sres.Files + if len(matches) != 1 { + t.Fatalf("want 1 file match, got %d", len(matches)) + } + + if len(matches[0].ChunkMatches) != 2 { + t.Fatalf("want 2 chunk matches, got %d", len(matches[0].ChunkMatches)) + } + + if matches[0].ChunkMatches[0].BestLineMatch != 4 { + t.Fatalf("want best-scoring line to be line 4, got %d", matches[0].ChunkMatches[0].BestLineMatch) + } + }) + + t.Run("ChunkMatches with symbols", func(t *testing.T) { + q := &query.Or{ + Children: []query.Q{ + &query.Symbol{Expr: &query.Substring{Pattern: "main"}}, + &query.Substring{Pattern: "five"}, + }, + } + + sres, err := searcher.Search(ctx, q, &SearchOptions{UseBM25Scoring: true, ChunkMatches: true, NumContextLines: 1}) + if err != nil { + t.Fatal(err) + } + + matches := sres.Files + if len(matches) != 2 { + t.Fatalf("want 2 file match, got %d", len(matches)) + } + + foundSymbolInfo := false + for _, m := range matches { + for _, cm := range m.ChunkMatches { + if len(cm.SymbolInfo) > 0 { + foundSymbolInfo = true + } + } + } + + if !foundSymbolInfo { + t.Fatalf("want symbol info, got none") + } + }) +} + func TestFileRestriction(t *testing.T) { b := testIndexBuilder(t, nil, Document{Name: "banana1", Content: []byte("x orange y")}, @@ -2453,7 +2538,7 @@ func TestIOStats(t *testing.T) { res := searchForTest(t, b, q) // 4096 (content) + 2 (overhead: newlines or doc sections) - if got, want := res.Stats.ContentBytesLoaded, int64(4100); got != want { + if got, want := res.Stats.ContentBytesLoaded, int64(4098); got != want { t.Errorf("got content I/O %d, want %d", got, want) } @@ -2479,6 +2564,38 @@ func TestIOStats(t *testing.T) { t.Errorf("got index I/O %d, want %d", got, want) } }) + + t.Run("LineMatches with BM25", func(t *testing.T) { + q := &query.Substring{Pattern: "abc", CaseSensitive: true, Content: true} + res := searchForTest(t, b, q, SearchOptions{UseBM25Scoring: true}) + + // 4096 (content) + 2 (overhead: newlines or doc sections) + if got, want := res.Stats.ContentBytesLoaded, int64(4098); got != want { + t.Errorf("got content I/O %d, want %d", got, want) + } + + // 1024 entries, each 4 bytes apart. 4 fits into single byte + // delta encoded. + if got, want := res.Stats.IndexBytesLoaded, int64(1024); got != want { + t.Errorf("got index I/O %d, want %d", got, want) + } + }) + + t.Run("ChunkMatches with BM25", func(t *testing.T) { + q := &query.Substring{Pattern: "abc", CaseSensitive: true, Content: true} + res := searchForTest(t, b, q, SearchOptions{UseBM25Scoring: true, ChunkMatches: true}) + + // 4096 (content) + 2 (overhead: newlines or doc sections) + if got, want := res.Stats.ContentBytesLoaded, int64(4098); got != want { + t.Errorf("got content I/O %d, want %d", got, want) + } + + // 1024 entries, each 4 bytes apart. 4 fits into single byte + // delta encoded. + if got, want := res.Stats.IndexBytesLoaded, int64(1024); got != want { + t.Errorf("got index I/O %d, want %d", got, want) + } + }) } func TestStartLineAnchor(t *testing.T) { @@ -3781,3 +3898,40 @@ func TestWordSearch(t *testing.T) { } }) } + +// Simple benchmark focused on chunk match scoring. It creates a single file that will have a 1000-line chunk match. +// The benchmark time is expected to be strongly correlated with time spent assembling and scoring this chunk. +func BenchmarkScoreChunkMatches(b *testing.B) { + ctx := context.Background() + var builder strings.Builder + for i := 0; i < 1000; i++ { + builder.WriteString(fmt.Sprintf("line-%d one one one two two two three three three four four four five five\n", i)) + } + + searcher := searcherForTest(b, testIndexBuilder(b, nil, + Document{Name: "f1", Content: []byte(builder.String())}, + )) + + q := &query.Or{ + Children: []query.Q{ + &query.Substring{Pattern: "f"}, + &query.Substring{Pattern: "t"}, + }} + + b.Run("score large ChunkMatch", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + sres, err := searcher.Search(ctx, q, &SearchOptions{ChunkMatches: true, NumContextLines: 1}) + if err != nil { + b.Fatal(err) + } + + matches := sres.Files + if len(matches) == 0 { + b.Fatalf("want file match, got none") + } + } + }) +} diff --git a/read.go b/read.go index 12cb32eb4..189ec64c4 100644 --- a/read.go +++ b/read.go @@ -533,7 +533,14 @@ func (d *indexData) readNewlines(i uint32, buf []uint32) ([]uint32, uint32, erro return nil, 0, err } - return fromSizedDeltas(blob, buf), sec.sz, nil + nl := fromSizedDeltas(blob, buf) + + // can be nil if buf is nil and there are no doc sections. However, we rely + // on it being non-nil to cache the read. + if nl == nil { + nl = make([]uint32, 0) + } + return nl, sec.sz, nil } func (d *indexData) readDocSections(i uint32, buf []DocumentSection) ([]DocumentSection, uint32, error) { diff --git a/score.go b/score.go index 09faad837..45f9737aa 100644 --- a/score.go +++ b/score.go @@ -15,31 +15,252 @@ package zoekt import ( + "bytes" "fmt" "math" - "strconv" "strings" + + "github.com/sourcegraph/zoekt/ctags" ) const ( - maxUInt16 = 0xffff ScoreOffset = 10_000_000 ) -// addScore increments the score of the FileMatch by the computed score. If -// debugScore is true, it also adds a debug string to the FileMatch. If raw is -// -1, it is ignored. Otherwise, it is added to the debug string. -func (m *FileMatch) addScore(what string, computed float64, raw float64, debugScore bool) { - if computed != 0 && debugScore { - var b strings.Builder - fmt.Fprintf(&b, "%s", what) - if raw != -1 { - fmt.Fprintf(&b, "(%s)", strconv.FormatFloat(raw, 'f', -1, 64)) +type chunkScore struct { + score float64 + debugScore string + bestLine int +} + +// scoreChunk calculates the score for each line in the chunk based on its candidate matches, and returns the score of +// the best-scoring line, along with its line number. +// Invariant: there should be at least one input candidate, len(ms) > 0. +func (p *contentProvider) scoreChunk(ms []*candidateMatch, language string, opts *SearchOptions) (chunkScore, []*Symbol) { + nl := p.newlines() + + var bestScore lineScore + bestLine := 0 + var symbolInfo []*Symbol + + start := 0 + currentLine := -1 + for i, m := range ms { + lineNumber := -1 + if !m.fileName { + lineNumber = nl.atOffset(m.byteOffset) + } + + // If this match represents a new line, then score the previous line and update 'start'. + if i != 0 && lineNumber != currentLine { + score, si := p.scoreLine(ms[start:i], language, currentLine, opts) + symbolInfo = append(symbolInfo, si...) + if score.score > bestScore.score { + bestScore = score + bestLine = currentLine + } + start = i + } + currentLine = lineNumber + } + + // Make sure to score the last line + line, si := p.scoreLine(ms[start:], language, currentLine, opts) + symbolInfo = append(symbolInfo, si...) + if line.score > bestScore.score { + bestScore = line + bestLine = currentLine + } + + cs := chunkScore{ + score: bestScore.score, + bestLine: bestLine, + } + if opts.DebugScore { + cs.debugScore = fmt.Sprintf("%s, (line: %d)", bestScore.debugScore, bestLine) + } + return cs, symbolInfo +} + +type lineScore struct { + score float64 + debugScore string +} + +// scoreLine calculates a score for the line based on its candidate matches. +// Invariants: +// - All candidate matches are assumed to come from the same line in the content. +// - If this line represents a filename, then lineNumber must be -1. +// - There should be at least one input candidate, len(ms) > 0. +func (p *contentProvider) scoreLine(ms []*candidateMatch, language string, lineNumber int, opts *SearchOptions) (lineScore, []*Symbol) { + if opts.UseBM25Scoring { + score, symbolInfo := p.scoreLineBM25(ms, lineNumber) + ls := lineScore{score: score} + if opts.DebugScore { + ls.debugScore = fmt.Sprintf("tfScore:%.2f, ", score) + } + return ls, symbolInfo + } + + score := 0.0 + what := "" + addScore := func(w string, s float64) { + if s != 0 && opts.DebugScore { + what += fmt.Sprintf("%s:%.2f, ", w, s) + } + score += s + } + + filename := p.data(true) + var symbolInfo []*Symbol + + var bestLine lineScore + for i, m := range ms { + data := p.data(m.fileName) + + endOffset := m.byteOffset + m.byteMatchSz + startBoundary := m.byteOffset < uint32(len(data)) && (m.byteOffset == 0 || byteClass(data[m.byteOffset-1]) != byteClass(data[m.byteOffset])) + endBoundary := endOffset > 0 && (endOffset == uint32(len(data)) || byteClass(data[endOffset-1]) != byteClass(data[endOffset])) + + score = 0 + what = "" + + if startBoundary && endBoundary { + addScore("WordMatch", scoreWordMatch) + } else if startBoundary || endBoundary { + addScore("PartialWordMatch", scorePartialWordMatch) + } + + if m.fileName { + sep := bytes.LastIndexByte(data, '/') + startMatch := int(m.byteOffset) == sep+1 + endMatch := endOffset == uint32(len(data)) + if startMatch && endMatch { + addScore("Base", scoreBase) + } else if startMatch || endMatch { + addScore("EdgeBase", (scoreBase+scorePartialBase)/2) + } else if sep < int(m.byteOffset) { + addScore("InnerBase", scorePartialBase) + } + } else if sec, si, ok := p.findSymbol(m); ok { + startMatch := sec.Start == m.byteOffset + endMatch := sec.End == endOffset + if startMatch && endMatch { + addScore("Symbol", scoreSymbol) + } else if startMatch || endMatch { + addScore("EdgeSymbol", (scoreSymbol+scorePartialSymbol)/2) + } else { + addScore("OverlapSymbol", scorePartialSymbol) + } + + // Score based on symbol data + if si != nil { + symbolKind := ctags.ParseSymbolKind(si.Kind) + sym := sectionSlice(data, sec) + + addScore(fmt.Sprintf("kind:%s:%s", language, si.Kind), scoreSymbolKind(language, filename, sym, symbolKind)) + + // This is from a symbol tree, so we need to store the symbol + // information. + if m.symbol { + if symbolInfo == nil { + symbolInfo = make([]*Symbol, len(ms)) + } + // findSymbols does not hydrate in Sym. So we need to store it. + si.Sym = string(sym) + symbolInfo[i] = si + } + } + } + + // scoreWeight != 1 means it affects score + if !epsilonEqualsOne(m.scoreWeight) { + score = score * m.scoreWeight + if opts.DebugScore { + what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight) + } + } + + if score > bestLine.score { + bestLine.score = score + bestLine.debugScore = what + } + } + + if opts.DebugScore { + bestLine.debugScore = fmt.Sprintf("score:%.2f <- %s", bestLine.score, strings.TrimSuffix(bestLine.debugScore, ", ")) + } + + return bestLine, symbolInfo +} + +// scoreLineBM25 computes the score of a line according to BM25, the most common scoring algorithm for text search: +// https://en.wikipedia.org/wiki/Okapi_BM25. Compared to the standard scoreLine algorithm, this score rewards multiple +// term matches on a line. +// Notes: +// - This BM25 calculation skips inverse document frequency (idf) to keep the implementation simple. +// - It uses the same calculateTermFrequency method as BM25 file scoring, which boosts filename and symbol matches. +func (p *contentProvider) scoreLineBM25(ms []*candidateMatch, lineNumber int) (float64, []*Symbol) { + // If this is a filename, then don't compute BM25. The score would not be comparable to line scores. + if lineNumber < 0 { + return 0, nil + } + + // Use standard parameter defaults used in Lucene (https://lucene.apache.org/core/10_1_0/core/org/apache/lucene/search/similarities/BM25Similarity.html) + k, b := 1.2, 0.75 + + // Calculate the length ratio of this line. As a heuristic, we assume an average line length of 100 characters. + // Usually the calculation would be based on terms, but using bytes should work fine, as we're just computing a ratio. + data := p.data(false) + nl := p.newlines() + lineLength := len(nl.getLines(data, lineNumber, lineNumber+1)) + L := float64(lineLength) / 100.0 + + score := 0.0 + tfs := p.calculateTermFrequency(ms, termDocumentFrequency{}) + for _, f := range tfs { + score += ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f)) + } + + // Check if any match comes from a symbol match tree, and if so hydrate in symbol information + var symbolInfo []*Symbol + for _, m := range ms { + if m.symbol { + if sec, si, ok := p.findSymbol(m); ok && si != nil { + // findSymbols does not hydrate in Sym. So we need to store it. + sym := sectionSlice(data, sec) + si.Sym = string(sym) + symbolInfo = append(symbolInfo, si) + } + } + } + return score, symbolInfo +} + +// termDocumentFrequency is a map "term" -> "number of documents that contain the term" +type termDocumentFrequency map[string]int + +// calculateTermFrequency computes the term frequency for the file match. +// Notes: +// - Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles. +// - Symbol matches also count more than content matches, to reward matches on symbol definitions. +func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { + // Treat each candidate match as a term and compute the frequencies. For now, ignore case sensitivity and + // ignore whether the match is a word boundary. + termFreqs := map[string]int{} + for _, m := range cands { + term := string(m.substrLowered) + if m.fileName || p.matchesSymbol(m) { + termFreqs[term] += 5 + } else { + termFreqs[term]++ } - fmt.Fprintf(&b, ":%.2f, ", computed) - m.Debug += b.String() } - m.Score += computed + + for term := range termFreqs { + df[term] += 1 + } + return termFreqs } // scoreFile computes a score for the file match using various scoring signals, like @@ -110,30 +331,20 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn } } -// idf computes the inverse document frequency for a term. nq is the number of -// documents that contain the term and documentCount is the total number of -// documents in the corpus. -func idf(nq, documentCount int) float64 { - return math.Log(1.0 + ((float64(documentCount) - float64(nq) + 0.5) / (float64(nq) + 0.5))) -} - -// termDocumentFrequency is a map "term" -> "number of documents that contain the term" -type termDocumentFrequency map[string]int - // termFrequency stores the term frequencies for doc. type termFrequency struct { doc uint32 tf map[string]int } -// scoreFilesUsingBM25 computes the score according to BM25, the most common -// scoring algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. +// scoreFilesUsingBM25 computes the score according to BM25, the most common scoring algorithm for text search: +// https://en.wikipedia.org/wiki/Okapi_BM25. // -// This scoring strategy ignores all other signals including document ranks. -// This keeps things simple for now, since BM25 is not normalized and can be -// tricky to combine with other scoring signals. +// Unlike standard file scoring, this scoring strategy ignores all other signals including document ranks. This keeps +// things simple for now, since BM25 is not normalized and can be tricky to combine with other scoring signals. It also +// ignores the individual LineMatch and ChunkMatch scores, instead calculating a score over all matches in the file. func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, tfs []termFrequency, df termDocumentFrequency, opts *SearchOptions) { - // Use standard parameter defaults (used in Lucene and academic papers) + // Use standard parameter defaults used in Lucene (https://lucene.apache.org/core/10_1_0/core/org/apache/lucene/search/similarities/BM25Similarity.html) k, b := 1.2, 0.75 averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs()) @@ -166,3 +377,10 @@ func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, tfs []termFrequ } } } + +// idf computes the inverse document frequency for a term. nq is the number of +// documents that contain the term and documentCount is the total number of +// documents in the corpus. +func idf(nq, documentCount int) float64 { + return math.Log(1.0 + ((float64(documentCount) - float64(nq) + 0.5) / (float64(nq) + 0.5))) +} From 59a1f690e93fa088cdd48e63a8469bde142cfeab Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 15 Jan 2025 10:17:02 -0800 Subject: [PATCH 2/2] Address code review comments --- api.go | 6 +++--- score.go | 17 ++++++++--------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/api.go b/api.go index b829c6878..4e2285940 100644 --- a/api.go +++ b/api.go @@ -19,6 +19,7 @@ import ( "encoding/json" "errors" "fmt" + "math" "reflect" "strconv" "strings" @@ -33,7 +34,6 @@ const ( stringHeaderBytes uint64 = 16 pointerSize uint64 = 8 interfaceBytes uint64 = 16 - maxUInt16 = 0xffff ) // FileMatch contains all the matches within a file. @@ -688,7 +688,7 @@ func (r *Repository) UnmarshalJSON(data []byte) error { // Normalize the repo score within [0, maxUint16), with the midpoint at 5,000. // This means popular repos (roughly ones with over 5,000 stars) see diminishing // returns from more stars. - r.Rank = uint16(r.priority / (5000.0 + r.priority) * maxUInt16) + r.Rank = uint16(r.priority / (5000.0 + r.priority) * math.MaxUint16) } } @@ -704,7 +704,7 @@ func monthsSince1970(t time.Time) uint16 { return 0 } months := int(t.Year()-1970)*12 + int(t.Month()-1) - return uint16(min(months, maxUInt16)) + return uint16(min(months, math.MaxUint16)) } // MergeMutable will merge x into r. mutated will be true if it made any diff --git a/score.go b/score.go index 45f9737aa..2559b160c 100644 --- a/score.go +++ b/score.go @@ -114,7 +114,7 @@ func (p *contentProvider) scoreLine(ms []*candidateMatch, language string, lineN filename := p.data(true) var symbolInfo []*Symbol - var bestLine lineScore + var bestScore lineScore for i, m := range ms { data := p.data(m.fileName) @@ -181,17 +181,17 @@ func (p *contentProvider) scoreLine(ms []*candidateMatch, language string, lineN } } - if score > bestLine.score { - bestLine.score = score - bestLine.debugScore = what + if score > bestScore.score { + bestScore.score = score + bestScore.debugScore = what } } if opts.DebugScore { - bestLine.debugScore = fmt.Sprintf("score:%.2f <- %s", bestLine.score, strings.TrimSuffix(bestLine.debugScore, ", ")) + bestScore.debugScore = fmt.Sprintf("score:%.2f <- %s", bestScore.score, strings.TrimSuffix(bestScore.debugScore, ", ")) } - return bestLine, symbolInfo + return bestScore, symbolInfo } // scoreLineBM25 computes the score of a line according to BM25, the most common scoring algorithm for text search: @@ -211,9 +211,8 @@ func (p *contentProvider) scoreLineBM25(ms []*candidateMatch, lineNumber int) (f // Calculate the length ratio of this line. As a heuristic, we assume an average line length of 100 characters. // Usually the calculation would be based on terms, but using bytes should work fine, as we're just computing a ratio. - data := p.data(false) nl := p.newlines() - lineLength := len(nl.getLines(data, lineNumber, lineNumber+1)) + lineLength := nl.lineStart(lineNumber+1) - nl.lineStart(lineNumber) L := float64(lineLength) / 100.0 score := 0.0 @@ -228,7 +227,7 @@ func (p *contentProvider) scoreLineBM25(ms []*candidateMatch, lineNumber int) (f if m.symbol { if sec, si, ok := p.findSymbol(m); ok && si != nil { // findSymbols does not hydrate in Sym. So we need to store it. - sym := sectionSlice(data, sec) + sym := sectionSlice(p.data(false), sec) si.Sym = string(sym) symbolInfo = append(symbolInfo, si) }