Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BM25 scoring for chunk matches #889

Merged
merged 2 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -135,6 +136,22 @@ func (m *FileMatch) sizeBytes() (sz uint64) {
return
}

// addScore increments the score of the FileMatch by the computed score. If
// debugScore is true, it also adds a debug string to the FileMatch. If raw is
// -1, it is ignored. Otherwise, it is added to the debug string.
func (m *FileMatch) addScore(what string, computed float64, raw float64, debugScore bool) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this here from score.go so all FileMatch methods are in the same place.

if computed != 0 && debugScore {
var b strings.Builder
fmt.Fprintf(&b, "%s", what)
if raw != -1 {
fmt.Fprintf(&b, "(%s)", strconv.FormatFloat(raw, 'f', -1, 64))
}
fmt.Fprintf(&b, ":%.2f, ", computed)
m.Debug += b.String()
}
m.Score += computed
}

// ChunkMatch is a set of non-overlapping matches within a contiguous range of
// lines in the file.
type ChunkMatch struct {
Expand Down Expand Up @@ -671,7 +688,7 @@ func (r *Repository) UnmarshalJSON(data []byte) error {
// Normalize the repo score within [0, maxUint16), with the midpoint at 5,000.
// This means popular repos (roughly ones with over 5,000 stars) see diminishing
// returns from more stars.
r.Rank = uint16(r.priority / (5000.0 + r.priority) * maxUInt16)
r.Rank = uint16(r.priority / (5000.0 + r.priority) * math.MaxUint16)
}
}

Expand All @@ -687,7 +704,7 @@ func monthsSince1970(t time.Time) uint16 {
return 0
}
months := int(t.Year()-1970)*12 + int(t.Month()-1)
return uint16(min(months, maxUInt16))
return uint16(min(months, math.MaxUint16))
}

// MergeMutable will merge x into r. mutated will be true if it made any
Expand Down Expand Up @@ -976,7 +993,8 @@ type SearchOptions struct {

// EXPERIMENTAL. If true, use text-search style scoring instead of the default
// scoring formula. The scoring algorithm treats each match in a file as a term
// and computes an approximation to BM25.
// and computes an approximation to BM25. When enabled, BM25 scoring is used for
// the overall FileMatch score, as well as individual LineMatch and ChunkMatch scores.
//
// The calculation of IDF assumes that Zoekt visits all documents containing any
// of the query terms during evaluation. This is true, for example, if all query
Expand Down
17 changes: 15 additions & 2 deletions build/scoring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,21 @@ func TestBM25(t *testing.T) {
language: "Java",
// bm25-score: 1.81 <- sum-termFrequencyScore: 116.00, length-ratio: 1.00
wantScore: 1.81,
// line 3: public class InnerClasses {
wantBestLineMatch: 3,
// line 54: private static <A, B> B runInnerInterface(InnerInterface<A, B> fn, A a) {
wantBestLineMatch: 54,
}, {
// Another content-only match
fileName: "example.java",
query: &query.And{Children: []query.Q{
&query.Substring{Pattern: "system"},
&query.Substring{Pattern: "time"},
}},
content: exampleJava,
language: "Java",
// bm25-score: 0.96 <- sum-termFrequencies: 12, length-ratio: 1.00
wantScore: 0.96,
// line 59: if (System.nanoTime() > System.currentTimeMillis()) {
wantBestLineMatch: 59,
},
{
// Matches only on filename
Expand Down
184 changes: 22 additions & 162 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ package zoekt

import (
"bytes"
"fmt"
"log"
"path"
"slices"
"sort"
"strings"
"unicode"
"unicode/utf8"

Expand Down Expand Up @@ -145,7 +143,7 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 {
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch {
var filenameMatches []*candidateMatch
contentMatches := make([]*candidateMatch, 0, len(ms))

Expand All @@ -160,16 +158,16 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int,
// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
contentMatches = breakMatchesOnNewlines(contentMatches, p.data(false))
return p.fillContentMatches(contentMatches, numContextLines, language, debug)
return p.fillContentMatches(contentMatches, numContextLines, language, opts)
}

// Otherwise, we return a single line containing the filematch match.
bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug)
lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts)
res := LineMatch{
Line: p.id.fileName(p.idx),
FileName: true,
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
Score: lineScore.score,
DebugScore: lineScore.debugScore,
}

for _, m := range ms {
Expand All @@ -192,7 +190,7 @@ func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int,
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch {
func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch {
var filenameMatches []*candidateMatch
contentMatches := make([]*candidateMatch, 0, len(ms))

Expand All @@ -206,11 +204,11 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines

// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
return p.fillContentChunkMatches(contentMatches, numContextLines, language, debug)
return p.fillContentChunkMatches(contentMatches, numContextLines, language, opts)
}

// Otherwise, we return a single chunk representing the filename match.
bestMatch, _ := p.candidateMatchScore(filenameMatches, language, debug)
lineScore, _ := p.scoreLine(filenameMatches, language, -1 /* must pass -1 for filenames */, opts)
fileName := p.id.fileName(p.idx)
ranges := make([]Range, 0, len(ms))
for _, m := range ms {
Expand All @@ -233,12 +231,12 @@ func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines
ContentStart: Location{ByteOffset: 0, LineNumber: 1, Column: 1},
Ranges: ranges,
FileName: true,
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
Score: lineScore.score,
DebugScore: lineScore.debugScore,
}}
}

func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []LineMatch {
var result []LineMatch
for len(ms) > 0 {
m := ms[0]
Expand Down Expand Up @@ -296,16 +294,17 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin
finalMatch.After = p.newlines().getLines(data, num+1, num+1+numContextLines)
}

bestMatch, symbolInfo := p.candidateMatchScore(lineCands, language, debug)
finalMatch.Score = bestMatch.score
finalMatch.DebugScore = bestMatch.debugScore
lineScore, symbolInfo := p.scoreLine(lineCands, language, num, opts)
finalMatch.Score = lineScore.score
finalMatch.DebugScore = lineScore.debugScore

for i, m := range lineCands {
fragment := LineFragmentMatch{
Offset: m.byteOffset,
LineOffset: int(m.byteOffset) - lineStart,
MatchLength: int(m.byteMatchSz),
}

if i < len(symbolInfo) && symbolInfo[i] != nil {
fragment.SymbolInfo = symbolInfo[i]
}
Expand All @@ -317,8 +316,7 @@ func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLin
return result
}

func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch {
newlines := p.newlines()
func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numContextLines int, language string, opts *SearchOptions) []ChunkMatch {
data := p.data(false)

// columnHelper prevents O(len(ms) * len(data)) lookups for all columns.
Expand All @@ -332,11 +330,10 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
sort.Sort((sortByOffsetSlice)(ms))
}

newlines := p.newlines()
chunks := chunkCandidates(ms, newlines, numContextLines)
chunkMatches := make([]ChunkMatch, 0, len(chunks))
for _, chunk := range chunks {
bestMatch, symbolInfo := p.candidateMatchScore(chunk.candidates, language, debug)

ranges := make([]Range, 0, len(chunk.candidates))
for _, cm := range chunk.candidates {
startOffset := cm.byteOffset
Expand All @@ -363,14 +360,7 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
}
firstLineStart := newlines.lineStart(firstLineNumber)

bestLineMatch := 0
if bestMatch.match != nil {
bestLineMatch = newlines.atOffset(bestMatch.match.byteOffset)
if debug {
bestMatch.debugScore = fmt.Sprintf("%s, (line: %d)", bestMatch.debugScore, bestLineMatch)
}
}

chunkScore, symbolInfo := p.scoreChunk(chunk.candidates, language, opts)
chunkMatches = append(chunkMatches, ChunkMatch{
Content: newlines.getLines(data, firstLineNumber, int(chunk.lastLine)+numContextLines+1),
ContentStart: Location{
Expand All @@ -381,9 +371,9 @@ func (p *contentProvider) fillContentChunkMatches(ms []*candidateMatch, numConte
FileName: false,
Ranges: ranges,
SymbolInfo: symbolInfo,
BestLineMatch: uint32(bestLineMatch),
Score: bestMatch.score,
DebugScore: bestMatch.debugScore,
BestLineMatch: uint32(chunkScore.bestLine),
Score: chunkScore.score,
DebugScore: chunkScore.debugScore,
})
}
return chunkMatches
Expand All @@ -405,6 +395,7 @@ type candidateChunk struct {
// output invariants: if you flatten candidates the input invariant is retained.
func chunkCandidates(ms []*candidateMatch, newlines newlines, numContextLines int) []candidateChunk {
var chunks []candidateChunk

for _, m := range ms {
startOffset := m.byteOffset
endOffset := m.byteOffset + m.byteMatchSz
Expand Down Expand Up @@ -536,10 +527,6 @@ const (
scoreKindMatch = 100.0
scoreFactorAtomMatch = 400.0

// File-only scoring signals. For now these are also bounded ~9000 to give them
// equal weight with the query-dependent signals.
scoreFileRankFactor = 9000.0

// Used for ordering line and chunk matches within a file.
scoreLineOrderFactor = 1.0

Expand Down Expand Up @@ -643,133 +630,6 @@ func (p *contentProvider) findSymbol(cm *candidateMatch) (DocumentSection, *Symb
return sec, si, true
}

// calculateTermFrequency computes the term frequency for the file match.
// Notes:
// * Filename matches count more than content matches. This mimics a common text search strategy to 'boost' matches on document titles.
// * Symbol matches also count more than content matches, to reward matches on symbol definitions.
func (p *contentProvider) calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int {
// Treat each candidate match as a term and compute the frequencies. For now, ignore case
// sensitivity and treat filenames and symbols the same as content.
termFreqs := map[string]int{}
for _, m := range cands {
term := string(m.substrLowered)
if m.fileName || p.matchesSymbol(m) {
termFreqs[term] += 5
} else {
termFreqs[term]++
}
}

for term := range termFreqs {
df[term] += 1
}
return termFreqs
}

// scoredMatch holds the score information for a candidate match.
type scoredMatch struct {
score float64
debugScore string
match *candidateMatch
}

// candidateMatchScore scores all candidate matches and returns the best-scoring match plus its score information.
// Invariant: there should be at least one input candidate, len(ms) > 0.
func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language string, debug bool) (scoredMatch, []*Symbol) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was moved to score.go and renamed to scoreLine.

score := 0.0
what := ""

addScore := func(w string, s float64) {
if s != 0 && debug {
what += fmt.Sprintf("%s:%.2f, ", w, s)
}
score += s
}

filename := p.data(true)
var symbolInfo []*Symbol

var bestMatch scoredMatch
for i, m := range ms {
data := p.data(m.fileName)

endOffset := m.byteOffset + m.byteMatchSz
startBoundary := m.byteOffset < uint32(len(data)) && (m.byteOffset == 0 || byteClass(data[m.byteOffset-1]) != byteClass(data[m.byteOffset]))
endBoundary := endOffset > 0 && (endOffset == uint32(len(data)) || byteClass(data[endOffset-1]) != byteClass(data[endOffset]))

score = 0
what = ""

if startBoundary && endBoundary {
addScore("WordMatch", scoreWordMatch)
} else if startBoundary || endBoundary {
addScore("PartialWordMatch", scorePartialWordMatch)
}

if m.fileName {
sep := bytes.LastIndexByte(data, '/')
startMatch := int(m.byteOffset) == sep+1
endMatch := endOffset == uint32(len(data))
if startMatch && endMatch {
addScore("Base", scoreBase)
} else if startMatch || endMatch {
addScore("EdgeBase", (scoreBase+scorePartialBase)/2)
} else if sep < int(m.byteOffset) {
addScore("InnerBase", scorePartialBase)
}
} else if sec, si, ok := p.findSymbol(m); ok {
startMatch := sec.Start == m.byteOffset
endMatch := sec.End == endOffset
if startMatch && endMatch {
addScore("Symbol", scoreSymbol)
} else if startMatch || endMatch {
addScore("EdgeSymbol", (scoreSymbol+scorePartialSymbol)/2)
} else {
addScore("OverlapSymbol", scorePartialSymbol)
}

// Score based on symbol data
if si != nil {
symbolKind := ctags.ParseSymbolKind(si.Kind)
sym := sectionSlice(data, sec)

addScore(fmt.Sprintf("kind:%s:%s", language, si.Kind), scoreSymbolKind(language, filename, sym, symbolKind))

// This is from a symbol tree, so we need to store the symbol
// information.
if m.symbol {
if symbolInfo == nil {
symbolInfo = make([]*Symbol, len(ms))
}
// findSymbols does not hydrate in Sym. So we need to store it.
si.Sym = string(sym)
symbolInfo[i] = si
}
}
}

// scoreWeight != 1 means it affects score
if !epsilonEqualsOne(m.scoreWeight) {
score = score * m.scoreWeight
if debug {
what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight)
}
}

if score > bestMatch.score {
bestMatch.score = score
bestMatch.debugScore = what
bestMatch.match = m
}
}

if debug {
bestMatch.debugScore = fmt.Sprintf("score:%.2f <- %s", bestMatch.score, strings.TrimSuffix(bestMatch.debugScore, ", "))
}

return bestMatch, symbolInfo
}

// sectionSlice will return data[sec.Start:sec.End] but will clip Start and
// End such that it won't be out of range.
func sectionSlice(data []byte, sec DocumentSection) []byte {
Expand Down
Loading
Loading