From 54e4e84d7b2a34d56eb4e82af80f287cd7d6f19e Mon Sep 17 00:00:00 2001 From: Keegan Carruthers-Smith Date: Sun, 28 Jan 2024 08:55:07 +0200 Subject: [PATCH] score: introduce query.Boost to scale score This commit introduces a new primitive Boost to our query language. It allows boosting (or dampening) the contribution to the score a query atoms will match contribute. To achieve this we introduce boostMatchTree which records this weight. We then adjust the visitMatches to take an initial score weight (1.0), and then each time we recurse through a boostMatchTree the score weight is multiplied by the boost weight. Additionally candidateMatch now has a new field, scoreWeight, which records the weight at time of candidate collection. Without boosting in the query this value will always be 1. Finally when scoring a candidateMatch we take the final score for it and multiply it by scoreWeight. Note: we do not expose a way to set this in the query language, only the query API. Test Plan: This functionality is currently untested. However, none of our existings tests have broken so this is technically safe to land. TODO add testing for boost query. --- api_test.go | 2 +- contentprovider.go | 5 +++++ eval.go | 23 ++++++++++++++++------- matchiter.go | 2 ++ matchtree.go | 45 +++++++++++++++++++++++++++++++++++++++------ query/query.go | 13 +++++++++++++ 6 files changed, 76 insertions(+), 14 deletions(-) diff --git a/api_test.go b/api_test.go index ab13f145d..ca6b47d56 100644 --- a/api_test.go +++ b/api_test.go @@ -152,7 +152,7 @@ func TestMatchSize(t *testing.T) { size: 112, }, { v: candidateMatch{}, - size: 72, + size: 80, }, { v: candidateChunk{}, size: 40, diff --git a/contentprovider.go b/contentprovider.go index 15156cab0..621af4594 100644 --- a/contentprovider.go +++ b/contentprovider.go @@ -660,6 +660,11 @@ func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language str } } + if m.scoreWeight != 1 { // should we be using epsilon comparison here? + score.score = score.score * m.scoreWeight + score.what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight) + } + if score.score > maxScore.score { maxScore.score = score.score maxScore.what = score.what diff --git a/eval.go b/eval.go index 7808f733b..9237ca4e4 100644 --- a/eval.go +++ b/eval.go @@ -420,7 +420,7 @@ nextFileMatch: // whether there's an exact match on a symbol, the number of query clauses that matched, etc. func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) { atomMatchCount := 0 - visitMatches(mt, known, func(mt matchTree) { + visitMatchAtoms(mt, known, func(mt matchTree) { atomMatchCount++ }) @@ -544,6 +544,13 @@ func (m sortByOffsetSlice) Less(i, j int) bool { return m[i].byteOffset < m[j].byteOffset } +func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch { + for _, m := range cm { + m.scoreWeight = scoreWeight + } + return cm +} + // Gather matches from this document. This never returns a mixture of // filename/content matches: if there are content matches, all // filename matches are trimmed from the result. The matches are @@ -554,18 +561,20 @@ func (m sortByOffsetSlice) Less(i, j int) bool { // but adjacent matches will remain. func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candidateMatch { var cands []*candidateMatch - visitMatches(mt, known, func(mt matchTree) { + visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) { + // TODO apply scoreWeight to candidates + _ = scoreWeight if smt, ok := mt.(*substrMatchTree); ok { - cands = append(cands, smt.current...) + cands = append(cands, setScoreWeight(scoreWeight, smt.current)...) } if rmt, ok := mt.(*regexpMatchTree); ok { - cands = append(cands, rmt.found...) + cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...) } if rmt, ok := mt.(*wordMatchTree); ok { - cands = append(cands, rmt.found...) + cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...) } if smt, ok := mt.(*symbolRegexpMatchTree); ok { - cands = append(cands, smt.found...) + cands = append(cands, setScoreWeight(scoreWeight, smt.found)...) } }) @@ -649,7 +658,7 @@ func (d *indexData) branchIndex(docID uint32) int { // returns all branches containing docID. func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string { var mask uint64 - visitMatches(mt, known, func(mt matchTree) { + visitMatchAtoms(mt, known, func(mt matchTree) { bq, ok := mt.(*branchQueryMatchTree) if !ok { return diff --git a/matchiter.go b/matchiter.go index 68c6e4856..98bf6b1ca 100644 --- a/matchiter.go +++ b/matchiter.go @@ -27,6 +27,8 @@ type candidateMatch struct { substrBytes []byte substrLowered []byte + scoreWeight float64 + file uint32 symbolIdx uint32 diff --git a/matchtree.go b/matchtree.go index 102e0bc26..2ddcc3281 100644 --- a/matchtree.go +++ b/matchtree.go @@ -170,6 +170,11 @@ type fileNameMatchTree struct { child matchTree } +type boostMatchTree struct { + child matchTree + weight float64 +} + // Don't visit this subtree for collecting matches. type noVisitMatchTree struct { matchTree @@ -392,6 +397,10 @@ func (t *fileNameMatchTree) prepare(doc uint32) { t.child.prepare(doc) } +func (t *boostMatchTree) prepare(doc uint32) { + t.child.prepare(doc) +} + func (t *substrMatchTree) prepare(nextDoc uint32) { t.matchIterator.prepare(nextDoc) t.current = t.matchIterator.candidates() @@ -455,6 +464,10 @@ func (t *fileNameMatchTree) nextDoc() uint32 { return t.child.nextDoc() } +func (t *boostMatchTree) nextDoc() uint32 { + return t.child.nextDoc() +} + func (t *branchQueryMatchTree) nextDoc() uint32 { var start uint32 if t.firstDone { @@ -515,6 +528,10 @@ func (t *fileNameMatchTree) String() string { return fmt.Sprintf("f(%v)", t.child) } +func (t *boostMatchTree) String() string { + return fmt.Sprintf("boost(%f, %v)", t.weight, t.child) +} + func (t *substrMatchTree) String() string { f := "" if t.fileName { @@ -556,6 +573,8 @@ func visitMatchTree(t matchTree, f func(matchTree)) { visitMatchTree(s.child, f) case *fileNameMatchTree: visitMatchTree(s.child, f) + case *boostMatchTree: + visitMatchTree(s.child, f) case *symbolSubstrMatchTree: visitMatchTree(s.substrMatchTree, f) case *symbolRegexpMatchTree: @@ -575,33 +594,41 @@ func updateMatchTreeStats(mt matchTree, stats *Stats) { }) } +func visitMatchAtoms(t matchTree, known map[matchTree]bool, f func(matchTree)) { + visitMatches(t, known, 1, func(mt matchTree, _ float64) { + f(mt) + }) +} + // visitMatches visits all atoms which can contribute matches. Note: This // skips noVisitMatchTree. -func visitMatches(t matchTree, known map[matchTree]bool, f func(matchTree)) { +func visitMatches(t matchTree, known map[matchTree]bool, weight float64, f func(matchTree, float64)) { switch s := t.(type) { case *andMatchTree: for _, ch := range s.children { if known[ch] { - visitMatches(ch, known, f) + visitMatches(ch, known, weight, f) } } case *andLineMatchTree: - visitMatches(&s.andMatchTree, known, f) + visitMatches(&s.andMatchTree, known, weight, f) case *orMatchTree: for _, ch := range s.children { if known[ch] { - visitMatches(ch, known, f) + visitMatches(ch, known, weight, f) } } + case *boostMatchTree: + visitMatches(s.child, known, weight*s.weight, f) case *symbolSubstrMatchTree: - visitMatches(s.substrMatchTree, known, f) + visitMatches(s.substrMatchTree, known, weight, f) case *notMatchTree: case *noVisitMatchTree: // don't collect into negative trees. case *fileNameMatchTree: // We will just gather the filename if we do not visit this tree. default: - f(s) + f(s, weight) } } @@ -876,6 +903,10 @@ func (t *fileNameMatchTree) matches(cp *contentProvider, cost int, known map[mat return evalMatchTree(cp, cost, known, t.child) } +func (t *boostMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState { + return evalMatchTree(cp, cost, known, t.child) +} + func (t *substrMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState { if t.contEvaluated { return matchesStateForSlice(t.current) @@ -1288,6 +1319,8 @@ func pruneMatchTree(mt matchTree) (matchTree, error) { } case *fileNameMatchTree: mt.child, err = pruneMatchTree(mt.child) + case *boostMatchTree: + mt.child, err = pruneMatchTree(mt.child) case *andLineMatchTree: child, err := pruneMatchTree(&mt.andMatchTree) if err != nil { diff --git a/query/query.go b/query/query.go index 6a8c5dd3a..28d28d19b 100644 --- a/query/query.go +++ b/query/query.go @@ -386,6 +386,19 @@ func (q *Type) String() string { } } +// Boost scales the contribution to score of descendents. +type Boost struct { + Child Q + // Weight will multiply the score of its descendents. Weights less than 1 + // will give less importance while values greater than 1 will give more + // importance. + Weight float64 +} + +func (q *Boost) String() string { + return fmt.Sprintf("(boost %f %s)", q.Weight, q.Child) +} + // Substring is the most basic query: a query for a substring. type Substring struct { Pattern string