Skip to content

Commit

Permalink
score: introduce query.Boost to scale score
Browse files Browse the repository at this point in the history
This commit introduces a new primitive Boost to our query language. It
allows boosting (or dampening) the contribution to the score a query
atoms will match contribute.

To achieve this we introduce boostMatchTree which records this weight.
We then adjust the visitMatches to take an initial score weight (1.0),
and then each time we recurse through a boostMatchTree the score weight
is multiplied by the boost weight. Additionally candidateMatch now has a
new field, scoreWeight, which records the weight at time of candidate
collection. Without boosting in the query this value will always be 1.

Finally when scoring a candidateMatch we take the final score for it and
multiply it by scoreWeight.

Note: we do not expose a way to set this in the query language, only the
query API.

Test Plan: This functionality is currently untested. However, none of
our existings tests have broken so this is technically safe to land.
TODO add testing for boost query.
  • Loading branch information
keegancsmith committed Jan 28, 2024
1 parent cdb1665 commit 54e4e84
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 14 deletions.
2 changes: 1 addition & 1 deletion api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func TestMatchSize(t *testing.T) {
size: 112,
}, {
v: candidateMatch{},
size: 72,
size: 80,
}, {
v: candidateChunk{},
size: 40,
Expand Down
5 changes: 5 additions & 0 deletions contentprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,11 @@ func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language str
}
}

if m.scoreWeight != 1 { // should we be using epsilon comparison here?
score.score = score.score * m.scoreWeight
score.what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight)
}

if score.score > maxScore.score {
maxScore.score = score.score
maxScore.what = score.what
Expand Down
23 changes: 16 additions & 7 deletions eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ nextFileMatch:
// whether there's an exact match on a symbol, the number of query clauses that matched, etc.
func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) {
atomMatchCount := 0
visitMatches(mt, known, func(mt matchTree) {
visitMatchAtoms(mt, known, func(mt matchTree) {
atomMatchCount++
})

Expand Down Expand Up @@ -544,6 +544,13 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
return m[i].byteOffset < m[j].byteOffset
}

func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch {
for _, m := range cm {
m.scoreWeight = scoreWeight
}
return cm
}

// Gather matches from this document. This never returns a mixture of
// filename/content matches: if there are content matches, all
// filename matches are trimmed from the result. The matches are
Expand All @@ -554,18 +561,20 @@ func (m sortByOffsetSlice) Less(i, j int) bool {
// but adjacent matches will remain.
func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candidateMatch {
var cands []*candidateMatch
visitMatches(mt, known, func(mt matchTree) {
visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) {
// TODO apply scoreWeight to candidates
_ = scoreWeight
if smt, ok := mt.(*substrMatchTree); ok {
cands = append(cands, smt.current...)
cands = append(cands, setScoreWeight(scoreWeight, smt.current)...)
}
if rmt, ok := mt.(*regexpMatchTree); ok {
cands = append(cands, rmt.found...)
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
}
if rmt, ok := mt.(*wordMatchTree); ok {
cands = append(cands, rmt.found...)
cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...)
}
if smt, ok := mt.(*symbolRegexpMatchTree); ok {
cands = append(cands, smt.found...)
cands = append(cands, setScoreWeight(scoreWeight, smt.found)...)
}
})

Expand Down Expand Up @@ -649,7 +658,7 @@ func (d *indexData) branchIndex(docID uint32) int {
// returns all branches containing docID.
func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string {
var mask uint64
visitMatches(mt, known, func(mt matchTree) {
visitMatchAtoms(mt, known, func(mt matchTree) {
bq, ok := mt.(*branchQueryMatchTree)
if !ok {
return
Expand Down
2 changes: 2 additions & 0 deletions matchiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ type candidateMatch struct {
substrBytes []byte
substrLowered []byte

scoreWeight float64

file uint32
symbolIdx uint32

Expand Down
45 changes: 39 additions & 6 deletions matchtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ type fileNameMatchTree struct {
child matchTree
}

type boostMatchTree struct {
child matchTree
weight float64
}

// Don't visit this subtree for collecting matches.
type noVisitMatchTree struct {
matchTree
Expand Down Expand Up @@ -392,6 +397,10 @@ func (t *fileNameMatchTree) prepare(doc uint32) {
t.child.prepare(doc)
}

func (t *boostMatchTree) prepare(doc uint32) {
t.child.prepare(doc)
}

func (t *substrMatchTree) prepare(nextDoc uint32) {
t.matchIterator.prepare(nextDoc)
t.current = t.matchIterator.candidates()
Expand Down Expand Up @@ -455,6 +464,10 @@ func (t *fileNameMatchTree) nextDoc() uint32 {
return t.child.nextDoc()
}

func (t *boostMatchTree) nextDoc() uint32 {
return t.child.nextDoc()
}

func (t *branchQueryMatchTree) nextDoc() uint32 {
var start uint32
if t.firstDone {
Expand Down Expand Up @@ -515,6 +528,10 @@ func (t *fileNameMatchTree) String() string {
return fmt.Sprintf("f(%v)", t.child)
}

func (t *boostMatchTree) String() string {
return fmt.Sprintf("boost(%f, %v)", t.weight, t.child)
}

func (t *substrMatchTree) String() string {
f := ""
if t.fileName {
Expand Down Expand Up @@ -556,6 +573,8 @@ func visitMatchTree(t matchTree, f func(matchTree)) {
visitMatchTree(s.child, f)
case *fileNameMatchTree:
visitMatchTree(s.child, f)
case *boostMatchTree:
visitMatchTree(s.child, f)
case *symbolSubstrMatchTree:
visitMatchTree(s.substrMatchTree, f)
case *symbolRegexpMatchTree:
Expand All @@ -575,33 +594,41 @@ func updateMatchTreeStats(mt matchTree, stats *Stats) {
})
}

func visitMatchAtoms(t matchTree, known map[matchTree]bool, f func(matchTree)) {
visitMatches(t, known, 1, func(mt matchTree, _ float64) {
f(mt)
})
}

// visitMatches visits all atoms which can contribute matches. Note: This
// skips noVisitMatchTree.
func visitMatches(t matchTree, known map[matchTree]bool, f func(matchTree)) {
func visitMatches(t matchTree, known map[matchTree]bool, weight float64, f func(matchTree, float64)) {
switch s := t.(type) {
case *andMatchTree:
for _, ch := range s.children {
if known[ch] {
visitMatches(ch, known, f)
visitMatches(ch, known, weight, f)
}
}
case *andLineMatchTree:
visitMatches(&s.andMatchTree, known, f)
visitMatches(&s.andMatchTree, known, weight, f)
case *orMatchTree:
for _, ch := range s.children {
if known[ch] {
visitMatches(ch, known, f)
visitMatches(ch, known, weight, f)
}
}
case *boostMatchTree:
visitMatches(s.child, known, weight*s.weight, f)
case *symbolSubstrMatchTree:
visitMatches(s.substrMatchTree, known, f)
visitMatches(s.substrMatchTree, known, weight, f)
case *notMatchTree:
case *noVisitMatchTree:
// don't collect into negative trees.
case *fileNameMatchTree:
// We will just gather the filename if we do not visit this tree.
default:
f(s)
f(s, weight)
}
}

Expand Down Expand Up @@ -876,6 +903,10 @@ func (t *fileNameMatchTree) matches(cp *contentProvider, cost int, known map[mat
return evalMatchTree(cp, cost, known, t.child)
}

func (t *boostMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
return evalMatchTree(cp, cost, known, t.child)
}

func (t *substrMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState {
if t.contEvaluated {
return matchesStateForSlice(t.current)
Expand Down Expand Up @@ -1288,6 +1319,8 @@ func pruneMatchTree(mt matchTree) (matchTree, error) {
}
case *fileNameMatchTree:
mt.child, err = pruneMatchTree(mt.child)
case *boostMatchTree:
mt.child, err = pruneMatchTree(mt.child)
case *andLineMatchTree:
child, err := pruneMatchTree(&mt.andMatchTree)
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions query/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,19 @@ func (q *Type) String() string {
}
}

// Boost scales the contribution to score of descendents.
type Boost struct {
Child Q
// Weight will multiply the score of its descendents. Weights less than 1
// will give less importance while values greater than 1 will give more
// importance.
Weight float64
}

func (q *Boost) String() string {
return fmt.Sprintf("(boost %f %s)", q.Weight, q.Child)
}

// Substring is the most basic query: a query for a substring.
type Substring struct {
Pattern string
Expand Down

0 comments on commit 54e4e84

Please sign in to comment.