diff --git a/api_test.go b/api_test.go index ab13f145d..ca6b47d56 100644 --- a/api_test.go +++ b/api_test.go @@ -152,7 +152,7 @@ func TestMatchSize(t *testing.T) { size: 112, }, { v: candidateMatch{}, - size: 72, + size: 80, }, { v: candidateChunk{}, size: 40, diff --git a/contentprovider.go b/contentprovider.go index 15156cab0..621af4594 100644 --- a/contentprovider.go +++ b/contentprovider.go @@ -660,6 +660,11 @@ func (p *contentProvider) candidateMatchScore(ms []*candidateMatch, language str } } + if m.scoreWeight != 1 { // should we be using epsilon comparison here? + score.score = score.score * m.scoreWeight + score.what += fmt.Sprintf("boost:%.2f, ", m.scoreWeight) + } + if score.score > maxScore.score { maxScore.score = score.score maxScore.what = score.what diff --git a/eval.go b/eval.go index 7808f733b..9237ca4e4 100644 --- a/eval.go +++ b/eval.go @@ -420,7 +420,7 @@ nextFileMatch: // whether there's an exact match on a symbol, the number of query clauses that matched, etc. func (d *indexData) scoreFile(fileMatch *FileMatch, doc uint32, mt matchTree, known map[matchTree]bool, opts *SearchOptions) { atomMatchCount := 0 - visitMatches(mt, known, func(mt matchTree) { + visitMatchAtoms(mt, known, func(mt matchTree) { atomMatchCount++ }) @@ -544,6 +544,13 @@ func (m sortByOffsetSlice) Less(i, j int) bool { return m[i].byteOffset < m[j].byteOffset } +func setScoreWeight(scoreWeight float64, cm []*candidateMatch) []*candidateMatch { + for _, m := range cm { + m.scoreWeight = scoreWeight + } + return cm +} + // Gather matches from this document. This never returns a mixture of // filename/content matches: if there are content matches, all // filename matches are trimmed from the result. The matches are @@ -554,18 +561,20 @@ func (m sortByOffsetSlice) Less(i, j int) bool { // but adjacent matches will remain. func gatherMatches(mt matchTree, known map[matchTree]bool, merge bool) []*candidateMatch { var cands []*candidateMatch - visitMatches(mt, known, func(mt matchTree) { + visitMatches(mt, known, 1, func(mt matchTree, scoreWeight float64) { + // TODO apply scoreWeight to candidates + _ = scoreWeight if smt, ok := mt.(*substrMatchTree); ok { - cands = append(cands, smt.current...) + cands = append(cands, setScoreWeight(scoreWeight, smt.current)...) } if rmt, ok := mt.(*regexpMatchTree); ok { - cands = append(cands, rmt.found...) + cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...) } if rmt, ok := mt.(*wordMatchTree); ok { - cands = append(cands, rmt.found...) + cands = append(cands, setScoreWeight(scoreWeight, rmt.found)...) } if smt, ok := mt.(*symbolRegexpMatchTree); ok { - cands = append(cands, smt.found...) + cands = append(cands, setScoreWeight(scoreWeight, smt.found)...) } }) @@ -649,7 +658,7 @@ func (d *indexData) branchIndex(docID uint32) int { // returns all branches containing docID. func (d *indexData) gatherBranches(docID uint32, mt matchTree, known map[matchTree]bool) []string { var mask uint64 - visitMatches(mt, known, func(mt matchTree) { + visitMatchAtoms(mt, known, func(mt matchTree) { bq, ok := mt.(*branchQueryMatchTree) if !ok { return diff --git a/matchiter.go b/matchiter.go index 68c6e4856..98bf6b1ca 100644 --- a/matchiter.go +++ b/matchiter.go @@ -27,6 +27,8 @@ type candidateMatch struct { substrBytes []byte substrLowered []byte + scoreWeight float64 + file uint32 symbolIdx uint32 diff --git a/matchtree.go b/matchtree.go index 102e0bc26..2ddcc3281 100644 --- a/matchtree.go +++ b/matchtree.go @@ -170,6 +170,11 @@ type fileNameMatchTree struct { child matchTree } +type boostMatchTree struct { + child matchTree + weight float64 +} + // Don't visit this subtree for collecting matches. type noVisitMatchTree struct { matchTree @@ -392,6 +397,10 @@ func (t *fileNameMatchTree) prepare(doc uint32) { t.child.prepare(doc) } +func (t *boostMatchTree) prepare(doc uint32) { + t.child.prepare(doc) +} + func (t *substrMatchTree) prepare(nextDoc uint32) { t.matchIterator.prepare(nextDoc) t.current = t.matchIterator.candidates() @@ -455,6 +464,10 @@ func (t *fileNameMatchTree) nextDoc() uint32 { return t.child.nextDoc() } +func (t *boostMatchTree) nextDoc() uint32 { + return t.child.nextDoc() +} + func (t *branchQueryMatchTree) nextDoc() uint32 { var start uint32 if t.firstDone { @@ -515,6 +528,10 @@ func (t *fileNameMatchTree) String() string { return fmt.Sprintf("f(%v)", t.child) } +func (t *boostMatchTree) String() string { + return fmt.Sprintf("boost(%f, %v)", t.weight, t.child) +} + func (t *substrMatchTree) String() string { f := "" if t.fileName { @@ -556,6 +573,8 @@ func visitMatchTree(t matchTree, f func(matchTree)) { visitMatchTree(s.child, f) case *fileNameMatchTree: visitMatchTree(s.child, f) + case *boostMatchTree: + visitMatchTree(s.child, f) case *symbolSubstrMatchTree: visitMatchTree(s.substrMatchTree, f) case *symbolRegexpMatchTree: @@ -575,33 +594,41 @@ func updateMatchTreeStats(mt matchTree, stats *Stats) { }) } +func visitMatchAtoms(t matchTree, known map[matchTree]bool, f func(matchTree)) { + visitMatches(t, known, 1, func(mt matchTree, _ float64) { + f(mt) + }) +} + // visitMatches visits all atoms which can contribute matches. Note: This // skips noVisitMatchTree. -func visitMatches(t matchTree, known map[matchTree]bool, f func(matchTree)) { +func visitMatches(t matchTree, known map[matchTree]bool, weight float64, f func(matchTree, float64)) { switch s := t.(type) { case *andMatchTree: for _, ch := range s.children { if known[ch] { - visitMatches(ch, known, f) + visitMatches(ch, known, weight, f) } } case *andLineMatchTree: - visitMatches(&s.andMatchTree, known, f) + visitMatches(&s.andMatchTree, known, weight, f) case *orMatchTree: for _, ch := range s.children { if known[ch] { - visitMatches(ch, known, f) + visitMatches(ch, known, weight, f) } } + case *boostMatchTree: + visitMatches(s.child, known, weight*s.weight, f) case *symbolSubstrMatchTree: - visitMatches(s.substrMatchTree, known, f) + visitMatches(s.substrMatchTree, known, weight, f) case *notMatchTree: case *noVisitMatchTree: // don't collect into negative trees. case *fileNameMatchTree: // We will just gather the filename if we do not visit this tree. default: - f(s) + f(s, weight) } } @@ -876,6 +903,10 @@ func (t *fileNameMatchTree) matches(cp *contentProvider, cost int, known map[mat return evalMatchTree(cp, cost, known, t.child) } +func (t *boostMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState { + return evalMatchTree(cp, cost, known, t.child) +} + func (t *substrMatchTree) matches(cp *contentProvider, cost int, known map[matchTree]bool) matchesState { if t.contEvaluated { return matchesStateForSlice(t.current) @@ -1288,6 +1319,8 @@ func pruneMatchTree(mt matchTree) (matchTree, error) { } case *fileNameMatchTree: mt.child, err = pruneMatchTree(mt.child) + case *boostMatchTree: + mt.child, err = pruneMatchTree(mt.child) case *andLineMatchTree: child, err := pruneMatchTree(&mt.andMatchTree) if err != nil { diff --git a/query/query.go b/query/query.go index 6a8c5dd3a..28d28d19b 100644 --- a/query/query.go +++ b/query/query.go @@ -386,6 +386,19 @@ func (q *Type) String() string { } } +// Boost scales the contribution to score of descendents. +type Boost struct { + Child Q + // Weight will multiply the score of its descendents. Weights less than 1 + // will give less importance while values greater than 1 will give more + // importance. + Weight float64 +} + +func (q *Boost) String() string { + return fmt.Sprintf("(boost %f %s)", q.Weight, q.Child) +} + // Substring is the most basic query: a query for a substring. type Substring struct { Pattern string