Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ranking: add IDF to BM25 score calculation #788

Merged
merged 8 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -946,9 +946,15 @@ type SearchOptions struct {
// will be used. This option is temporary and is only exposed for testing/ tuning purposes.
DocumentRanksWeight float64

// EXPERIMENTAL. If true, use text-search style scoring instead of the default scoring formula.
// The scoring algorithm treats each match in a file as a term and computes an approximation to
// BM25. When enabled, all other scoring signals are ignored, including document ranks.
// EXPERIMENTAL. If true, use text-search style scoring instead of the default
// scoring formula. The scoring algorithm treats each match in a file as a term
// and computes an approximation to BM25.
//
// The calculation of IDF assumes that Zoekt visits all documents containing any
// of the query terms during evaluation. This is true, for example, if all query
// terms are ORed together.
//
// When enabled, all other scoring signals are ignored, including document ranks.
UseBM25Scoring bool

// Trace turns on opentracing for this request if true and if the Jaeger address was provided as
Expand Down
16 changes: 8 additions & 8 deletions build/scoring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ func TestBM25(t *testing.T) {
query: &query.Substring{Pattern: "example"},
content: exampleJava,
language: "Java",
// bm25-score:1.69 (sum-tf: 7.00, length-ratio: 2.00)
wantScore: 1.82,
// bm25-score: 0.57 <- sum-termFrequencyScore: 10.00, length-ratio: 1.00
wantScore: 0.57,
}, {
// Matches only on content
fileName: "example.java",
Expand All @@ -89,25 +89,25 @@ func TestBM25(t *testing.T) {
}},
content: exampleJava,
language: "Java",
// bm25-score:5.75 (sum-tf: 56.00, length-ratio: 2.00)
wantScore: 5.75,
// bm25-score: 1.75 <- sum-termFrequencyScore: 56.00, length-ratio: 1.00
wantScore: 1.75,
},
{
// Matches only on filename
fileName: "example.java",
query: &query.Substring{Pattern: "java"},
content: exampleJava,
language: "Java",
// bm25-score:1.07 (sum-tf: 2.00, length-ratio: 2.00)
wantScore: 1.55,
// bm25-score: 0.51 <- sum-termFrequencyScore: 5.00, length-ratio: 1.00
wantScore: 0.51,
},
{
// Matches only on filename, and content is missing
fileName: "a/b/c/config.go",
query: &query.Substring{Pattern: "config.go"},
language: "Go",
// bm25-score:1.91 (sum-tf: 2.00, length-ratio: 0.00)
wantScore: 2.08,
// bm25-score: 0.60 <- sum-termFrequencyScore: 5.00, length-ratio: 0.00
wantScore: 0.60,
},
}

Expand Down
32 changes: 28 additions & 4 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ func (d *indexData) Search(ctx context.Context, q query.Q, opts *SearchOptions)
docCount := uint32(len(d.fileBranchMasks))
lastDoc := int(-1)

// document frequency per term
df := make(termDocumentFrequency)

// term frequency per file match
var tfs []termFrequency

nextFileMatch:
for {
canceled := false
Expand Down Expand Up @@ -317,8 +323,14 @@ nextFileMatch:
fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore)
}

var tf map[string]int
if opts.UseBM25Scoring {
d.scoreFileUsingBM25(&fileMatch, nextDoc, finalCands, opts)
// For BM25 scoring, the calculation of the score is split in two parts. Here we
// calculate the term frequencies for the current document and update the
// document frequencies. Since we don't store document frequencies in the index,
// we have to defer the calculation of the final BM25 score to after the whole
// shard has been processed.
tf = calculateTermFrequency(finalCands, df)
} else {
// Use the standard, non-experimental scoring method by default
d.scoreFile(&fileMatch, nextDoc, mt, known, opts)
Expand All @@ -339,16 +351,28 @@ nextFileMatch:
repoMatchCount += len(fileMatch.LineMatches)
repoMatchCount += matchedChunkRanges

if opts.DebugScore {
fileMatch.Debug = fmt.Sprintf("score:%.2f <- %s", fileMatch.Score, fileMatch.Debug)
if opts.UseBM25Scoring {
// Invariant: tfs[i] belongs to res.Files[i]
tfs = append(tfs, termFrequency{
doc: nextDoc,
tf: tf,
})
}
keegancsmith marked this conversation as resolved.
Show resolved Hide resolved

res.Files = append(res.Files, fileMatch)

res.Stats.MatchCount += len(fileMatch.LineMatches)
res.Stats.MatchCount += matchedChunkRanges
res.Stats.FileCount++
}

// Calculate BM25 score for all file matches in the shard. We assume that we
// have seen all documents containing any of the terms in the query so that df
// correctly reflects the document frequencies. This is true, for example, if
// all terms in the query are ORed together.
if opts.UseBM25Scoring {
d.scoreFilesUsingBM25(res.Files, tfs, df, opts)
}

for _, md := range d.repoMetaData {
r := md
addRepo(&res, &r)
Expand Down
94 changes: 60 additions & 34 deletions score.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ func (m *FileMatch) addScore(what string, computed float64, raw float64, debugSc
m.Score += computed
}

func (m *FileMatch) addBM25Score(score float64, sumTf float64, L float64, debugScore bool) {
if debugScore {
m.Debug += fmt.Sprintf("bm25-score:%.2f (sum-tf: %.2f, length-ratio: %.2f)", score, sumTf, L)
}
m.Score += score
}

// scoreFile computes a score for the file match using various scoring signals, like
// whether there's an exact match on a symbol, the number of query clauses that matched, etc.
func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) {
Expand Down Expand Up @@ -111,54 +104,87 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn
addScore("repo-rank", scoreRepoRankFactor*float64(md.Rank)/maxUInt16)

if opts.DebugScore {
fileMatch.Debug = strings.TrimSuffix(fileMatch.Debug, ", ")
fileMatch.Debug = fmt.Sprintf("score: %.2f <- %s", fileMatch.Score, strings.TrimSuffix(fileMatch.Debug, ", "))
}
}

// scoreFileUsingBM25 computes a score for the file match using an approximation to BM25, the most common scoring
// algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. It implements all parts of the formula
// except inverse document frequency (idf), since we don't have access to global term frequency statistics.
//
// Filename matches count twice as much as content matches. This mimics a common text search strategy where you
// 'boost' matches on document titles.
// calculateTermFrequency computes the term frequency for the file match.
//
// This scoring strategy ignores all other signals including document ranks. This keeps things simple for now,
// since BM25 is not normalized and can be tricky to combine with other scoring signals.
func (d *indexData) scoreFileUsingBM25(fileMatch *FileMatch, doc uint32, cands []*candidateMatch, opts *SearchOptions) {
// Filename matches count more than content matches. This mimics a common text
stefanhengl marked this conversation as resolved.
Show resolved Hide resolved
// search strategy where you 'boost' matches on document titles.
func calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int {
// Treat each candidate match as a term and compute the frequencies. For now, ignore case
// sensitivity and treat filenames and symbols the same as content.
termFreqs := map[string]int{}
for _, cand := range cands {
term := string(cand.substrLowered)

if cand.fileName {
termFreqs[term] += 5
} else {
termFreqs[term]++
}
}

// Compute the file length ratio. Usually the calculation would be based on terms, but using
// bytes should work fine, as we're just computing a ratio.
fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc])
numFiles := len(d.boundaries)
stefanhengl marked this conversation as resolved.
Show resolved Hide resolved
averageFileLength := float64(d.boundaries[numFiles-1]) / float64(numFiles)
for term := range termFreqs {
df[term] += 1
jtibshirani marked this conversation as resolved.
Show resolved Hide resolved
}

return termFreqs
}

// idf computes the inverse document frequency for a term. nq is the number of
// documents that contain the term and documentCount is the total number of
// documents in the corpus.
func idf(nq, documentCount int) float64 {
return math.Log(1.0 + ((float64(documentCount) - float64(nq) + 0.5) / (float64(nq) + 0.5)))
}

// termDocumentFrequency is a map "term" -> "number of documents that contain the term"
type termDocumentFrequency map[string]int

// termFrequency stores the term frequencies for doc.
type termFrequency struct {
doc uint32
tf map[string]int
}

// scoreFilesUsingBM25 computes the score according to BM25, the most common
// scoring algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25.
//
// This scoring strategy ignores all other signals including document ranks.
// This keeps things simple for now, since BM25 is not normalized and can be
// tricky to combine with other scoring signals.
func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, tfs []termFrequency, df termDocumentFrequency, opts *SearchOptions) {
// Use standard parameter defaults (used in Lucene and academic papers)
k, b := 1.2, 0.75

averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs())
// This is very unlikely, but explicitly guard against division by zero.
if averageFileLength == 0 {
averageFileLength++
}
L := fileLength / averageFileLength

// Use standard parameter defaults (used in Lucene and academic papers)
k, b := 1.2, 0.75
sumTf := 0.0 // Just for debugging
score := 0.0
for _, freq := range termFreqs {
tf := float64(freq)
sumTf += tf
score += ((k + 1.0) * tf) / (k*(1.0-b+b*L) + tf)
}
for i := range tfs {
score := 0.0

// Compute the file length ratio. Usually the calculation would be based on terms, but using
// bytes should work fine, as we're just computing a ratio.
doc := tfs[i].doc
fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc])

fileMatch.addBM25Score(score, sumTf, L, opts.DebugScore)
L := fileLength / averageFileLength

sumTF := 0 // Just for debugging
for term, f := range tfs[i].tf {
sumTF += f
tfScore := ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f))
score += idf(df[term], int(d.numDocs())) * tfScore
}

fileMatches[i].Score = score

if opts.DebugScore {
fileMatches[i].Debug = fmt.Sprintf("bm25-score: %.2f <- sum-termFrequencies: %d, length-ratio: %.2f", score, sumTF, L)
}
}
}
51 changes: 51 additions & 0 deletions score_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package zoekt

import (
"maps"
"testing"
)

func TestCalculateTermFrequency(t *testing.T) {
cases := []struct {
cands []*candidateMatch
wantDF termDocumentFrequency
wantTermFrequencies map[string]int
}{{
cands: []*candidateMatch{
{substrLowered: []byte("foo")},
{substrLowered: []byte("foo")},
{substrLowered: []byte("bar")},
{
substrLowered: []byte("bas"),
fileName: true,
},
},
wantDF: termDocumentFrequency{
"foo": 1,
"bar": 1,
"bas": 1,
},
wantTermFrequencies: map[string]int{
"foo": 2,
"bar": 1,
"bas": 5,
},
},
}

for _, c := range cases {
t.Run("", func(t *testing.T) {
fm := FileMatch{}
df := make(termDocumentFrequency)
tf := calculateTermFrequency(c.cands, df)

if !maps.Equal(df, c.wantDF) {
t.Errorf("got %v, want %v", df, c.wantDF)
}

if !maps.Equal(tf, c.wantTermFrequencies) {
t.Errorf("got %v, want %v", fm, c.wantTermFrequencies)
}
})
}
}
Loading