From a35ccdc3bf8dbb9ec6521d852277967317a7f6a5 Mon Sep 17 00:00:00 2001
From: Julie Tibshirani <julie.tibshirani@sourcegraph.com>
Date: Mon, 6 May 2024 12:42:41 -0700
Subject: [PATCH] Avoid overlapping trigrams in distanceHitIterator

---
 bits.go           |  4 ++--
 index_test.go     |  4 ++--
 indexdata.go      | 48 +++++++++++++++++++++++++++++++++++------------
 indexdata_test.go | 28 +++++++++++++++++++++++++++
 4 files changed, 68 insertions(+), 16 deletions(-)

diff --git a/bits.go b/bits.go
index 0b4901745..7d1363102 100644
--- a/bits.go
+++ b/bits.go
@@ -110,7 +110,7 @@ func (n ngram) String() string {
 type runeNgramOff struct {
 	ngram ngram
 	// index is the original index inside of the returned array of splitNGrams
-	index uint32
+	index int
 }
 
 func (a runeNgramOff) Compare(b runeNgramOff) int {
@@ -149,7 +149,7 @@ func splitNGrams(str []byte) []runeNgramOff {
 		ng := runesToNGram(runeGram)
 		result = append(result, runeNgramOff{
 			ngram: ng,
-			index: uint32(len(result)),
+			index: len(result),
 		})
 	}
 	return result
diff --git a/index_test.go b/index_test.go
index 8608b5bfc..05401215f 100644
--- a/index_test.go
+++ b/index_test.go
@@ -441,7 +441,7 @@ func TestSearchStats(t *testing.T) {
 			Want: Stats{
 				FilesLoaded:        1,
 				ContentBytesLoaded: 22,
-				IndexBytesLoaded:   8,
+				IndexBytesLoaded:   10,
 				NgramMatches:       3, // we look at doc 1, because it's max(0,1) due to AND
 				NgramLookups:       104,
 				MatchCount:         2,
@@ -556,7 +556,7 @@ func TestSearchStats(t *testing.T) {
 			}},
 			Want: Stats{
 				ContentBytesLoaded: 33, // we still have to run regex since "app" matches two documents
-				IndexBytesLoaded:   8,
+				IndexBytesLoaded:   10,
 				FilesConsidered:    2, // important that we don't check 3 to ensure we are using the index
 				FilesLoaded:        2,
 				MatchCount:         0, // even though there is a match it doesn't align with a symbol
diff --git a/indexdata.go b/indexdata.go
index ddb8a95c1..2fdd6ef5a 100644
--- a/indexdata.go
+++ b/indexdata.go
@@ -336,9 +336,31 @@ func min2Index(xs []uint32) (idx0, idx1 int) {
 	return
 }
 
-// minFrequencyNgramOffsets returns the two lowest frequency ngrams to pass to
-// the distance iterator. If they have the same frequency, we maximise the
-// distance between them. first will always have a smaller index than last.
+// findSelectiveNgrams returns two ngrams to pass to the distance iterator, chosen to
+// produce a small file intersection. It finds the two lowest frequency ngrams, making
+// sure to maximize the distance between them in case of ties. It avoids overlapping
+// trigrams to keep their intersection as small as possible.
+//
+// Invariant: first will always have a smaller index than last.
+func findSelectiveNgrams(ngramOffs []runeNgramOff, indexMap []int, frequencies []uint32) (first, last runeNgramOff) {
+	first, last = minFrequencyNgramOffsets(ngramOffs, frequencies)
+
+	// If the trigrams are overlapping, then try to shift one to reduce overlap.
+	// This is guaranteed to produce a smaller intersection.
+	if last.index-first.index < ngramSize {
+		newFirstIndex := max(last.index-ngramSize, 0)
+		if newFirstIndex != first.index {
+			first = ngramOffs[indexMap[newFirstIndex]]
+		}
+
+		newLastIndex := min(first.index+ngramSize, len(ngramOffs)-1)
+		if newLastIndex != last.index {
+			last = ngramOffs[indexMap[newLastIndex]]
+		}
+	}
+	return
+}
+
 func minFrequencyNgramOffsets(ngramOffs []runeNgramOff, frequencies []uint32) (first, last runeNgramOff) {
 	firstI, lastI := min2Index(frequencies)
 	// If the frequencies are equal lets maximise distance in the query
@@ -357,13 +379,15 @@ func minFrequencyNgramOffsets(ngramOffs []runeNgramOff, frequencies []uint32) (f
 			}
 		}
 	}
+
 	first = ngramOffs[firstI]
 	last = ngramOffs[lastI]
-	// Ensure first appears before last to make distance logic below clean.
+
+	// Ensure first appears before last as a helpful invariant.
 	if first.index > last.index {
 		last, first = first, last
 	}
-	return first, last
+	return
 }
 
 func (data *indexData) ngrams(filename bool) btreeIndex {
@@ -412,9 +436,10 @@ func (d *indexData) iterateNgrams(query *query.Substring) (*ngramIterationResult
 	// bucket (which can cause disk IO).
 	slices.SortFunc(ngramOffs, runeNgramOff.Compare)
 	frequencies := make([]uint32, 0, len(ngramOffs))
+	indexMap := make([]int, len(ngramOffs))
 	ngramLookups := 0
 	ngrams := d.ngrams(query.FileName)
-	for _, o := range ngramOffs {
+	for i, o := range ngramOffs {
 		var freq uint32
 		if query.CaseSensitive {
 			freq = ngrams.Get(o.ngram).sz
@@ -438,15 +463,14 @@ func (d *indexData) iterateNgrams(query *query.Substring) (*ngramIterationResult
 		}
 
 		frequencies = append(frequencies, freq)
+		indexMap[o.index] = i
 	}
 
-	// first and last are now the smallest trigram posting lists to iterate
-	// through.
-	first, last := minFrequencyNgramOffsets(ngramOffs, frequencies)
+	first, last := findSelectiveNgrams(ngramOffs, indexMap, frequencies)
 
 	iter := &ngramDocIterator{
-		leftPad:      first.index,
-		rightPad:     uint32(utf8.RuneCountInString(str)) - first.index,
+		leftPad:      uint32(first.index),
+		rightPad:     uint32(utf8.RuneCountInString(str) - first.index),
 		ngramLookups: ngramLookups,
 	}
 	if query.FileName {
@@ -456,7 +480,7 @@ func (d *indexData) iterateNgrams(query *query.Substring) (*ngramIterationResult
 	}
 
 	if first != last {
-		runeDist := last.index - first.index
+		runeDist := uint32(last.index - first.index)
 		i, err := d.newDistanceTrigramIter(first.ngram, last.ngram, runeDist, query.CaseSensitive, query.FileName)
 		if err != nil {
 			return nil, err
diff --git a/indexdata_test.go b/indexdata_test.go
index d4f8e1182..1e8b07966 100644
--- a/indexdata_test.go
+++ b/indexdata_test.go
@@ -72,3 +72,31 @@ func TestMinFrequencyNgramOffsets(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+func TestFindSelectiveNGrams(t *testing.T) {
+	if err := quick.Check(func(s string, maxFreq uint16) bool {
+		ngramOffs := splitNGrams([]byte(s))
+		if len(ngramOffs) == 0 {
+			return true
+		}
+
+		slices.SortFunc(ngramOffs, runeNgramOff.Compare)
+		indexMap := make([]int, len(ngramOffs))
+		for i, n := range ngramOffs {
+			indexMap[n.index] = i
+		}
+
+		frequencies := genFrequencies(ngramOffs, int(maxFreq))
+		x0, x1 := findSelectiveNgrams(ngramOffs, indexMap, frequencies)
+
+		if len(ngramOffs) <= 1 {
+			return true
+		}
+
+		// Just assert the invariant that x0 is before x1. This test mostly checks
+		// for out-of-bounds errors.
+		return x0.index < x1.index
+	}, nil); err != nil {
+		t.Fatal(err)
+	}
+}