From 67629d7ca9f568d44bc47d73f256888b94d69696 Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Tue, 4 Jun 2024 13:09:48 +0200 Subject: [PATCH 1/8] ranking: add IDF to BM25 score calculation So far, we didn't include IDF in our BM25 score function. Zoekt uses a trigram index and hence doesn't compute document frequency during indexing. We could add this information to the index, but it is not immediately obvious how to tokenize code in a way that is compatible with tokens from a natural language query. Here we calulate the document frequency at query time under the assumption that we visit all documents containing any of the query term. Test plan: - Updated unit test - Context evaluation improved from 60/89 to 63/89 --- build/scoring_test.go | 16 ++++----- eval.go | 41 ++++++++++++++++++---- score.go | 79 ++++++++++++++++++++++++++++++++----------- 3 files changed, 103 insertions(+), 33 deletions(-) diff --git a/build/scoring_test.go b/build/scoring_test.go index e4e2e51e..0f3a2f2a 100644 --- a/build/scoring_test.go +++ b/build/scoring_test.go @@ -77,8 +77,8 @@ func TestBM25(t *testing.T) { query: &query.Substring{Pattern: "example"}, content: exampleJava, language: "Java", - // bm25-score:1.69 (sum-tf: 7.00, length-ratio: 2.00) - wantScore: 1.82, + // bm25-score:1.26 <- sum-termFrequencyScore: 10.00, length-ratio: 2.00 + wantScore: 1.26, }, { // Matches only on content fileName: "example.java", @@ -89,8 +89,8 @@ func TestBM25(t *testing.T) { }}, content: exampleJava, language: "Java", - // bm25-score:5.75 (sum-tf: 56.00, length-ratio: 2.00) - wantScore: 5.75, + // bm25-score:3.99 <- sum-termFrequencyScore: 56.00, length-ratio: 2.00 + wantScore: 3.99, }, { // Matches only on filename @@ -98,16 +98,16 @@ func TestBM25(t *testing.T) { query: &query.Substring{Pattern: "java"}, content: exampleJava, language: "Java", - // bm25-score:1.07 (sum-tf: 2.00, length-ratio: 2.00) - wantScore: 1.55, + // bm25-score:1.07 <- sum-termFrequencyScore: 5.00, length-ratio: 2.00 + wantScore: 1.07, }, { // Matches only on filename, and content is missing fileName: "a/b/c/config.go", query: &query.Substring{Pattern: "config.go"}, language: "Go", - // bm25-score:1.91 (sum-tf: 2.00, length-ratio: 0.00) - wantScore: 2.08, + // bm25-score:1.44 <- sum-termFrequencyScore: 5.00, length-ratio: 0.00 + wantScore: 1.44, }, } diff --git a/eval.go b/eval.go index 0d8ec91b..7f892d0b 100644 --- a/eval.go +++ b/eval.go @@ -197,6 +197,15 @@ func (d *indexData) Search(ctx context.Context, q query.Q, opts *SearchOptions) docCount := uint32(len(d.fileBranchMasks)) lastDoc := int(-1) + // document frequency per term + df := make(termDocumentFrequency) + + // term frequency scores per document + var tfs termFrequencyScore + + // used to track intermediate scores for BM25 scoring. + resFiles := fileMatchesWithScores{} + nextFileMatch: for { canceled := false @@ -318,7 +327,11 @@ nextFileMatch: } if opts.UseBM25Scoring { - d.scoreFileUsingBM25(&fileMatch, nextDoc, finalCands, opts) + // For BM25 scoring, the calculation of the score is split in two parts. Here we + // calculate the term frequency scores for the current document. Since we don't + // store document frequencies in the index, we have to defer the calculation of + // IDF and the final BM25 score to after the whole shard has been processed. + tfs = d.calculateTermFrequencyScore(&fileMatch, nextDoc, finalCands, df, opts) } else { // Use the standard, non-experimental scoring method by default d.scoreFile(&fileMatch, nextDoc, mt, known, opts) @@ -339,16 +352,32 @@ nextFileMatch: repoMatchCount += len(fileMatch.LineMatches) repoMatchCount += matchedChunkRanges - if opts.DebugScore { - fileMatch.Debug = fmt.Sprintf("score:%.2f <- %s", fileMatch.Score, fileMatch.Debug) - } - - res.Files = append(res.Files, fileMatch) + resFiles.addFileMatch(fileMatch, tfs) res.Stats.MatchCount += len(fileMatch.LineMatches) res.Stats.MatchCount += matchedChunkRanges res.Stats.FileCount++ } + // Calculate final BM25 score for all file matches in the shard. We assume that + // we have seen all documents containing any of the terms in the query so that + // df correctly reflects the document frequencies. This is true, for example, if + // all terms in the query are ORed together. + if opts.UseBM25Scoring { + resFiles.scoreFilesUsingBM25(df, len(d.boundaries)) + } + + res.Files = resFiles.fileMatches + + if opts.DebugScore { + prefix := "score" + if opts.UseBM25Scoring { + prefix = "bm25-score" + } + for i, fileMatch := range res.Files { + res.Files[i].Debug = fmt.Sprintf("%s: %.2f <- %s", prefix, fileMatch.Score, fileMatch.Debug) + } + } + for _, md := range d.repoMetaData { r := md addRepo(&res, &r) diff --git a/score.go b/score.go index 9bcf1bbc..0c264fe5 100644 --- a/score.go +++ b/score.go @@ -39,13 +39,6 @@ func (m *FileMatch) addScore(what string, computed float64, raw float64, debugSc m.Score += computed } -func (m *FileMatch) addBM25Score(score float64, sumTf float64, L float64, debugScore bool) { - if debugScore { - m.Debug += fmt.Sprintf("bm25-score:%.2f (sum-tf: %.2f, length-ratio: %.2f)", score, sumTf, L) - } - m.Score += score -} - // scoreFile computes a score for the file match using various scoring signals, like // whether there's an exact match on a symbol, the number of query clauses that matched, etc. func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) { @@ -115,16 +108,20 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn } } -// scoreFileUsingBM25 computes a score for the file match using an approximation to BM25, the most common scoring -// algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. It implements all parts of the formula -// except inverse document frequency (idf), since we don't have access to global term frequency statistics. +// calculateTermFrequencyScore computes the TF score per term for the file match +// according to BM25, the most common scoring algorithm for text search: +// https://en.wikipedia.org/wiki/Okapi_BM25. We defer the calculation of the +// full bm25 score to after we have finished searching the shard, because we can +// only calculate the inverse document frequency (idf) after we have seen all +// documents. // -// Filename matches count twice as much as content matches. This mimics a common text search strategy where you -// 'boost' matches on document titles. +// Filename matches count more than content matches. This mimics a common text +// search strategy where you 'boost' matches on document titles. // -// This scoring strategy ignores all other signals including document ranks. This keeps things simple for now, -// since BM25 is not normalized and can be tricky to combine with other scoring signals. -func (d *indexData) scoreFileUsingBM25(fileMatch *FileMatch, doc uint32, cands []*candidateMatch, opts *SearchOptions) { +// This scoring strategy ignores all other signals including document ranks. +// This keeps things simple for now, since BM25 is not normalized and can be +// tricky to combine with other scoring signals. +func (d *indexData) calculateTermFrequencyScore(fileMatch *FileMatch, doc uint32, cands []*candidateMatch, df termDocumentFrequency, opts *SearchOptions) termFrequencyScore { // Treat each candidate match as a term and compute the frequencies. For now, ignore case // sensitivity and treat filenames and symbols the same as content. termFreqs := map[string]int{} @@ -153,12 +150,56 @@ func (d *indexData) scoreFileUsingBM25(fileMatch *FileMatch, doc uint32, cands [ // Use standard parameter defaults (used in Lucene and academic papers) k, b := 1.2, 0.75 sumTf := 0.0 // Just for debugging - score := 0.0 - for _, freq := range termFreqs { + + tfs := make(termFrequencyScore) + + for term, freq := range termFreqs { tf := float64(freq) sumTf += tf - score += ((k + 1.0) * tf) / (k*(1.0-b+b*L) + tf) + + // Invariant: the keys of df are the union of the keys of tfs over all files. + df[term] += 1 + tfs[term] = ((k + 1.0) * tf) / (k*(1.0-b+b*L) + tf) } - fileMatch.addBM25Score(score, sumTf, L, opts.DebugScore) + if opts.DebugScore { + fileMatch.Debug = fmt.Sprintf("sum-termFrequencyScore: %.2f, length-ratio: %.2f", sumTf, L) + } + + return tfs +} + +// idf computes the inverse document frequency for a term. nq is the number of +// documents that contain the term and documentCount is the total number of +// documents in the corpus. +func idf(nq, documentCount int) float64 { + return math.Log(1.0 + ((float64(documentCount) - float64(nq) + 0.5) / (float64(nq) + 0.5))) +} + +// termDocumentFrequency is a map "term" -> "number of documents that contain the term" +type termDocumentFrequency map[string]int + +// termFrequencyScore is a map "term" -> "term frequency score" +type termFrequencyScore map[string]float64 + +// fileMatchesWithScores is a helper type that is used to store the file matches +// along with internal scoring information. +type fileMatchesWithScores struct { + fileMatches []FileMatch + tfs []termFrequencyScore +} + +func (m *fileMatchesWithScores) addFileMatch(fm FileMatch, tfs termFrequencyScore) { + m.fileMatches = append(m.fileMatches, fm) + m.tfs = append(m.tfs, tfs) +} + +func (m *fileMatchesWithScores) scoreFilesUsingBM25(df termDocumentFrequency, documentCount int) { + for i := range m.fileMatches { + score := 0.0 + for term, tfScore := range m.tfs[i] { + score += idf(df[term], documentCount) * tfScore + } + m.fileMatches[i].Score = score + } } From 827d00fcd31544d0df50dfba8e7a8117cb003559 Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Wed, 5 Jun 2024 15:34:56 +0200 Subject: [PATCH 2/8] fix numFiles, use d.numDocs() --- build/scoring_test.go | 16 ++++++++-------- eval.go | 2 +- score.go | 3 +-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/build/scoring_test.go b/build/scoring_test.go index 0f3a2f2a..8dc2f2f9 100644 --- a/build/scoring_test.go +++ b/build/scoring_test.go @@ -77,8 +77,8 @@ func TestBM25(t *testing.T) { query: &query.Substring{Pattern: "example"}, content: exampleJava, language: "Java", - // bm25-score:1.26 <- sum-termFrequencyScore: 10.00, length-ratio: 2.00 - wantScore: 1.26, + // bm25-score: 0.57 <- sum-termFrequencyScore: 10.00, length-ratio: 1.00 + wantScore: 0.57, }, { // Matches only on content fileName: "example.java", @@ -89,8 +89,8 @@ func TestBM25(t *testing.T) { }}, content: exampleJava, language: "Java", - // bm25-score:3.99 <- sum-termFrequencyScore: 56.00, length-ratio: 2.00 - wantScore: 3.99, + // bm25-score: 1.75 <- sum-termFrequencyScore: 56.00, length-ratio: 1.00 + wantScore: 1.75, }, { // Matches only on filename @@ -98,16 +98,16 @@ func TestBM25(t *testing.T) { query: &query.Substring{Pattern: "java"}, content: exampleJava, language: "Java", - // bm25-score:1.07 <- sum-termFrequencyScore: 5.00, length-ratio: 2.00 - wantScore: 1.07, + // bm25-score: 0.51 <- sum-termFrequencyScore: 5.00, length-ratio: 1.00 + wantScore: 0.51, }, { // Matches only on filename, and content is missing fileName: "a/b/c/config.go", query: &query.Substring{Pattern: "config.go"}, language: "Go", - // bm25-score:1.44 <- sum-termFrequencyScore: 5.00, length-ratio: 0.00 - wantScore: 1.44, + // bm25-score: 0.60 <- sum-termFrequencyScore: 5.00, length-ratio: 0.00 + wantScore: 0.60, }, } diff --git a/eval.go b/eval.go index 7f892d0b..83b4d12c 100644 --- a/eval.go +++ b/eval.go @@ -363,7 +363,7 @@ nextFileMatch: // df correctly reflects the document frequencies. This is true, for example, if // all terms in the query are ORed together. if opts.UseBM25Scoring { - resFiles.scoreFilesUsingBM25(df, len(d.boundaries)) + resFiles.scoreFilesUsingBM25(df, int(d.numDocs())) } res.Files = resFiles.fileMatches diff --git a/score.go b/score.go index 0c264fe5..4d1517e3 100644 --- a/score.go +++ b/score.go @@ -138,8 +138,7 @@ func (d *indexData) calculateTermFrequencyScore(fileMatch *FileMatch, doc uint32 // Compute the file length ratio. Usually the calculation would be based on terms, but using // bytes should work fine, as we're just computing a ratio. fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) - numFiles := len(d.boundaries) - averageFileLength := float64(d.boundaries[numFiles-1]) / float64(numFiles) + averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs()) // This is very unlikely, but explicitly guard against division by zero. if averageFileLength == 0 { From af9a0a4bc861750531f2ee90ddac53feaedeb7cc Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Thu, 6 Jun 2024 09:41:09 +0200 Subject: [PATCH 3/8] perform full BM25 calc in scoreFilesUsingBM25 --- eval.go | 26 ++++++------------ score.go | 83 ++++++++++++++++++++++++++++++-------------------------- 2 files changed, 52 insertions(+), 57 deletions(-) diff --git a/eval.go b/eval.go index 83b4d12c..855c119c 100644 --- a/eval.go +++ b/eval.go @@ -200,8 +200,8 @@ func (d *indexData) Search(ctx context.Context, q query.Q, opts *SearchOptions) // document frequency per term df := make(termDocumentFrequency) - // term frequency scores per document - var tfs termFrequencyScore + // term frequency per document + var tf termFrequencies // used to track intermediate scores for BM25 scoring. resFiles := fileMatchesWithScores{} @@ -328,10 +328,10 @@ nextFileMatch: if opts.UseBM25Scoring { // For BM25 scoring, the calculation of the score is split in two parts. Here we - // calculate the term frequency scores for the current document. Since we don't - // store document frequencies in the index, we have to defer the calculation of - // IDF and the final BM25 score to after the whole shard has been processed. - tfs = d.calculateTermFrequencyScore(&fileMatch, nextDoc, finalCands, df, opts) + // calculate the term frequencies for the current document. Since we don't store + // document frequencies in the index, we have to defer the calculation of the + // final BM25 score to after the whole shard has been processed. + tf = d.calculateTermFrequency(nextDoc, finalCands, df) } else { // Use the standard, non-experimental scoring method by default d.scoreFile(&fileMatch, nextDoc, mt, known, opts) @@ -352,7 +352,7 @@ nextFileMatch: repoMatchCount += len(fileMatch.LineMatches) repoMatchCount += matchedChunkRanges - resFiles.addFileMatch(fileMatch, tfs) + resFiles.addFileMatch(fileMatch, tf) res.Stats.MatchCount += len(fileMatch.LineMatches) res.Stats.MatchCount += matchedChunkRanges res.Stats.FileCount++ @@ -363,21 +363,11 @@ nextFileMatch: // df correctly reflects the document frequencies. This is true, for example, if // all terms in the query are ORed together. if opts.UseBM25Scoring { - resFiles.scoreFilesUsingBM25(df, int(d.numDocs())) + d.scoreFilesUsingBM25(&resFiles, df, opts) } res.Files = resFiles.fileMatches - if opts.DebugScore { - prefix := "score" - if opts.UseBM25Scoring { - prefix = "bm25-score" - } - for i, fileMatch := range res.Files { - res.Files[i].Debug = fmt.Sprintf("%s: %.2f <- %s", prefix, fileMatch.Score, fileMatch.Debug) - } - } - for _, md := range d.repoMetaData { r := md addRepo(&res, &r) diff --git a/score.go b/score.go index 4d1517e3..20179501 100644 --- a/score.go +++ b/score.go @@ -104,7 +104,7 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn addScore("repo-rank", scoreRepoRankFactor*float64(md.Rank)/maxUInt16) if opts.DebugScore { - fileMatch.Debug = strings.TrimSuffix(fileMatch.Debug, ", ") + fileMatch.Debug = fmt.Sprintf("score: %.2f <- %s", fileMatch.Score, strings.TrimSuffix(fileMatch.Debug, ", ")) } } @@ -121,13 +121,17 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn // This scoring strategy ignores all other signals including document ranks. // This keeps things simple for now, since BM25 is not normalized and can be // tricky to combine with other scoring signals. -func (d *indexData) calculateTermFrequencyScore(fileMatch *FileMatch, doc uint32, cands []*candidateMatch, df termDocumentFrequency, opts *SearchOptions) termFrequencyScore { +func (d *indexData) calculateTermFrequency(doc uint32, cands []*candidateMatch, df termDocumentFrequency) termFrequencies { // Treat each candidate match as a term and compute the frequencies. For now, ignore case // sensitivity and treat filenames and symbols the same as content. termFreqs := map[string]int{} for _, cand := range cands { term := string(cand.substrLowered) + if _, ok := termFreqs[term]; !ok { + df[term] += 1 + } + if cand.fileName { termFreqs[term] += 5 } else { @@ -135,37 +139,11 @@ func (d *indexData) calculateTermFrequencyScore(fileMatch *FileMatch, doc uint32 } } - // Compute the file length ratio. Usually the calculation would be based on terms, but using - // bytes should work fine, as we're just computing a ratio. - fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) - averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs()) - - // This is very unlikely, but explicitly guard against division by zero. - if averageFileLength == 0 { - averageFileLength++ - } - L := fileLength / averageFileLength - - // Use standard parameter defaults (used in Lucene and academic papers) - k, b := 1.2, 0.75 - sumTf := 0.0 // Just for debugging - - tfs := make(termFrequencyScore) - - for term, freq := range termFreqs { - tf := float64(freq) - sumTf += tf - - // Invariant: the keys of df are the union of the keys of tfs over all files. - df[term] += 1 - tfs[term] = ((k + 1.0) * tf) / (k*(1.0-b+b*L) + tf) - } - - if opts.DebugScore { - fileMatch.Debug = fmt.Sprintf("sum-termFrequencyScore: %.2f, length-ratio: %.2f", sumTf, L) + return termFrequencies{ + doc: doc, + termFreqs: termFreqs, } - return tfs } // idf computes the inverse document frequency for a term. nq is the number of @@ -178,27 +156,54 @@ func idf(nq, documentCount int) float64 { // termDocumentFrequency is a map "term" -> "number of documents that contain the term" type termDocumentFrequency map[string]int -// termFrequencyScore is a map "term" -> "term frequency score" -type termFrequencyScore map[string]float64 +type termFrequencies struct { + doc uint32 + termFreqs map[string]int +} // fileMatchesWithScores is a helper type that is used to store the file matches // along with internal scoring information. type fileMatchesWithScores struct { fileMatches []FileMatch - tfs []termFrequencyScore + tf []termFrequencies } -func (m *fileMatchesWithScores) addFileMatch(fm FileMatch, tfs termFrequencyScore) { +func (m *fileMatchesWithScores) addFileMatch(fm FileMatch, tf termFrequencies) { m.fileMatches = append(m.fileMatches, fm) - m.tfs = append(m.tfs, tfs) + m.tf = append(m.tf, tf) } -func (m *fileMatchesWithScores) scoreFilesUsingBM25(df termDocumentFrequency, documentCount int) { +func (d *indexData) scoreFilesUsingBM25(m *fileMatchesWithScores, df termDocumentFrequency, opts *SearchOptions) { + // Use standard parameter defaults (used in Lucene and academic papers) + k, b := 1.2, 0.75 + + averageFileLength := float64(d.boundaries[d.numDocs()]) / float64(d.numDocs()) + // This is very unlikely, but explicitly guard against division by zero. + if averageFileLength == 0 { + averageFileLength++ + } + for i := range m.fileMatches { score := 0.0 - for term, tfScore := range m.tfs[i] { - score += idf(df[term], documentCount) * tfScore + + // Compute the file length ratio. Usually the calculation would be based on terms, but using + // bytes should work fine, as we're just computing a ratio. + doc := m.tf[i].doc + fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) + + L := fileLength / averageFileLength + + sumTF := 0 // Just for debugging + for term, f := range m.tf[i].termFreqs { + sumTF += f + tfScore := ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f)) + score += idf(df[term], int(d.numDocs())) * tfScore } + m.fileMatches[i].Score = score + + if opts.DebugScore { + m.fileMatches[i].Debug = fmt.Sprintf("bm25-score: %.2f <- sum-termFrequencies: %d, length-ratio: %.2f", score, sumTF, L) + } } } From 793e69cd8e19c12570c09c45f4098f33ac061eda Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Thu, 6 Jun 2024 13:00:16 +0200 Subject: [PATCH 4/8] FileMatch --- api.go | 6 +++ api_proto.go | 46 +++++++++++++++- api_proto_test.go | 17 +++--- api_test.go | 2 +- .../grpc/server/server_test.go | 9 ++-- eval.go | 28 ++++------ index_test.go | 4 +- read_test.go | 2 +- score.go | 54 ++++++------------- 9 files changed, 94 insertions(+), 74 deletions(-) diff --git a/api.go b/api.go index 192b6a83..a8c65c12 100644 --- a/api.go +++ b/api.go @@ -84,6 +84,12 @@ type FileMatch struct { // RepositoryID is a Sourcegraph extension. This is the ID of Repository in // Sourcegraph. RepositoryID uint32 `json:",omitempty"` + + doc uint32 + + // termFrequencies is a map from term to term frequency. We use this to + // calculate the BM25 score of a file match. + termFrequencies map[string]int } func (m *FileMatch) sizeBytes() (sz uint64) { diff --git a/api_proto.go b/api_proto.go index 368b689d..e315ff0a 100644 --- a/api_proto.go +++ b/api_proto.go @@ -17,10 +17,12 @@ package zoekt // import "github.com/sourcegraph/zoekt" import ( "math/rand" "reflect" + "testing/quick" - proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + + proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" ) func FileMatchFromProto(p *proto.FileMatch) FileMatch { @@ -83,6 +85,28 @@ func (m *FileMatch) ToProto() *proto.FileMatch { } } +func (*FileMatch) Generate(rng *rand.Rand, _ int) reflect.Value { + var f FileMatch + v := &FileMatch{ + FileName: gen(f.FileName, rng), + Repository: gen(f.Repository, rng), + SubRepositoryName: gen(f.SubRepositoryName, rng), + SubRepositoryPath: gen(f.SubRepositoryPath, rng), + Version: gen(f.Version, rng), + Language: gen(f.Language, rng), + Debug: gen(f.Debug, rng), + Branches: gen(f.Branches, rng), + LineMatches: gen(f.LineMatches, rng), + ChunkMatches: gen(f.ChunkMatches, rng), + Content: gen(f.Content, rng), + Checksum: gen(f.Checksum, rng), + Score: gen(f.Score, rng), + RepositoryPriority: gen(f.RepositoryPriority, rng), + RepositoryID: gen(f.RepositoryID, rng), + } + return reflect.ValueOf(v) +} + func ChunkMatchFromProto(p *proto.ChunkMatch) ChunkMatch { ranges := make([]Range, len(p.GetRanges())) for i, r := range p.GetRanges() { @@ -400,6 +424,20 @@ func (sr *SearchResult) ToStreamProto() *proto.StreamSearchResponse { return &proto.StreamSearchResponse{ResponseChunk: sr.ToProto()} } +func (*SearchResult) Generate(rng *rand.Rand, _ int) reflect.Value { + fm := &FileMatch{} + + var s SearchResult + v := &SearchResult{ + Stats: gen(s.Stats, rng), + Progress: gen(s.Progress, rng), + Files: []FileMatch{*gen(fm, rng)}, + RepoURLs: gen(s.RepoURLs, rng), + LineFragments: gen(s.LineFragments, rng), + } + return reflect.ValueOf(v) +} + func RepositoryBranchFromProto(p *proto.RepositoryBranch) RepositoryBranch { return RepositoryBranch{ Name: p.GetName(), @@ -728,3 +766,9 @@ func (s *SearchOptions) ToProto() *proto.SearchOptions { UseBm25Scoring: s.UseBM25Scoring, } } + +func gen[T any](sample T, r *rand.Rand) T { + var t T + v, _ := quick.Value(reflect.TypeOf(t), r) + return v.Interface().(T) +} diff --git a/api_proto_test.go b/api_proto_test.go index 79d93957..391a03f2 100644 --- a/api_proto_test.go +++ b/api_proto_test.go @@ -27,18 +27,23 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - webproto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" "google.golang.org/protobuf/proto" + webproto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" + fuzz "github.com/AdaLogics/go-fuzz-headers" ) func TestProtoRoundtrip(t *testing.T) { t.Run("FileMatch", func(t *testing.T) { - f := func(f1 FileMatch) bool { + f := func(f1 *FileMatch) bool { p1 := f1.ToProto() f2 := FileMatchFromProto(p1) - return reflect.DeepEqual(f1, f2) + if diff := cmp.Diff(f1, &f2, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { + fmt.Printf("got diff: %s", diff) + return false + } + return true } if err := quick.Check(f, nil); err != nil { t.Fatal(err) @@ -398,12 +403,6 @@ func (RepoListField) Generate(rng *rand.Rand, _ int) reflect.Value { } } -func gen[T any](sample T, r *rand.Rand) T { - var t T - v, _ := quick.Value(reflect.TypeOf(t), r) - return v.Interface().(T) -} - // This is a real search result that is intended to be a reasonable representative // for serialization benchmarks. // Generated by modifying the code to dump the proto to a file, then running a diff --git a/api_test.go b/api_test.go index 87ad4167..61be75ff 100644 --- a/api_test.go +++ b/api_test.go @@ -146,7 +146,7 @@ func TestMatchSize(t *testing.T) { size int }{{ v: FileMatch{}, - size: 256, + size: 264, }, { v: ChunkMatch{}, size: 112, diff --git a/cmd/zoekt-webserver/grpc/server/server_test.go b/cmd/zoekt-webserver/grpc/server/server_test.go index eae99d60..49cd6670 100644 --- a/cmd/zoekt-webserver/grpc/server/server_test.go +++ b/cmd/zoekt-webserver/grpc/server/server_test.go @@ -12,7 +12,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" "go.uber.org/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -21,6 +20,8 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" + "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" + "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/internal/mockSearcher" "github.com/sourcegraph/zoekt/query" @@ -108,11 +109,11 @@ func TestClientServer(t *testing.T) { } func TestFuzzGRPCChunkSender(t *testing.T) { - validateResult := func(input zoekt.SearchResult) error { + validateResult := func(input *zoekt.SearchResult) error { clientStream, serverStream := newPairedSearchStream(t) sender := gRPCChunkSender(serverStream) - sender.Send(&input) + sender.Send(input) allResponses := readAllStream(t, clientStream) if len(allResponses) == 0 { @@ -185,7 +186,7 @@ func TestFuzzGRPCChunkSender(t *testing.T) { } var lastErr error - if err := quick.Check(func(r zoekt.SearchResult) bool { + if err := quick.Check(func(r *zoekt.SearchResult) bool { lastErr = validateResult(r) return lastErr == nil diff --git a/eval.go b/eval.go index 855c119c..3a1aa16d 100644 --- a/eval.go +++ b/eval.go @@ -200,12 +200,6 @@ func (d *indexData) Search(ctx context.Context, q query.Q, opts *SearchOptions) // document frequency per term df := make(termDocumentFrequency) - // term frequency per document - var tf termFrequencies - - // used to track intermediate scores for BM25 scoring. - resFiles := fileMatchesWithScores{} - nextFileMatch: for { canceled := false @@ -294,6 +288,7 @@ nextFileMatch: FileName: string(d.fileName(nextDoc)), Checksum: d.getChecksum(nextDoc), Language: d.languageMap[d.getLanguage(nextDoc)], + doc: nextDoc, } if s := d.subRepos[nextDoc]; s > 0 { @@ -328,10 +323,11 @@ nextFileMatch: if opts.UseBM25Scoring { // For BM25 scoring, the calculation of the score is split in two parts. Here we - // calculate the term frequencies for the current document. Since we don't store - // document frequencies in the index, we have to defer the calculation of the - // final BM25 score to after the whole shard has been processed. - tf = d.calculateTermFrequency(nextDoc, finalCands, df) + // calculate the term frequencies for the current document and update the + // document frequencies. Since we don't store document frequencies in the index, + // we have to defer the calculation of the final BM25 score to after the whole + // shard has been processed. + calculateTermFrequency(&fileMatch, finalCands, df) } else { // Use the standard, non-experimental scoring method by default d.scoreFile(&fileMatch, nextDoc, mt, known, opts) @@ -352,22 +348,20 @@ nextFileMatch: repoMatchCount += len(fileMatch.LineMatches) repoMatchCount += matchedChunkRanges - resFiles.addFileMatch(fileMatch, tf) + res.Files = append(res.Files, fileMatch) res.Stats.MatchCount += len(fileMatch.LineMatches) res.Stats.MatchCount += matchedChunkRanges res.Stats.FileCount++ } - // Calculate final BM25 score for all file matches in the shard. We assume that - // we have seen all documents containing any of the terms in the query so that - // df correctly reflects the document frequencies. This is true, for example, if + // Calculate BM25 score for all file matches in the shard. We assume that we + // have seen all documents containing any of the terms in the query so that df + // correctly reflects the document frequencies. This is true, for example, if // all terms in the query are ORed together. if opts.UseBM25Scoring { - d.scoreFilesUsingBM25(&resFiles, df, opts) + d.scoreFilesUsingBM25(res.Files, df, opts) } - res.Files = resFiles.fileMatches - for _, md := range d.repoMetaData { r := md addRepo(&res, &r) diff --git a/index_test.go b/index_test.go index 05401215..6091235b 100644 --- a/index_test.go +++ b/index_test.go @@ -223,7 +223,7 @@ func TestNewlines(t *testing.T) { }}, }} - if diff := cmp.Diff(matches, want); diff != "" { + if diff := cmp.Diff(matches, want, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { t.Fatal(diff) } }) @@ -248,7 +248,7 @@ func TestNewlines(t *testing.T) { }}, }} - if diff := cmp.Diff(want, matches); diff != "" { + if diff := cmp.Diff(want, matches, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { t.Fatal(diff) } }) diff --git a/read_test.go b/read_test.go index 9e7acd13..5b8682b4 100644 --- a/read_test.go +++ b/read_test.go @@ -306,7 +306,7 @@ func TestReadSearch(t *testing.T) { continue } - if d := cmp.Diff(want.FileMatches[j], res.Files); d != "" { + if d := cmp.Diff(want.FileMatches[j], res.Files, cmpopts.IgnoreUnexported(FileMatch{})); d != "" { t.Errorf("matches for %s on %s (-want +got)\n%s", q, name, d) } } diff --git a/score.go b/score.go index 20179501..36901cd7 100644 --- a/score.go +++ b/score.go @@ -108,20 +108,11 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn } } -// calculateTermFrequencyScore computes the TF score per term for the file match -// according to BM25, the most common scoring algorithm for text search: -// https://en.wikipedia.org/wiki/Okapi_BM25. We defer the calculation of the -// full bm25 score to after we have finished searching the shard, because we can -// only calculate the inverse document frequency (idf) after we have seen all -// documents. +// calculateTermFrequency computes the term frequency for the file match. // // Filename matches count more than content matches. This mimics a common text // search strategy where you 'boost' matches on document titles. -// -// This scoring strategy ignores all other signals including document ranks. -// This keeps things simple for now, since BM25 is not normalized and can be -// tricky to combine with other scoring signals. -func (d *indexData) calculateTermFrequency(doc uint32, cands []*candidateMatch, df termDocumentFrequency) termFrequencies { +func calculateTermFrequency(fileMatch *FileMatch, cands []*candidateMatch, df termDocumentFrequency) { // Treat each candidate match as a term and compute the frequencies. For now, ignore case // sensitivity and treat filenames and symbols the same as content. termFreqs := map[string]int{} @@ -139,11 +130,7 @@ func (d *indexData) calculateTermFrequency(doc uint32, cands []*candidateMatch, } } - return termFrequencies{ - doc: doc, - termFreqs: termFreqs, - } - + fileMatch.termFrequencies = termFreqs } // idf computes the inverse document frequency for a term. nq is the number of @@ -156,24 +143,13 @@ func idf(nq, documentCount int) float64 { // termDocumentFrequency is a map "term" -> "number of documents that contain the term" type termDocumentFrequency map[string]int -type termFrequencies struct { - doc uint32 - termFreqs map[string]int -} - -// fileMatchesWithScores is a helper type that is used to store the file matches -// along with internal scoring information. -type fileMatchesWithScores struct { - fileMatches []FileMatch - tf []termFrequencies -} - -func (m *fileMatchesWithScores) addFileMatch(fm FileMatch, tf termFrequencies) { - m.fileMatches = append(m.fileMatches, fm) - m.tf = append(m.tf, tf) -} - -func (d *indexData) scoreFilesUsingBM25(m *fileMatchesWithScores, df termDocumentFrequency, opts *SearchOptions) { +// scoreFilesUsingBM25 computes the score according to BM25, the most common +// scoring algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. +// +// This scoring strategy ignores all other signals including document ranks. +// This keeps things simple for now, since BM25 is not normalized and can be +// tricky to combine with other scoring signals. +func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, df termDocumentFrequency, opts *SearchOptions) { // Use standard parameter defaults (used in Lucene and academic papers) k, b := 1.2, 0.75 @@ -183,27 +159,27 @@ func (d *indexData) scoreFilesUsingBM25(m *fileMatchesWithScores, df termDocumen averageFileLength++ } - for i := range m.fileMatches { + for i := range fileMatches { score := 0.0 // Compute the file length ratio. Usually the calculation would be based on terms, but using // bytes should work fine, as we're just computing a ratio. - doc := m.tf[i].doc + doc := fileMatches[i].doc fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) L := fileLength / averageFileLength sumTF := 0 // Just for debugging - for term, f := range m.tf[i].termFreqs { + for term, f := range fileMatches[i].termFrequencies { sumTF += f tfScore := ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f)) score += idf(df[term], int(d.numDocs())) * tfScore } - m.fileMatches[i].Score = score + fileMatches[i].Score = score if opts.DebugScore { - m.fileMatches[i].Debug = fmt.Sprintf("bm25-score: %.2f <- sum-termFrequencies: %d, length-ratio: %.2f", score, sumTF, L) + fileMatches[i].Debug = fmt.Sprintf("bm25-score: %.2f <- sum-termFrequencies: %d, length-ratio: %.2f", score, sumTF, L) } } } From 0b3c8f00aca8f0ff4f533af83c9b68d040a6329f Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Fri, 7 Jun 2024 11:36:23 +0200 Subject: [PATCH 5/8] move df out of loop --- score.go | 9 ++++----- score_test.go | 51 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) create mode 100644 score_test.go diff --git a/score.go b/score.go index 36901cd7..34464c4a 100644 --- a/score.go +++ b/score.go @@ -118,11 +118,6 @@ func calculateTermFrequency(fileMatch *FileMatch, cands []*candidateMatch, df te termFreqs := map[string]int{} for _, cand := range cands { term := string(cand.substrLowered) - - if _, ok := termFreqs[term]; !ok { - df[term] += 1 - } - if cand.fileName { termFreqs[term] += 5 } else { @@ -130,6 +125,10 @@ func calculateTermFrequency(fileMatch *FileMatch, cands []*candidateMatch, df te } } + for term := range termFreqs { + df[term] += 1 + } + fileMatch.termFrequencies = termFreqs } diff --git a/score_test.go b/score_test.go new file mode 100644 index 00000000..37811aaa --- /dev/null +++ b/score_test.go @@ -0,0 +1,51 @@ +package zoekt + +import ( + "maps" + "testing" +) + +func TestCalculateTermFrequency(t *testing.T) { + cases := []struct { + cands []*candidateMatch + wantDF termDocumentFrequency + wantTermFrequencies map[string]int + }{{ + cands: []*candidateMatch{ + {substrLowered: []byte("foo")}, + {substrLowered: []byte("foo")}, + {substrLowered: []byte("bar")}, + { + substrLowered: []byte("bas"), + fileName: true, + }, + }, + wantDF: termDocumentFrequency{ + "foo": 1, + "bar": 1, + "bas": 1, + }, + wantTermFrequencies: map[string]int{ + "foo": 2, + "bar": 1, + "bas": 5, + }, + }, + } + + for _, c := range cases { + t.Run("", func(t *testing.T) { + fm := FileMatch{} + df := make(termDocumentFrequency) + calculateTermFrequency(&fm, c.cands, df) + + if !maps.Equal(df, c.wantDF) { + t.Errorf("got %v, want %v", df, c.wantDF) + } + + if !maps.Equal(fm.termFrequencies, c.wantTermFrequencies) { + t.Errorf("got %v, want %v", fm.termFrequencies, c.wantTermFrequencies) + } + }) + } +} From a9f21adb6fd7970a832a967fddfdc35c941105fa Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Fri, 7 Jun 2024 12:49:08 +0200 Subject: [PATCH 6/8] revert to auxiliary slice --- api.go | 12 +++++++++--- eval.go | 15 ++++++++++++--- score.go | 18 ++++++++++++------ score_test.go | 6 +++--- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/api.go b/api.go index a8c65c12..d33da279 100644 --- a/api.go +++ b/api.go @@ -952,9 +952,15 @@ type SearchOptions struct { // will be used. This option is temporary and is only exposed for testing/ tuning purposes. DocumentRanksWeight float64 - // EXPERIMENTAL. If true, use text-search style scoring instead of the default scoring formula. - // The scoring algorithm treats each match in a file as a term and computes an approximation to - // BM25. When enabled, all other scoring signals are ignored, including document ranks. + // EXPERIMENTAL. If true, use text-search style scoring instead of the default + // scoring formula. The scoring algorithm treats each match in a file as a term + // and computes an approximation to BM25. + // + // The calculation of IDF assumes that Zoekt visits all documents containing any + // of the query terms during evaluation. This is true, for example, if all query + // terms are ORed together. + // + // When enabled, all other scoring signals are ignored, including document ranks. UseBM25Scoring bool // Trace turns on opentracing for this request if true and if the Jaeger address was provided as diff --git a/eval.go b/eval.go index 3a1aa16d..5b53ec65 100644 --- a/eval.go +++ b/eval.go @@ -200,6 +200,9 @@ func (d *indexData) Search(ctx context.Context, q query.Q, opts *SearchOptions) // document frequency per term df := make(termDocumentFrequency) + // term frequency per file match + var tfs []termFrequency + nextFileMatch: for { canceled := false @@ -288,7 +291,6 @@ nextFileMatch: FileName: string(d.fileName(nextDoc)), Checksum: d.getChecksum(nextDoc), Language: d.languageMap[d.getLanguage(nextDoc)], - doc: nextDoc, } if s := d.subRepos[nextDoc]; s > 0 { @@ -321,13 +323,14 @@ nextFileMatch: fileMatch.LineMatches = cp.fillMatches(finalCands, opts.NumContextLines, fileMatch.Language, opts.DebugScore) } + var tf map[string]int if opts.UseBM25Scoring { // For BM25 scoring, the calculation of the score is split in two parts. Here we // calculate the term frequencies for the current document and update the // document frequencies. Since we don't store document frequencies in the index, // we have to defer the calculation of the final BM25 score to after the whole // shard has been processed. - calculateTermFrequency(&fileMatch, finalCands, df) + tf = calculateTermFrequency(finalCands, df) } else { // Use the standard, non-experimental scoring method by default d.scoreFile(&fileMatch, nextDoc, mt, known, opts) @@ -348,7 +351,13 @@ nextFileMatch: repoMatchCount += len(fileMatch.LineMatches) repoMatchCount += matchedChunkRanges + // Invariant: tfs[i] belongs to res.Files[i] + tfs = append(tfs, termFrequency{ + doc: nextDoc, + tf: tf, + }) res.Files = append(res.Files, fileMatch) + res.Stats.MatchCount += len(fileMatch.LineMatches) res.Stats.MatchCount += matchedChunkRanges res.Stats.FileCount++ @@ -359,7 +368,7 @@ nextFileMatch: // correctly reflects the document frequencies. This is true, for example, if // all terms in the query are ORed together. if opts.UseBM25Scoring { - d.scoreFilesUsingBM25(res.Files, df, opts) + d.scoreFilesUsingBM25(res.Files, tfs, df, opts) } for _, md := range d.repoMetaData { diff --git a/score.go b/score.go index 34464c4a..a2579df2 100644 --- a/score.go +++ b/score.go @@ -112,7 +112,7 @@ func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, kn // // Filename matches count more than content matches. This mimics a common text // search strategy where you 'boost' matches on document titles. -func calculateTermFrequency(fileMatch *FileMatch, cands []*candidateMatch, df termDocumentFrequency) { +func calculateTermFrequency(cands []*candidateMatch, df termDocumentFrequency) map[string]int { // Treat each candidate match as a term and compute the frequencies. For now, ignore case // sensitivity and treat filenames and symbols the same as content. termFreqs := map[string]int{} @@ -129,7 +129,7 @@ func calculateTermFrequency(fileMatch *FileMatch, cands []*candidateMatch, df te df[term] += 1 } - fileMatch.termFrequencies = termFreqs + return termFreqs } // idf computes the inverse document frequency for a term. nq is the number of @@ -142,13 +142,19 @@ func idf(nq, documentCount int) float64 { // termDocumentFrequency is a map "term" -> "number of documents that contain the term" type termDocumentFrequency map[string]int +// termFrequency stores the term frequencies for doc. +type termFrequency struct { + doc uint32 + tf map[string]int +} + // scoreFilesUsingBM25 computes the score according to BM25, the most common // scoring algorithm for text search: https://en.wikipedia.org/wiki/Okapi_BM25. // // This scoring strategy ignores all other signals including document ranks. // This keeps things simple for now, since BM25 is not normalized and can be // tricky to combine with other scoring signals. -func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, df termDocumentFrequency, opts *SearchOptions) { +func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, tfs []termFrequency, df termDocumentFrequency, opts *SearchOptions) { // Use standard parameter defaults (used in Lucene and academic papers) k, b := 1.2, 0.75 @@ -158,18 +164,18 @@ func (d *indexData) scoreFilesUsingBM25(fileMatches []FileMatch, df termDocument averageFileLength++ } - for i := range fileMatches { + for i := range tfs { score := 0.0 // Compute the file length ratio. Usually the calculation would be based on terms, but using // bytes should work fine, as we're just computing a ratio. - doc := fileMatches[i].doc + doc := tfs[i].doc fileLength := float64(d.boundaries[doc+1] - d.boundaries[doc]) L := fileLength / averageFileLength sumTF := 0 // Just for debugging - for term, f := range fileMatches[i].termFrequencies { + for term, f := range tfs[i].tf { sumTF += f tfScore := ((k + 1.0) * float64(f)) / (k*(1.0-b+b*L) + float64(f)) score += idf(df[term], int(d.numDocs())) * tfScore diff --git a/score_test.go b/score_test.go index 37811aaa..2e3b1384 100644 --- a/score_test.go +++ b/score_test.go @@ -37,14 +37,14 @@ func TestCalculateTermFrequency(t *testing.T) { t.Run("", func(t *testing.T) { fm := FileMatch{} df := make(termDocumentFrequency) - calculateTermFrequency(&fm, c.cands, df) + tf := calculateTermFrequency(c.cands, df) if !maps.Equal(df, c.wantDF) { t.Errorf("got %v, want %v", df, c.wantDF) } - if !maps.Equal(fm.termFrequencies, c.wantTermFrequencies) { - t.Errorf("got %v, want %v", fm.termFrequencies, c.wantTermFrequencies) + if !maps.Equal(tf, c.wantTermFrequencies) { + t.Errorf("got %v, want %v", fm, c.wantTermFrequencies) } }) } From 43cd75253fe2551ab1ad7bee92658cf6c24eb633 Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Fri, 7 Jun 2024 12:53:52 +0200 Subject: [PATCH 7/8] revert change to FileMatch --- api.go | 6 --- api_proto.go | 46 +------------------ api_proto_test.go | 17 +++---- api_test.go | 2 +- .../grpc/server/server_test.go | 9 ++-- index_test.go | 4 +- read_test.go | 2 +- 7 files changed, 18 insertions(+), 68 deletions(-) diff --git a/api.go b/api.go index d33da279..1e478dfa 100644 --- a/api.go +++ b/api.go @@ -84,12 +84,6 @@ type FileMatch struct { // RepositoryID is a Sourcegraph extension. This is the ID of Repository in // Sourcegraph. RepositoryID uint32 `json:",omitempty"` - - doc uint32 - - // termFrequencies is a map from term to term frequency. We use this to - // calculate the BM25 score of a file match. - termFrequencies map[string]int } func (m *FileMatch) sizeBytes() (sz uint64) { diff --git a/api_proto.go b/api_proto.go index e315ff0a..368b689d 100644 --- a/api_proto.go +++ b/api_proto.go @@ -17,12 +17,10 @@ package zoekt // import "github.com/sourcegraph/zoekt" import ( "math/rand" "reflect" - "testing/quick" + proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" - - proto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" ) func FileMatchFromProto(p *proto.FileMatch) FileMatch { @@ -85,28 +83,6 @@ func (m *FileMatch) ToProto() *proto.FileMatch { } } -func (*FileMatch) Generate(rng *rand.Rand, _ int) reflect.Value { - var f FileMatch - v := &FileMatch{ - FileName: gen(f.FileName, rng), - Repository: gen(f.Repository, rng), - SubRepositoryName: gen(f.SubRepositoryName, rng), - SubRepositoryPath: gen(f.SubRepositoryPath, rng), - Version: gen(f.Version, rng), - Language: gen(f.Language, rng), - Debug: gen(f.Debug, rng), - Branches: gen(f.Branches, rng), - LineMatches: gen(f.LineMatches, rng), - ChunkMatches: gen(f.ChunkMatches, rng), - Content: gen(f.Content, rng), - Checksum: gen(f.Checksum, rng), - Score: gen(f.Score, rng), - RepositoryPriority: gen(f.RepositoryPriority, rng), - RepositoryID: gen(f.RepositoryID, rng), - } - return reflect.ValueOf(v) -} - func ChunkMatchFromProto(p *proto.ChunkMatch) ChunkMatch { ranges := make([]Range, len(p.GetRanges())) for i, r := range p.GetRanges() { @@ -424,20 +400,6 @@ func (sr *SearchResult) ToStreamProto() *proto.StreamSearchResponse { return &proto.StreamSearchResponse{ResponseChunk: sr.ToProto()} } -func (*SearchResult) Generate(rng *rand.Rand, _ int) reflect.Value { - fm := &FileMatch{} - - var s SearchResult - v := &SearchResult{ - Stats: gen(s.Stats, rng), - Progress: gen(s.Progress, rng), - Files: []FileMatch{*gen(fm, rng)}, - RepoURLs: gen(s.RepoURLs, rng), - LineFragments: gen(s.LineFragments, rng), - } - return reflect.ValueOf(v) -} - func RepositoryBranchFromProto(p *proto.RepositoryBranch) RepositoryBranch { return RepositoryBranch{ Name: p.GetName(), @@ -766,9 +728,3 @@ func (s *SearchOptions) ToProto() *proto.SearchOptions { UseBm25Scoring: s.UseBM25Scoring, } } - -func gen[T any](sample T, r *rand.Rand) T { - var t T - v, _ := quick.Value(reflect.TypeOf(t), r) - return v.Interface().(T) -} diff --git a/api_proto_test.go b/api_proto_test.go index 391a03f2..79d93957 100644 --- a/api_proto_test.go +++ b/api_proto_test.go @@ -27,23 +27,18 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/protobuf/proto" - webproto "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" + "google.golang.org/protobuf/proto" fuzz "github.com/AdaLogics/go-fuzz-headers" ) func TestProtoRoundtrip(t *testing.T) { t.Run("FileMatch", func(t *testing.T) { - f := func(f1 *FileMatch) bool { + f := func(f1 FileMatch) bool { p1 := f1.ToProto() f2 := FileMatchFromProto(p1) - if diff := cmp.Diff(f1, &f2, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { - fmt.Printf("got diff: %s", diff) - return false - } - return true + return reflect.DeepEqual(f1, f2) } if err := quick.Check(f, nil); err != nil { t.Fatal(err) @@ -403,6 +398,12 @@ func (RepoListField) Generate(rng *rand.Rand, _ int) reflect.Value { } } +func gen[T any](sample T, r *rand.Rand) T { + var t T + v, _ := quick.Value(reflect.TypeOf(t), r) + return v.Interface().(T) +} + // This is a real search result that is intended to be a reasonable representative // for serialization benchmarks. // Generated by modifying the code to dump the proto to a file, then running a diff --git a/api_test.go b/api_test.go index 61be75ff..87ad4167 100644 --- a/api_test.go +++ b/api_test.go @@ -146,7 +146,7 @@ func TestMatchSize(t *testing.T) { size int }{{ v: FileMatch{}, - size: 264, + size: 256, }, { v: ChunkMatch{}, size: 112, diff --git a/cmd/zoekt-webserver/grpc/server/server_test.go b/cmd/zoekt-webserver/grpc/server/server_test.go index 49cd6670..eae99d60 100644 --- a/cmd/zoekt-webserver/grpc/server/server_test.go +++ b/cmd/zoekt-webserver/grpc/server/server_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" "go.uber.org/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -20,8 +21,6 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" - "github.com/sourcegraph/zoekt/grpc/protos/zoekt/webserver/v1" - "github.com/sourcegraph/zoekt" "github.com/sourcegraph/zoekt/internal/mockSearcher" "github.com/sourcegraph/zoekt/query" @@ -109,11 +108,11 @@ func TestClientServer(t *testing.T) { } func TestFuzzGRPCChunkSender(t *testing.T) { - validateResult := func(input *zoekt.SearchResult) error { + validateResult := func(input zoekt.SearchResult) error { clientStream, serverStream := newPairedSearchStream(t) sender := gRPCChunkSender(serverStream) - sender.Send(input) + sender.Send(&input) allResponses := readAllStream(t, clientStream) if len(allResponses) == 0 { @@ -186,7 +185,7 @@ func TestFuzzGRPCChunkSender(t *testing.T) { } var lastErr error - if err := quick.Check(func(r *zoekt.SearchResult) bool { + if err := quick.Check(func(r zoekt.SearchResult) bool { lastErr = validateResult(r) return lastErr == nil diff --git a/index_test.go b/index_test.go index 6091235b..05401215 100644 --- a/index_test.go +++ b/index_test.go @@ -223,7 +223,7 @@ func TestNewlines(t *testing.T) { }}, }} - if diff := cmp.Diff(matches, want, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { + if diff := cmp.Diff(matches, want); diff != "" { t.Fatal(diff) } }) @@ -248,7 +248,7 @@ func TestNewlines(t *testing.T) { }}, }} - if diff := cmp.Diff(want, matches, cmpopts.IgnoreUnexported(FileMatch{})); diff != "" { + if diff := cmp.Diff(want, matches); diff != "" { t.Fatal(diff) } }) diff --git a/read_test.go b/read_test.go index 5b8682b4..9e7acd13 100644 --- a/read_test.go +++ b/read_test.go @@ -306,7 +306,7 @@ func TestReadSearch(t *testing.T) { continue } - if d := cmp.Diff(want.FileMatches[j], res.Files, cmpopts.IgnoreUnexported(FileMatch{})); d != "" { + if d := cmp.Diff(want.FileMatches[j], res.Files); d != "" { t.Errorf("matches for %s on %s (-want +got)\n%s", q, name, d) } } From f283039b6aaa2a07fd347c102a50a914bab06601 Mon Sep 17 00:00:00 2001 From: Stefan Hengl Date: Mon, 10 Jun 2024 12:27:16 +0200 Subject: [PATCH 8/8] PR comment --- eval.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/eval.go b/eval.go index 5b53ec65..af637f01 100644 --- a/eval.go +++ b/eval.go @@ -351,11 +351,13 @@ nextFileMatch: repoMatchCount += len(fileMatch.LineMatches) repoMatchCount += matchedChunkRanges - // Invariant: tfs[i] belongs to res.Files[i] - tfs = append(tfs, termFrequency{ - doc: nextDoc, - tf: tf, - }) + if opts.UseBM25Scoring { + // Invariant: tfs[i] belongs to res.Files[i] + tfs = append(tfs, termFrequency{ + doc: nextDoc, + tf: tf, + }) + } res.Files = append(res.Files, fileMatch) res.Stats.MatchCount += len(fileMatch.LineMatches)