Skip to content

Commit

Permalink
Ranking: include filename matches in bm25 (#757)
Browse files Browse the repository at this point in the history
BM25 considers a file to be a better match when there are many occurrences of
terms in the file. It's important to count all term occurrences, including
those in other fields like the filename.

For historical reasons, Zoekt trims all filename matches from a result if there
are any content matches. This meant that in BM25 scoring, we didn't account for
filename matches.

This PR refactors the match code so that we only trim filename matches when
assembling the final `FileMatch`. We retain filename matches when creating
`candidateMatch`, which lets BM25 scoring use them. Even without the better
BM25 scoring, I think this refactor makes the code easier to follow.
  • Loading branch information
jtibshirani authored Apr 17, 2024
1 parent 59ab949 commit 43b9225
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 118 deletions.
79 changes: 67 additions & 12 deletions build/scoring_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,59 @@ func TestFileNameMatch(t *testing.T) {
}

for _, c := range cases {
checkScoring(t, c, ctags.UniversalCTags)
checkScoring(t, c, false, ctags.UniversalCTags)
}
}

func TestBM25(t *testing.T) {
exampleJava, err := os.ReadFile("./testdata/example.java")
if err != nil {
t.Fatal(err)
}

cases := []scoreCase{
{
// Matches on both filename and content
fileName: "example.java",
query: &query.Substring{Pattern: "example"},
content: exampleJava,
language: "Java",
// keyword-score:1.63 (sum-tf: 6.00, length-ratio: 2.00)
wantScore: 1.63,
}, {
// Matches only on content
fileName: "example.java",
query: &query.And{Children: []query.Q{
&query.Substring{Pattern: "inner"},
&query.Substring{Pattern: "static"},
&query.Substring{Pattern: "interface"},
}},
content: exampleJava,
language: "Java",
// keyword-score:5.75 (sum-tf: 56.00, length-ratio: 2.00)
wantScore: 5.75,
},
{
// Matches only on filename
fileName: "example.java",
query: &query.Substring{Pattern: "java"},
content: exampleJava,
language: "Java",
// keyword-score:1.07 (sum-tf: 2.00, length-ratio: 2.00)
wantScore: 1.07,
},
{
// Matches only on filename, and content is missing
fileName: "a/b/c/config.go",
query: &query.Substring{Pattern: "config.go"},
language: "Go",
// keyword-score:1.91 (sum-tf: 2.00, length-ratio: 0.00)
wantScore: 1.91,
},
}

for _, c := range cases {
checkScoring(t, c, true, ctags.UniversalCTags)
}
}

Expand Down Expand Up @@ -197,7 +249,7 @@ func TestJava(t *testing.T) {
}

for _, c := range cases {
checkScoring(t, c, ctags.UniversalCTags)
checkScoring(t, c, false, ctags.UniversalCTags)
}
}

Expand Down Expand Up @@ -261,7 +313,7 @@ func TestKotlin(t *testing.T) {
parserType := ctags.UniversalCTags
for _, c := range cases {
t.Run(c.language, func(t *testing.T) {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
})
}
}
Expand Down Expand Up @@ -318,7 +370,7 @@ func TestCpp(t *testing.T) {
parserType := ctags.UniversalCTags
for _, c := range cases {
t.Run(c.language, func(t *testing.T) {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
})
}
}
Expand Down Expand Up @@ -350,7 +402,7 @@ func TestPython(t *testing.T) {

for _, parserType := range []ctags.CTagsParserType{ctags.UniversalCTags, ctags.ScipCTags} {
for _, c := range cases {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
}
}

Expand All @@ -364,7 +416,7 @@ func TestPython(t *testing.T) {
wantScore: 7860,
}

checkScoring(t, scipOnlyCase, ctags.ScipCTags)
checkScoring(t, scipOnlyCase, false, ctags.ScipCTags)
}

func TestRuby(t *testing.T) {
Expand Down Expand Up @@ -402,7 +454,7 @@ func TestRuby(t *testing.T) {

for _, parserType := range []ctags.CTagsParserType{ctags.UniversalCTags, ctags.ScipCTags} {
for _, c := range cases {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
}
}
}
Expand Down Expand Up @@ -450,7 +502,7 @@ func TestScala(t *testing.T) {

parserType := ctags.UniversalCTags
for _, c := range cases {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
}
}

Expand Down Expand Up @@ -509,7 +561,7 @@ func Get() {

for _, parserType := range []ctags.CTagsParserType{ctags.UniversalCTags, ctags.ScipCTags} {
for _, c := range cases {
checkScoring(t, c, parserType)
checkScoring(t, c, false, parserType)
}
}
}
Expand All @@ -532,7 +584,7 @@ func skipIfCTagsUnavailable(t *testing.T, parserType ctags.CTagsParserType) {
}
}

func checkScoring(t *testing.T, c scoreCase, parserType ctags.CTagsParserType) {
func checkScoring(t *testing.T, c scoreCase, keywordScoring bool, parserType ctags.CTagsParserType) {
skipIfCTagsUnavailable(t, parserType)

name := c.language
Expand Down Expand Up @@ -572,7 +624,10 @@ func checkScoring(t *testing.T, c scoreCase, parserType ctags.CTagsParserType) {
}
defer ss.Close()

srs, err := ss.Search(context.Background(), c.query, &zoekt.SearchOptions{DebugScore: true})
srs, err := ss.Search(context.Background(), c.query, &zoekt.SearchOptions{
UseKeywordScoring: keywordScoring,
ChunkMatches: true,
DebugScore: true})
if err != nil {
t.Fatal(err)
}
Expand All @@ -582,7 +637,7 @@ func checkScoring(t *testing.T, c scoreCase, parserType ctags.CTagsParserType) {
}

if got := srs.Files[0].Score; math.Abs(got-c.wantScore) > epsilon {
t.Fatalf("score: want %f, got %f\ndebug: %s\ndebugscore: %s", c.wantScore, got, srs.Files[0].Debug, srs.Files[0].LineMatches[0].DebugScore)
t.Fatalf("score: want %f, got %f\ndebug: %s\ndebugscore: %s", c.wantScore, got, srs.Files[0].Debug, srs.Files[0].ChunkMatches[0].DebugScore)
}

if got := srs.Files[0].Language; got != c.language {
Expand Down
134 changes: 79 additions & 55 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,82 +137,106 @@ func (p *contentProvider) findOffset(filename bool, r uint32) uint32 {
return byteOff
}

// fillMatches converts the internal candidateMatch slice into our API's LineMatch.
// It only ever returns content XOR filename matches, not both. If there are any
// content matches, these are always returned, and we omit filename matches.
//
// Performance invariant: ms is sorted and non-overlapping.
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
var result []LineMatch
if ms[0].fileName {
score, debugScore, _ := p.candidateMatchScore(ms, language, debug)
var filenameMatches []*candidateMatch
contentMatches := ms[:0]

// There is only "line" in a filename.
res := LineMatch{
Line: p.id.fileName(p.idx),
FileName: true,

Score: score,
DebugScore: debugScore,
for _, m := range ms {
if m.fileName {
filenameMatches = append(filenameMatches, m)
} else {
contentMatches = append(contentMatches, m)
}
}

for _, m := range ms {
res.LineFragments = append(res.LineFragments, LineFragmentMatch{
LineOffset: int(m.byteOffset),
MatchLength: int(m.byteMatchSz),
Offset: m.byteOffset,
})
// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
contentMatches = breakMatchesOnNewlines(contentMatches, p.data(false))
return p.fillContentMatches(contentMatches, numContextLines, language, debug)
}

result = []LineMatch{res}
}
} else {
ms = breakMatchesOnNewlines(ms, p.data(false))
result = p.fillContentMatches(ms, numContextLines, language, debug)
// Otherwise, we return a single line containing the filematch match.
score, debugScore, _ := p.candidateMatchScore(filenameMatches, language, debug)
res := LineMatch{
Line: p.id.fileName(p.idx),
FileName: true,
Score: score,
DebugScore: debugScore,
}

return result
for _, m := range ms {
res.LineFragments = append(res.LineFragments, LineFragmentMatch{
LineOffset: int(m.byteOffset),
MatchLength: int(m.byteMatchSz),
Offset: m.byteOffset,
})
}

return []LineMatch{res}

}

// fillChunkMatches converts the internal candidateMatch slice into our APIs ChunkMatch.
// fillChunkMatches converts the internal candidateMatch slice into our API's ChunkMatch.
// It only ever returns content XOR filename matches, not both. If there are any content
// matches, these are always returned, and we omit filename matches.
//
// Performance invariant: ms is sorted and non-overlapping.
//
// Note: the byte slices may be backed by mmapped data, so before being
// returned by the API it needs to be copied.
func (p *contentProvider) fillChunkMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []ChunkMatch {
var result []ChunkMatch
if ms[0].fileName {
// If the first match is a filename match, there will only be
// one match and the matched content will be the filename.

score, debugScore, _ := p.candidateMatchScore(ms, language, debug)
var filenameMatches []*candidateMatch
contentMatches := ms[:0]

fileName := p.id.fileName(p.idx)
ranges := make([]Range, 0, len(ms))
for _, m := range ms {
ranges = append(ranges, Range{
Start: Location{
ByteOffset: m.byteOffset,
LineNumber: 1,
Column: uint32(utf8.RuneCount(fileName[:m.byteOffset]) + 1),
},
End: Location{
ByteOffset: m.byteOffset + m.byteMatchSz,
LineNumber: 1,
Column: uint32(utf8.RuneCount(fileName[:m.byteOffset+m.byteMatchSz]) + 1),
},
})
for _, m := range ms {
if m.fileName {
filenameMatches = append(filenameMatches, m)
} else {
contentMatches = append(contentMatches, m)
}
}

result = []ChunkMatch{{
Content: fileName,
ContentStart: Location{ByteOffset: 0, LineNumber: 1, Column: 1},
Ranges: ranges,
FileName: true,
// If there are any content matches, we only return these and skip filename matches.
if len(contentMatches) > 0 {
return p.fillContentChunkMatches(contentMatches, numContextLines, language, debug)
}

Score: score,
DebugScore: debugScore,
}}
} else {
result = p.fillContentChunkMatches(ms, numContextLines, language, debug)
// Otherwise, we return a single chunk representing the filename match.
score, debugScore, _ := p.candidateMatchScore(filenameMatches, language, debug)
fileName := p.id.fileName(p.idx)
ranges := make([]Range, 0, len(ms))
for _, m := range ms {
ranges = append(ranges, Range{
Start: Location{
ByteOffset: m.byteOffset,
LineNumber: 1,
Column: uint32(utf8.RuneCount(fileName[:m.byteOffset]) + 1),
},
End: Location{
ByteOffset: m.byteOffset + m.byteMatchSz,
LineNumber: 1,
Column: uint32(utf8.RuneCount(fileName[:m.byteOffset+m.byteMatchSz]) + 1),
},
})
}

return result
return []ChunkMatch{{
Content: fileName,
ContentStart: Location{ByteOffset: 0, LineNumber: 1, Column: 1},
Ranges: ranges,
FileName: true,

Score: score,
DebugScore: debugScore,
}}
}

func (p *contentProvider) fillContentMatches(ms []*candidateMatch, numContextLines int, language string, debug bool) []LineMatch {
Expand Down
Loading

0 comments on commit 43b9225

Please sign in to comment.