From ac81c4acd6ba91edf33715018a9b84e94c4c538a Mon Sep 17 00:00:00 2001 From: Andrew Kimball Date: Wed, 6 Nov 2024 15:10:26 -0800 Subject: [PATCH] vecindex: implement C-SPANN search MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the VectorIndex class, which will implement the C-SPANN algorithm, which adapts Microsoft’s SPANN and SPFresh algorithms to work well with CockroachDB’s unique distributed architecture. This PR implements K-means tree search with a test implementation of bottom-up tree construction. Later PR's will include code to incrementally build the tree. Epic: CRDB-42943 Release note: None --- pkg/sql/vecindex/BUILD.bazel | 16 +- pkg/sql/vecindex/quantize/rabitq_test.go | 2 + pkg/sql/vecindex/testdata/delete.ddt | 90 +++ pkg/sql/vecindex/testdata/search-features.ddt | 72 ++ pkg/sql/vecindex/testdata/search.ddt | 186 +++++ .../vecindex/vecstore/in_memory_store_test.go | 10 +- pkg/sql/vecindex/vecstore/partition_test.go | 4 +- pkg/sql/vecindex/vecstore/search_set.go | 53 +- pkg/sql/vecindex/vecstore/search_set_test.go | 31 +- pkg/sql/vecindex/vector_index.go | 633 ++++++++++++++++++ pkg/sql/vecindex/vector_index_test.go | 562 ++++++++++++++++ pkg/util/vector/BUILD.bazel | 1 + pkg/util/vector/vector_set.go | 29 +- pkg/util/vector/vector_set_test.go | 5 + 14 files changed, 1641 insertions(+), 53 deletions(-) create mode 100644 pkg/sql/vecindex/testdata/delete.ddt create mode 100644 pkg/sql/vecindex/testdata/search-features.ddt create mode 100644 pkg/sql/vecindex/testdata/search.ddt create mode 100644 pkg/sql/vecindex/vector_index.go create mode 100644 pkg/sql/vecindex/vector_index_test.go diff --git a/pkg/sql/vecindex/BUILD.bazel b/pkg/sql/vecindex/BUILD.bazel index e52fd865fce0..66cae89cc595 100644 --- a/pkg/sql/vecindex/BUILD.bazel +++ b/pkg/sql/vecindex/BUILD.bazel @@ -8,11 +8,16 @@ filegroup( go_library( name = "vecindex", - srcs = ["kmeans.go"], + srcs = [ + "kmeans.go", + "vector_index.go", + ], importpath = "github.com/cockroachdb/cockroach/pkg/sql/vecindex", visibility = ["//visibility:public"], deps = [ "//pkg/sql/vecindex/internal", + "//pkg/sql/vecindex/quantize", + "//pkg/sql/vecindex/vecstore", "//pkg/util/num32", "//pkg/util/vector", "@com_github_cockroachdb_errors//:errors", @@ -21,14 +26,21 @@ go_library( go_test( name = "vecindex_test", - srcs = ["kmeans_test.go"], + srcs = [ + "kmeans_test.go", + "vector_index_test.go", + ], data = glob(["testdata/**"]), embed = [":vecindex"], deps = [ "//pkg/sql/vecindex/internal", + "//pkg/sql/vecindex/quantize", "//pkg/sql/vecindex/testutils", + "//pkg/sql/vecindex/vecstore", "//pkg/util/num32", "//pkg/util/vector", + "@com_github_cockroachdb_datadriven//:datadriven", + "@com_github_cockroachdb_errors//:errors", "@com_github_stretchr_testify//require", "@org_gonum_v1_gonum//floats/scalar", "@org_gonum_v1_gonum//stat", diff --git a/pkg/sql/vecindex/quantize/rabitq_test.go b/pkg/sql/vecindex/quantize/rabitq_test.go index 0abe34e5501e..b8d7da704071 100644 --- a/pkg/sql/vecindex/quantize/rabitq_test.go +++ b/pkg/sql/vecindex/quantize/rabitq_test.go @@ -106,8 +106,10 @@ func TestRaBitQuantizerEdge(t *testing.T) { vectors := vector.MakeSet(141) vectors.AddUndefined(2) + zeros := vectors.At(0) ones := vectors.At(1) for i := 0; i < len(ones); i++ { + zeros[i] = 0 ones[i] = 1 } quantizedSet := quantizer.Quantize(ctx, &vectors).(*RaBitQuantizedVectorSet) diff --git a/pkg/sql/vecindex/testdata/delete.ddt b/pkg/sql/vecindex/testdata/delete.ddt new file mode 100644 index 000000000000..a7fca90faff4 --- /dev/null +++ b/pkg/sql/vecindex/testdata/delete.ddt @@ -0,0 +1,90 @@ +# ---------- +# Test deleting vectors from primary index, but not from secondary index. +# ---------- +new-index min-partition-size=1 max-partition-size=3 beam-size=2 +vec1: (1, 2) +vec2: (7, 4) +vec3: (4, 3) +vec4: (5, 5) +---- +• 1 (3.1667, 3) +│ +├───• 2 (5.3333, 4) +│ │ +│ ├───• vec2 (7, 4) +│ ├───• vec3 (4, 3) +│ └───• vec4 (5, 5) +│ +└───• 3 (1, 2) + │ + └───• vec1 (1, 2) + +# Delete vector from primary index, but not from secondary index. +delete not-found +vec3 +---- +• 1 (3.1667, 3) +│ +├───• 2 (5.3333, 4) +│ │ +│ ├───• vec2 (7, 4) +│ ├───• vec3 (MISSING) +│ └───• vec4 (5, 5) +│ +└───• 3 (1, 2) + │ + └───• vec1 (1, 2) + +# Ensure deleted vector is not returned by search. This should enqueue a fixup +# that removes the vector from the index. +search max-results=1 +(4, 3) +---- +vec4: 5 (centroid=1.0541) +4 leaf vectors, 6 vectors, 2 full vectors, 3 partitions + +# Again, with higher max results. +search max-results=2 +(4, 3) +---- +vec4: 5 (centroid=1.0541) +vec2: 10 (centroid=1.6667) +4 leaf vectors, 6 vectors, 4 full vectors, 3 partitions + +# Vector should now be gone from the index. +# TODO(andyk): This will be true once fixups are added. +format-tree +---- +• 1 (3.1667, 3) +│ +├───• 2 (5.3333, 4) +│ │ +│ ├───• vec2 (7, 4) +│ ├───• vec3 (MISSING) +│ └───• vec4 (5, 5) +│ +└───• 3 (1, 2) + │ + └───• vec1 (1, 2) + +# Delete all vectors from one branch of the tree. +delete not-found +vec1 +---- +• 1 (3.1667, 3) +│ +├───• 2 (5.3333, 4) +│ │ +│ ├───• vec2 (7, 4) +│ ├───• vec3 (MISSING) +│ └───• vec4 (5, 5) +│ +└───• 3 (1, 2) + │ + └───• vec1 (MISSING) + +# Search the empty branch. +search max-results=1 beam-size=1 +(1, 2) +---- +1 leaf vectors, 3 vectors, 1 full vectors, 2 partitions diff --git a/pkg/sql/vecindex/testdata/search-features.ddt b/pkg/sql/vecindex/testdata/search-features.ddt new file mode 100644 index 000000000000..0486bfda97c3 --- /dev/null +++ b/pkg/sql/vecindex/testdata/search-features.ddt @@ -0,0 +1,72 @@ +# Load 500 512-dimension features and search them. Use small partition size to +# ensure a deeper tree. + +new-index dims=512 min-partition-size=2 max-partition-size=8 quality-samples=4 beam-size=2 load-features=500 hide-tree +---- +Created index with 500 vectors with 512 dimensions. + +# Start with 1 result and default beam size of 2. +search max-results=1 use-feature=9999 +---- +vec441: 0.4646 (centroid=0.382) +9 leaf vectors, 39 vectors, 2 full vectors, 5 partitions + +# Search for additional results. +search max-results=3 use-feature=9999 +---- +vec441: 0.4646 (centroid=0.382) +vec99: 0.6356 (centroid=0.382) +vec296: 0.7638 (centroid=0.5962) +9 leaf vectors, 39 vectors, 6 full vectors, 5 partitions + +# Use a larger beam size. +search max-results=6 use-feature=9999 beam-size=8 +---- +vec74: 0.4155 (centroid=0.5092) +vec195: 0.4359 (centroid=0.5127) +vec441: 0.4646 (centroid=0.382) +vec77: 0.4894 (centroid=0.4286) +vec355: 0.5821 (centroid=0.4617) +vec328: 0.6032 (centroid=0.5276) +58 leaf vectors, 123 vectors, 14 full vectors, 15 partitions + +# Turn off re-ranking, which results in increased inaccuracy. +search max-results=6 use-feature=9999 beam-size=8 skip-rerank +---- +vec195: 0.4179 ±0.0264 (centroid=0.5127) +vec74: 0.4322 ±0.0263 (centroid=0.5092) +vec441: 0.4657 ±0.0215 (centroid=0.382) +vec77: 0.4881 ±0.0221 (centroid=0.4286) +vec355: 0.5658 ±0.0238 (centroid=0.4617) +vec415: 0.6142 ±0.0302 (centroid=0.5306) +58 leaf vectors, 123 vectors, 0 full vectors, 15 partitions + +# Return top 25 results. +search max-results=25 use-feature=9999 beam-size=8 +---- +vec74: 0.4155 (centroid=0.5092) +vec195: 0.4359 (centroid=0.5127) +vec441: 0.4646 (centroid=0.382) +vec77: 0.4894 (centroid=0.4286) +vec355: 0.5821 (centroid=0.4617) +vec328: 0.6032 (centroid=0.5276) +vec389: 0.6183 (centroid=0.5267) +vec415: 0.6298 (centroid=0.5306) +vec99: 0.6356 (centroid=0.382) +vec267: 0.6742 (centroid=0.526) +vec6: 0.685 (centroid=0.6015) +vec485: 0.6867 (centroid=0.362) +vec236: 0.687 (centroid=0.5071) +vec198: 0.6885 (centroid=0.5094) +vec65: 0.6898 (centroid=0.4403) +vec146: 0.6901 (centroid=0.5601) +vec282: 0.7197 (centroid=0.4023) +vec410: 0.728 (centroid=0.4261) +vec356: 0.7341 (centroid=0.4352) +vec439: 0.7428 (centroid=0.6023) +vec116: 0.7462 (centroid=0.4643) +vec273: 0.7555 (centroid=0.5226) +vec453: 0.7735 (centroid=0.3571) +vec233: 0.7737 (centroid=0.5502) +vec331: 0.7793 (centroid=0.4871) +58 leaf vectors, 123 vectors, 44 full vectors, 15 partitions diff --git a/pkg/sql/vecindex/testdata/search.ddt b/pkg/sql/vecindex/testdata/search.ddt new file mode 100644 index 000000000000..8435c1b1acf5 --- /dev/null +++ b/pkg/sql/vecindex/testdata/search.ddt @@ -0,0 +1,186 @@ +# ---------- +# Construct new index with only root-level vectors. +# ---------- +new-index min-partition-size=1 max-partition-size=4 beam-size=2 +vec1: (1, 2) +vec2: (7, 4) +vec3: (4, 3) +---- +• 1 (4, 3) +│ +├───• vec1 (1, 2) +├───• vec2 (7, 4) +└───• vec3 (4, 3) + +# Search for vector that has exact match. +search +(7, 4) +---- +vec2: 0 (centroid=3.1623) +3 leaf vectors, 3 vectors, 3 full vectors, 1 partitions + +# Search for vector with no exact match. +search max-results=2 +(3, 5) +---- +vec3: 5 (centroid=0) +vec1: 13 (centroid=3.1623) +3 leaf vectors, 3 vectors, 3 full vectors, 1 partitions + +# ---------- +# Construct new index with multiple levels. +# ---------- +new-index min-partition-size=1 max-partition-size=4 beam-size=2 +vec1: (1, 2) +vec2: (7, 4) +vec3: (4, 3) +vec4: (-4, 5) +vec5: (1, 11) +vec6: (1, -6) +vec7: (0, 4) +vec8: (-2, 8) +vec9: (2, 8) +vec10: (0, 3) +vec11: (1, 1) +vec12: (5, 4) +vec13: (6, 2) +---- +• 1 (2.0556, 1.1389) +│ +├───• 8 (1, -6) +│ │ +│ └───• 4 (1, -6) +│ │ +│ └───• vec6 (1, -6) +│ +├───• 9 (5.5, 3.25) +│ │ +│ ├───• 3 (6.5, 3) +│ │ │ +│ │ ├───• vec2 (7, 4) +│ │ └───• vec13 (6, 2) +│ │ +│ └───• 5 (4.5, 3.5) +│ │ +│ ├───• vec3 (4, 3) +│ └───• vec12 (5, 4) +│ +└───• 10 (-0.3333, 6.1667) + │ + ├───• 2 (-3, 6.5) + │ │ + │ ├───• vec4 (-4, 5) + │ └───• vec8 (-2, 8) + │ + ├───• 6 (0.5, 2.5) + │ │ + │ ├───• vec1 (1, 2) + │ ├───• vec7 (0, 4) + │ ├───• vec10 (0, 3) + │ └───• vec11 (1, 1) + │ + └───• 7 (1.5, 9.5) + │ + ├───• vec5 (1, 11) + └───• vec9 (2, 8) + +# Search for closest vectors with beam-size=1. +search max-results=2 beam-size=1 +(1, 6) +---- +vec9: 5 (centroid=1.5811) +vec5: 25 (centroid=1.5811) +2 leaf vectors, 8 vectors, 2 full vectors, 3 partitions + +# Search for closest vectors with beam-size=2. +search max-results=2 beam-size=2 +(1, 6) +---- +vec9: 5 (centroid=1.5811) +vec7: 5 (centroid=1.5811) +6 leaf vectors, 12 vectors, 3 full vectors, 4 partitions + +# ---------- +# Construct new index with duplicate vectors. +# ---------- +new-index min-partition-size=1 max-partition-size=4 beam-size=2 +vec1: (-3, 5) +vec2: (10, -10) +vec3: (4, 9) +vec4: (1, 1) +vec5: (4, 9) +vec6: (6, 2) +---- +• 1 (3.6667, 2.6667) +│ +├───• 2 (8, -4) +│ │ +│ ├───• vec2 (10, -10) +│ └───• vec6 (6, 2) +│ +├───• 3 (-1, 3) +│ │ +│ ├───• vec1 (-3, 5) +│ └───• vec4 (1, 1) +│ +└───• 4 (4, 9) + │ + ├───• vec3 (4, 9) + └───• vec5 (4, 9) + +# Ensure that search result returns multiple keys. +search max-results=3 +(5, 10) +---- +vec3: 2 (centroid=0) +vec5: 2 (centroid=0) +vec1: 89 (centroid=2.8284) +4 leaf vectors, 7 vectors, 4 full vectors, 3 partitions + +# ---------- +# Construct new index with duplicate keys. This can happen when a vector is +# updated in the primary index, but it cannot be found in the secondary index. +# ---------- +new-index min-partition-size=1 max-partition-size=4 beam-size=2 +vec1: (1, 2) +vec2: (7, 4) +vec3: (4, 3) +vec4: (-4, 5) +vec1: (10, 5) +vec1: (12, 7) +---- +• 1 (2.7222, 4.2778) +│ +├───• 2 (9.6667, 5.3333) +│ │ +│ ├───• vec2 (7, 4) +│ ├───• vec1 (12, 7) +│ └───• vec1 (12, 7) +│ +├───• 3 (2.5, 2.5) +│ │ +│ ├───• vec1 (12, 7) +│ └───• vec3 (4, 3) +│ +└───• 4 (-4, 5) + │ + └───• vec4 (-4, 5) + +# Ensure that search result doesn't contain duplicates. +search max-results=5 +(8, 9) +---- +vec1: 20 (centroid=1.5811) +vec2: 26 (centroid=2.9814) +vec3: 52 (centroid=1.5811) +5 leaf vectors, 8 vectors, 3 full vectors, 3 partitions + +# Do not rerank results. This may cause a different vec1 duplicate to be +# returned. +search max-results=5 skip-rerank +(8, 9) +---- +vec2: 26.3282 ±16.9822 (centroid=2.9814) +vec3: 49.8042 ±19.0394 (centroid=1.5811) +vec1: 100.1958 ±19.0394 (centroid=1.5811) +5 leaf vectors, 8 vectors, 0 full vectors, 3 partitions diff --git a/pkg/sql/vecindex/vecstore/in_memory_store_test.go b/pkg/sql/vecindex/vecstore/in_memory_store_test.go index b2819b80566d..322877991cea 100644 --- a/pkg/sql/vecindex/vecstore/in_memory_store_test.go +++ b/pkg/sql/vecindex/vecstore/in_memory_store_test.go @@ -105,7 +105,7 @@ func TestInMemoryStore(t *testing.T) { result2 := SearchResult{QuerySquaredDistance: 13, ErrorBound: 0, CentroidDistance: 5, ParentPartitionKey: 1, ChildKey: childKey30} results := searchSet.PopResults() roundResults(results, 4) - require.Equal(t, []SearchResult{result1, result2}, results) + require.Equal(t, SearchResults{result1, result2}, results) require.Equal(t, 3, partitionCounts[0]) }) @@ -156,7 +156,7 @@ func TestInMemoryStore(t *testing.T) { require.NoError(t, err) require.Equal(t, Level(2), level) result3 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 0, ParentPartitionKey: 1, ChildKey: childKey2} - require.Equal(t, []SearchResult{result3}, searchSet.PopResults()) + require.Equal(t, SearchResults{result3}, searchSet.PopResults()) require.Equal(t, 1, partitionCounts[0]) }) @@ -192,7 +192,7 @@ func TestInMemoryStore(t *testing.T) { require.Equal(t, Level(1), level) result4 := SearchResult{QuerySquaredDistance: 1, ErrorBound: 0, CentroidDistance: 2.23606797749979, ParentPartitionKey: 2, ChildKey: childKey10} result5 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 1, ParentPartitionKey: 2, ChildKey: childKey40} - require.Equal(t, []SearchResult{result4, result5}, searchSet.PopResults()) + require.Equal(t, SearchResults{result4, result5}, searchSet.PopResults()) require.Equal(t, 3, partitionCounts[0]) }) @@ -217,7 +217,7 @@ func TestInMemoryStore(t *testing.T) { require.Equal(t, Level(1), level) result4 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 5, ParentPartitionKey: 2, ChildKey: childKey30} result5 := SearchResult{QuerySquaredDistance: 5, ErrorBound: 0, CentroidDistance: 4.61, ParentPartitionKey: 3, ChildKey: childKey50} - require.Equal(t, []SearchResult{result4, result5}, roundResults(searchSet.PopResults(), 2)) + require.Equal(t, SearchResults{result4, result5}, roundResults(searchSet.PopResults(), 2)) require.Equal(t, []int{3, 2}, partitionCounts) }) @@ -282,7 +282,7 @@ func TestInMemoryStoreConcurrency(t *testing.T) { ctx2, txn2, []PartitionKey{RootKey}, vector.T{0, 0}, &searchSet, partitionCounts) require.NoError(t, err) result1 := SearchResult{QuerySquaredDistance: 25, ErrorBound: 0, CentroidDistance: 5, ParentPartitionKey: RootKey, ChildKey: childKey10} - require.Equal(t, []SearchResult{result1}, searchSet.PopResults()) + require.Equal(t, SearchResults{result1}, searchSet.PopResults()) require.Equal(t, 1, partitionCounts[0]) wait.Done() diff --git a/pkg/sql/vecindex/vecstore/partition_test.go b/pkg/sql/vecindex/vecstore/partition_test.go index a871c1babe65..3809044ccf39 100644 --- a/pkg/sql/vecindex/vecstore/partition_test.go +++ b/pkg/sql/vecindex/vecstore/partition_test.go @@ -49,7 +49,7 @@ func TestPartition(t *testing.T) { result2 := SearchResult{QuerySquaredDistance: 13, ErrorBound: 0, CentroidDistance: 0.3333, ParentPartitionKey: 1, ChildKey: childKey40} result3 := SearchResult{QuerySquaredDistance: 17, ErrorBound: 0, CentroidDistance: 1.6667, ParentPartitionKey: 1, ChildKey: childKey20} results := roundResults(searchSet.PopResults(), 4) - require.Equal(t, []SearchResult{result1, result2, result3}, results) + require.Equal(t, SearchResults{result1, result2, result3}, results) // Find method. require.Equal(t, 2, partition.Find(childKey30)) @@ -68,7 +68,7 @@ func TestPartition(t *testing.T) { require.Equal(t, []ChildKey{}, partition.ChildKeys()) } -func roundResults(results []SearchResult, prec int) []SearchResult { +func roundResults(results SearchResults, prec int) SearchResults { for i := range results { result := &results[i] result.QuerySquaredDistance = float32(scalar.Round(float64(result.QuerySquaredDistance), prec)) diff --git a/pkg/sql/vecindex/vecstore/search_set.go b/pkg/sql/vecindex/vecstore/search_set.go index 65158aa19594..6da969f0b1cf 100644 --- a/pkg/sql/vecindex/vecstore/search_set.go +++ b/pkg/sql/vecindex/vecstore/search_set.go @@ -12,6 +12,20 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/vector" ) +// SearchResults is a list of search results from the search set. +type SearchResults []SearchResult + +// Sort re-orders the results in-place, by their distance. +func (s *SearchResults) Sort() { + results := searchResultHeap(*s) + if len(results) > 1 { + // Sort the results in-place. + sort.Slice(results, func(i int, j int) bool { + return !results.Less(i, j) + }) + } +} + // SearchResult contains a set of results from searching partitions for data // vectors that are nearest to a query vector. type SearchResult struct { @@ -57,7 +71,7 @@ func (h searchResultHeap) Less(i, j int) bool { return true } if distance1 == distance2 && h[i].ErrorBound < h[j].ErrorBound { - // If distance is equal, higher error bound sorts first. + // If distance is equal, lower error bound sorts first. return true } return false @@ -181,34 +195,25 @@ func (ss *SearchSet) Add(candidate *SearchResult) { } // AddAll includes a set of candidates in the search set. -func (ss *SearchSet) AddAll(candidates []SearchResult) { +func (ss *SearchSet) AddAll(candidates SearchResults) { for i := range candidates { ss.Add(&candidates[i]) } } -// PopResults removes the nearest candidates by distance from the set and -// returns them as a sorted list. -func (ss *SearchSet) PopResults() []SearchResult { - return ss.popAndSort(&ss.results) +// PopResults removes all results from the set and returns them in order of +// their distance. +func (ss *SearchSet) PopResults() SearchResults { + results := ss.PopUnsortedResults() + results.Sort() + return results } -// PopExtraResults removes extra potential candidates from the set and returns -// them as a sorted list. -func (ss *SearchSet) PopExtraResults() []SearchResult { - return ss.popAndSort(&ss.extraResults) -} - -func (ss *SearchSet) popAndSort(results *searchResultHeap) []SearchResult { - if len(*results) > 1 { - // Sort the results. This invalidates the heap invariant, so do not - // attempt to reuse the slice. - sort.Slice(*results, func(i int, j int) bool { - return !results.Less(i, j) - }) - } - - popped := *results - *results = nil - return popped +// PopUnsortedResults removes the nearest candidates by distance from the set +// and returns them in unsorted order. +func (ss *SearchSet) PopUnsortedResults() SearchResults { + popped := append(ss.results, ss.extraResults...) + ss.results = nil + ss.extraResults = nil + return SearchResults(popped) } diff --git a/pkg/sql/vecindex/vecstore/search_set_test.go b/pkg/sql/vecindex/vecstore/search_set_test.go index 0807497cc62d..8b218c62ead9 100644 --- a/pkg/sql/vecindex/vecstore/search_set_test.go +++ b/pkg/sql/vecindex/vecstore/search_set_test.go @@ -75,7 +75,6 @@ func TestSearchSet(t *testing.T) { // Empty. searchSet := SearchSet{MaxResults: 3, MaxExtraResults: 7} require.Nil(t, searchSet.PopResults()) - require.Nil(t, searchSet.PopExtraResults()) // Exceed max results, outside of error bounds. result1 := SearchResult{ @@ -90,41 +89,35 @@ func TestSearchSet(t *testing.T) { searchSet.Add(&result2) searchSet.Add(&result3) searchSet.Add(&result4) - require.Equal(t, []SearchResult{result3, result1, result4}, searchSet.PopResults()) - require.Nil(t, searchSet.PopExtraResults()) + require.Equal(t, SearchResults{result3, result1, result4}, searchSet.PopResults()) // Exceed max results, but within error bounds. result5 := SearchResult{ QuerySquaredDistance: 6, ErrorBound: 1.5, CentroidDistance: 50, ParentPartitionKey: 500, ChildKey: ChildKey{PrimaryKey: []byte{50}}} result6 := SearchResult{ QuerySquaredDistance: 5, ErrorBound: 1, CentroidDistance: 60, ParentPartitionKey: 600, ChildKey: ChildKey{PrimaryKey: []byte{60}}} - searchSet.AddAll([]SearchResult{result1, result2, result3, result4, result5, result6}) - require.Equal(t, []SearchResult{result3, result1, result4}, searchSet.PopResults()) - require.Equal(t, []SearchResult{result6, result5}, searchSet.PopExtraResults()) + searchSet.AddAll(SearchResults{result1, result2, result3, result4, result5, result6}) + require.Equal(t, SearchResults{result3, result1, result4, result6, result5}, searchSet.PopResults()) // Don't allow extra results. otherSet := SearchSet{MaxResults: 3} - otherSet.AddAll([]SearchResult{result1, result2, result3, result4, result5, result6}) - require.Equal(t, []SearchResult{result3, result1, result4}, otherSet.PopResults()) - require.Nil(t, otherSet.PopExtraResults()) + otherSet.AddAll(SearchResults{result1, result2, result3, result4, result5, result6}) + require.Equal(t, SearchResults{result3, result1, result4}, otherSet.PopResults()) // Add better results that invalidate farther candidates. result7 := SearchResult{ QuerySquaredDistance: 4, ErrorBound: 1.5, CentroidDistance: 70, ParentPartitionKey: 700, ChildKey: ChildKey{PrimaryKey: []byte{70}}} - searchSet.AddAll([]SearchResult{result1, result2, result3, result4, result5, result6, result7}) - require.Equal(t, []SearchResult{result3, result1, result7}, searchSet.PopResults()) - require.Equal(t, []SearchResult{result4, result6, result5}, searchSet.PopExtraResults()) + searchSet.AddAll(SearchResults{result1, result2, result3, result4, result5, result6, result7}) + require.Equal(t, SearchResults{result3, result1, result7, result4, result6, result5}, searchSet.PopResults()) result8 := SearchResult{ QuerySquaredDistance: 0.5, ErrorBound: 0.5, CentroidDistance: 80, ParentPartitionKey: 800, ChildKey: ChildKey{PrimaryKey: []byte{80}}} - searchSet.AddAll([]SearchResult{result1, result2, result3, result4}) - searchSet.AddAll([]SearchResult{result5, result6, result7, result8}) - require.Equal(t, []SearchResult{result8, result3, result1}, searchSet.PopResults()) - require.Equal(t, []SearchResult{result7, result4}, searchSet.PopExtraResults()) + searchSet.AddAll(SearchResults{result1, result2, result3, result4}) + searchSet.AddAll(SearchResults{result5, result6, result7, result8}) + require.Equal(t, SearchResults{result8, result3, result1, result7, result4}, searchSet.PopResults()) // Allow one extra result. otherSet.MaxExtraResults = 1 - otherSet.AddAll([]SearchResult{result1, result2, result3, result4, result5, result6, result7}) - require.Equal(t, []SearchResult{result3, result1, result7}, otherSet.PopResults()) - require.Equal(t, []SearchResult{result4}, otherSet.PopExtraResults()) + otherSet.AddAll(SearchResults{result1, result2, result3, result4, result5, result6, result7}) + require.Equal(t, SearchResults{result3, result1, result7, result4}, otherSet.PopResults()) } diff --git a/pkg/sql/vecindex/vector_index.go b/pkg/sql/vecindex/vector_index.go new file mode 100644 index 000000000000..7108e722e7d4 --- /dev/null +++ b/pkg/sql/vecindex/vector_index.go @@ -0,0 +1,633 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package vecindex + +import ( + "bytes" + "context" + "math" + "strconv" + "strings" + + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/internal" + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize" + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore" + "github.com/cockroachdb/cockroach/pkg/util/num32" + "github.com/cockroachdb/cockroach/pkg/util/vector" + "github.com/cockroachdb/errors" +) + +// RerankMultiplier is multiplied by MaxResults to calculate the maximum number +// of search results that will be reranked with the original full-size vectors. +const RerankMultiplier = 10 + +// DeletedMultiplier increases the number of results that will be reranked, in +// order to account for vectors that may have been deleted in the primary index. +const DeletedMultiplier = 1.2 + +// VectorIndexOptions specifies options that control how the index will be +// built, as well as default options for how it will be searched. A given search +// operation can specify SearchOptions to override the default behavior. +type VectorIndexOptions struct { + // MinPartitionSize specifies the size below which a partition will be merged + // into other partitions at the same level. + MinPartitionSize int + // MaxPartitionSize specifies the size above which a partition will be split. + MaxPartitionSize int + // BaseBeamSize is the default number of child partitions that will be + // searched at each level during insert, delete, and search operations. + // Adaptive search will automatically decrease or increase this value as the + // search proceeds. + BaseBeamSize int + // QualitySamples is the number of search results that are used as samples for + // calculating search quality statistics. Adaptive search uses these stats to + // determine how many partitions to search at each level. + QualitySamples int + // DisableAdaptiveSearch can be set to true to disable adaptive search. This + // is useful for testing and benchmarking. + DisableAdaptiveSearch bool + // DisableErrorBounds can be set to true to disable using error bounds to + // limit the number of results that need to be reranked. This is useful for + // testing and benchmarking. + DisableErrorBounds bool +} + +// SearchOptions specifies options that apply to a particular search operation +// over the vector index. +type SearchOptions struct { + // BaseBeamSize is the default number of child partitions that will be + // searched at each level. Adaptive search will automatically decrease or + // increase this value as the search proceeds. If zero, then it defaults to + // the BaseBeamSize value from VectorIndexOptions. + BaseBeamSize int + // SkipRerank does not rerank search results using the original full-size + // vectors. While this speeds up the search, it can also significantly + // reduce accuracy. It is currently only used for testing. + SkipRerank bool + // ReturnVectors specifies whether to return the original full-size vectors + // in search results. + ReturnVectors bool +} + +// searchContext contains per-thread state needed during index search +// operations. Fields in the context are set at the beginning of an index +// operation and passed down the call stack. +type searchContext struct { + Ctx context.Context + Workspace internal.Workspace + Txn vecstore.Txn + Options SearchOptions + + // Level of the tree from which search results are returned. For the Search + // operation, this is always LeafLevel, but inserts and splits/merges can + // search at intermediate levels of the tree. + Level vecstore.Level + + // Original is the original, full-size vector that was passed to the top-level + // method on VectorIndex. + Original vector.T + + // Randomized is the original vector after being processed by the quantizer's + // RandomizeVector method. It can have a different number of dimensions than + // Original and different values in those dimensions. + Randomized vector.T + + tempKeys []vecstore.PartitionKey + tempCounts []int + tempVectorsWithKeys []vecstore.VectorWithKey +} + +// VectorIndex implements the C-SPANN algorithm, which adapts Microsoft’s SPANN +// and SPFresh algorithms to work well with CockroachDB’s unique distributed +// architecture. This enables CockroachDB to efficiently answer approximate +// nearest neighbor (ANN) queries with high accuracy, low latency, and fresh +// results, with millions or even billions of indexed vectors. In a departure +// from SPANN (and similar to Google's ScaNN), C-SPANN packs hundreds of vectors +// into each partition, and indexes partitions using a K-means tree. +// +// VectorIndex is thread-safe. There should typically be only one VectorIndex +// instance in the process for each index. +type VectorIndex struct { + // options specify how the index will be built and searched, by default. + options VectorIndexOptions + // store is the interface with the component that transactionally stores + // partitions and vectors. + store vecstore.Store + // rootQuantizer quantizes vectors in the root partition. + rootQuantizer quantize.Quantizer + // quantizer quantizes vectors in every partition except the root. + quantizer quantize.Quantizer +} + +// NewVectorIndex constructs a new vector index instance. Typically, only one +// VectorIndex instance should be created for each index in the process. +func NewVectorIndex( + ctx context.Context, + store vecstore.Store, + quantizer quantize.Quantizer, + options *VectorIndexOptions, +) (*VectorIndex, error) { + vi := &VectorIndex{ + options: *options, + store: store, + rootQuantizer: quantize.NewUnQuantizer(quantizer.GetRandomDims()), + quantizer: quantizer, + } + if vi.options.MinPartitionSize == 0 { + vi.options.MinPartitionSize = 16 + } + if vi.options.MaxPartitionSize == 0 { + vi.options.MaxPartitionSize = 128 + } + if vi.options.BaseBeamSize == 0 { + vi.options.BaseBeamSize = 8 + } + if vi.options.QualitySamples == 0 { + vi.options.QualitySamples = 16 + } + + return vi, nil +} + +// CreateRoot creates an empty root partition in the store. This should only be +// called once when the index is first created. +func (vi *VectorIndex) CreateRoot(ctx context.Context, txn vecstore.Txn) error { + // Use the UnQuantizer because vectors in the root are not quantized. + dims := vi.rootQuantizer.GetRandomDims() + vectors := vector.MakeSet(dims) + rootQuantizedSet := vi.rootQuantizer.Quantize(ctx, &vectors) + rootPartition := vecstore.NewPartition( + vi.rootQuantizer, rootQuantizedSet, []vecstore.ChildKey{}, vecstore.LeafLevel) + return vi.store.SetRootPartition(ctx, txn, rootPartition) +} + +// Search finds vectors in the index that are closest to the given query vector +// and returns them in the search set. Set searchSet.MaxResults to limit the +// number of results. +func (vi *VectorIndex) Search( + ctx context.Context, + txn vecstore.Txn, + queryVector vector.T, + searchSet *vecstore.SearchSet, + options SearchOptions, +) error { + searchCtx := searchContext{ + Txn: txn, + Original: queryVector, + Level: vecstore.LeafLevel, + Options: options, + } + + searchCtx.Ctx = internal.WithWorkspace(ctx, &searchCtx.Workspace) + + // Randomize the vector if required by the quantizer. + tempRandomized := searchCtx.Workspace.AllocVector(vi.quantizer.GetRandomDims()) + defer searchCtx.Workspace.FreeVector(tempRandomized) + vi.quantizer.RandomizeVector(ctx, queryVector, tempRandomized, false /* invert */) + searchCtx.Randomized = tempRandomized + + return vi.searchHelper(&searchCtx, searchSet, true /* allowRetry */) +} + +// searchHelper contains the core search logic for the K-means tree. It begins +// at the root and proceeds downwards, breadth-first. At each level of the tree, +// it searches the subset of partitions that have centroids nearest to the query +// vector. Using estimated distance calculations, the search finds the nearest +// quantized data vectors within these partitions. If at an interior level, +// these data vectors are the quantized representation of centroids in the next +// level down, and the search continues there. If at the leaf level, then these +// data vectors are the quantized representation of the original vectors that +// were inserted into the tree. The original, full-size vectors are fetched from +// the primary index and used to re-rank candidate search results. +func (vi *VectorIndex) searchHelper( + searchCtx *searchContext, searchSet *vecstore.SearchSet, allowRetry bool, +) error { + // Return enough search results to: + // 1. Ensure that the number of results requested by the caller is respected. + // 2. Ensure that there are enough samples for calculating stats. + // 3. Ensure that there are enough results for adaptive querying to dynamically + // expand the beam size (up to 4x the base beam size). + maxResults := max( + searchSet.MaxResults, vi.options.QualitySamples, searchCtx.Options.BaseBeamSize*4) + subSearchSet := vecstore.SearchSet{MaxResults: maxResults} + searchCtx.tempKeys = ensureSliceLen(searchCtx.tempKeys, 1) + searchCtx.tempKeys[0] = vecstore.RootKey + searchLevel, err := vi.searchChildPartitions(searchCtx, &subSearchSet, searchCtx.tempKeys) + if err != nil { + return err + } + + if searchLevel < searchCtx.Level { + // This should only happen when inserting into the root. + if searchLevel != searchCtx.Level-1 { + panic(errors.AssertionFailedf("caller passed invalid level %d", searchCtx.Level)) + } + if searchCtx.Options.ReturnVectors { + panic(errors.AssertionFailedf("ReturnVectors=true not supported for this case")) + } + searchSet.Add(&vecstore.SearchResult{ + ChildKey: vecstore.ChildKey{PartitionKey: vecstore.RootKey}, + }) + return nil + } + + for { + results := subSearchSet.PopUnsortedResults() + if len(results) == 0 { + // This should never happen, as it means that interior partition(s) + // have no children. The vector deletion logic should prevent that. + panic(errors.AssertionFailedf( + "interior partition(s) on level %d has no children", searchLevel)) + } + + var zscore float64 + if searchLevel > vecstore.LeafLevel { + // Compute the z-score of the candidate results list. + // TODO(andyk): Track z-score stats. + zscore = 0 + } + + if searchLevel <= searchCtx.Level { + if searchLevel != searchCtx.Level { + // This indicates index corruption, since each lower level should + // be one less than its parent level. + panic(errors.AssertionFailedf("somehow skipped to level %d when searching for level %d", + searchLevel, searchCtx.Level)) + } + + // Aggregate all stats from searching lower levels of the tree. + searchSet.Stats.Add(&subSearchSet.Stats) + + results = vi.pruneDuplicates(results) + if !searchCtx.Options.SkipRerank || searchCtx.Options.ReturnVectors { + // Re-rank search results with full vectors. + searchSet.Stats.FullVectorCount += len(results) + results, err = vi.rerankSearchResults(searchCtx, results) + if err != nil { + return err + } + } + searchSet.AddAll(results) + break + } + + // Calculate beam size for searching next level. + beamSize := searchCtx.Options.BaseBeamSize + if beamSize == 0 { + beamSize = vi.options.BaseBeamSize + } + + if !vi.options.DisableAdaptiveSearch { + // Look at variance in result distances to calculate the beam size for + // the next level. The less variance there is, the larger the beam size. + // The intuition is that the closer the distances are to one another, the + // more densely packed are the vectors, and the more partitions they're + // likely to be spread across. + tempBeamSize := float64(beamSize) * math.Pow(2, -zscore) + tempBeamSize = max(min(tempBeamSize, float64(beamSize)*4), float64(beamSize)/2) + + if searchLevel > vecstore.LeafLevel+1 { + // Use progressively smaller beam size for higher levels, since + // each contains exponentially fewer partitions. + tempBeamSize /= math.Pow(2, float64(searchLevel-(vecstore.LeafLevel+1))) + } + + beamSize = int(math.Ceil(tempBeamSize)) + } + beamSize = max(beamSize, 1) + + searchLevel-- + if searchLevel == searchCtx.Level { + // Searching the last level, so return enough search results to: + // 1. Ensure that the number of results requested by the caller is + // respected. + // 2. Ensure there are enough samples for re-ranking to work well, even + // if there are deleted vectors. + if !vi.options.DisableErrorBounds { + subSearchSet.MaxResults = int(math.Ceil(float64(searchSet.MaxResults) * DeletedMultiplier)) + subSearchSet.MaxExtraResults = subSearchSet.MaxResults * RerankMultiplier + } else { + subSearchSet.MaxResults = searchSet.MaxResults * RerankMultiplier / 2 + subSearchSet.MaxExtraResults = 0 + } + } + + // Search up to beamSize child partitions. + results.Sort() + keyCount := min(beamSize, len(results)) + searchCtx.tempKeys = ensureSliceLen(searchCtx.tempKeys, keyCount) + for i := 0; i < keyCount; i++ { + searchCtx.tempKeys[i] = results[i].ChildKey.PartitionKey + } + + _, err = vi.searchChildPartitions(searchCtx, &subSearchSet, searchCtx.tempKeys) + if errors.Is(err, vecstore.ErrPartitionNotFound) { + // The cached root partition must be stale, so retry the search. + if !allowRetry { + // This indicates index corruption, since it should only require + // a single retry to handle the case where the root partition is + // stale. There should be no other cases that a partition cannot + // be found. + panic(errors.AssertionFailedf("partition cannot be found even though root is not stale")) + } + return vi.searchHelper(searchCtx, searchSet, false /* allowRetry */) + } else if err != nil { + return err + } + } + + return nil +} + +// searchChildPartitions searches the set of requested partitions for the query +// vector and adds the closest matches to the given search set. +func (vi *VectorIndex) searchChildPartitions( + searchCtx *searchContext, searchSet *vecstore.SearchSet, partitionKeys []vecstore.PartitionKey, +) (level vecstore.Level, err error) { + searchCtx.tempCounts = ensureSliceLen(searchCtx.tempCounts, len(partitionKeys)) + level, err = vi.store.SearchPartitions( + searchCtx.Ctx, searchCtx.Txn, partitionKeys, searchCtx.Randomized, + searchSet, searchCtx.tempCounts) + if err != nil { + return 0, err + } + + for i := 0; i < len(searchCtx.tempCounts); i++ { + count := searchCtx.tempCounts[i] + searchSet.Stats.SearchedPartition(level, count) + // TODO(andyk): Enqueue a split/merge fixup for the partition. + } + + return level, nil +} + +// pruneDuplicates removes candidates with duplicate child keys. This is rare, +// but it can happen when a vector updated in the primary index cannot be +// located in the secondary index. +func (vi *VectorIndex) pruneDuplicates(candidates []vecstore.SearchResult) []vecstore.SearchResult { + if len(candidates) <= 1 { + // No possibility of duplicates. + return candidates + } + + if candidates[0].ChildKey.PrimaryKey == nil { + // Only leaf partitions can have duplicates. + return candidates + } + + dups := make(map[string]bool, len(candidates)) + for i := 0; i < len(candidates); i++ { + key := candidates[i].ChildKey.PrimaryKey + if _, ok := dups[string(key)]; ok { + // Found duplicate, so remove it by replacing it with the last + // candidate. + candidates[i] = candidates[len(candidates)-1] + candidates = candidates[:len(candidates)-1] + i-- + continue + } + dups[string(key)] = true + } + return candidates +} + +// rerankSearchResults updates the given set of candidates with their exact +// distances from the query vector. It does this by fetching the original full +// size vectors from the store, in order to re-rank the top candidates for +// extra search result accuracy. +func (vi *VectorIndex) rerankSearchResults( + searchCtx *searchContext, candidates []vecstore.SearchResult, +) ([]vecstore.SearchResult, error) { + if len(candidates) == 0 { + return candidates, nil + } + + // Fetch the full vectors from the store. + candidates, err := vi.getRerankVectors(searchCtx, candidates) + if err != nil { + return candidates, err + } + + queryVector := searchCtx.Randomized + if searchCtx.Level == vecstore.LeafLevel { + // Leaf vectors haven't been randomized, so compare with the original query + // vector if available, or un-randomize the randomized vector. The original + // vector is not available in some cases where split/merge needs to move + // vectors between partitions. + if searchCtx.Original != nil { + queryVector = searchCtx.Original + } else { + queryVector = searchCtx.Workspace.AllocVector(vi.quantizer.GetOriginalDims()) + defer searchCtx.Workspace.FreeVector(queryVector) + vi.quantizer.RandomizeVector( + searchCtx.Ctx, searchCtx.Randomized, queryVector, true /* invert */) + } + } + + // Compute exact distances for the vectors. + for i := range candidates { + candidate := &candidates[i] + candidate.QuerySquaredDistance = num32.L2SquaredDistance(candidate.Vector, queryVector) + candidate.ErrorBound = 0 + } + + return candidates, nil +} + +// getRerankVectors updates the given search candidates with the original full +// size vectors from the store. If a candidate's vector has been deleted from +// the primary index, that candidate is removed from the list of candidates +// that's returned. +func (vi *VectorIndex) getRerankVectors( + searchCtx *searchContext, candidates []vecstore.SearchResult, +) ([]vecstore.SearchResult, error) { + // Prepare vector references. + searchCtx.tempVectorsWithKeys = ensureSliceLen(searchCtx.tempVectorsWithKeys, len(candidates)) + for i := 0; i < len(candidates); i++ { + searchCtx.tempVectorsWithKeys[i].Key = candidates[i].ChildKey + } + + // The store is expected to fetch the vectors in parallel. + err := vi.store.GetFullVectors(searchCtx.Ctx, searchCtx.Txn, searchCtx.tempVectorsWithKeys) + if err != nil { + return nil, err + } + + for i := 0; i < len(candidates); i++ { + candidates[i].Vector = searchCtx.tempVectorsWithKeys[i].Vector + + // Exclude deleted vectors from results. + if candidates[i].Vector == nil { + // Vector was deleted, so add fixup to delete it. + // TODO(andyk): Enqueue a delete of a vector. + + // Move the last candidate to the current position and reduce size + // of slice by one. + searchCtx.tempVectorsWithKeys[i] = searchCtx.tempVectorsWithKeys[len(candidates)-1] + candidates[i] = candidates[len(candidates)-1] + candidates = candidates[:len(candidates)-1] + i-- + } + } + + return candidates, nil +} + +// FormatOptions modifies the behavior of the Format method. +type FormatOptions struct { + // PrimaryKeyStrings, if true, indicates that primary key bytes should be + // interpreted as strings. This is used for testing scenarios. + PrimaryKeyStrings bool +} + +// Format formats the vector index as a tree-formatted string similar to this, +// for testing and debugging purposes: +// +// • 1 (4, 3) +// │ +// ├───• vec1 (1, 2) +// ├───• vec2 (7, 4) +// └───• vec3 (4, 3) +// +// Vectors with many dimensions are abbreviated like (5, -1, ..., 2, 8), and +// values are rounded to 4 decimal places. Centroids are printed next to +// partition keys. +func (vi *VectorIndex) Format( + ctx context.Context, txn vecstore.Txn, options FormatOptions, +) (str string, err error) { + var buf bytes.Buffer + + // Format each number to 4 decimal places, removing unnecessary trailing + // zeros. + formatFloat := func(value float32) string { + s := strconv.FormatFloat(float64(value), 'f', 4, 32) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s + } + + writeVector := func(vector vector.T) { + buf.WriteByte('(') + if len(vector) > 4 { + // Show first 2 numbers, '...', and last 2 numbers. + buf.WriteString(formatFloat(vector[0])) + buf.WriteString(", ") + buf.WriteString(formatFloat(vector[1])) + buf.WriteString(", ..., ") + buf.WriteString(formatFloat(vector[len(vector)-2])) + buf.WriteString(", ") + buf.WriteString(formatFloat(vector[len(vector)-1])) + } else { + // Show all numbers if there are 4 or fewer. + for i, val := range vector { + if i != 0 { + buf.WriteString(", ") + } + buf.WriteString(formatFloat(val)) + } + } + buf.WriteByte(')') + } + + writePrimaryKey := func(key vecstore.PrimaryKey) { + if options.PrimaryKeyStrings { + buf.WriteString(string(key)) + } else { + for i, b := range key { + if i != 0 { + buf.WriteByte(' ') + } + buf.WriteString(strconv.FormatUint(uint64(b), 10)) + } + } + } + + var helper func(partitionKey vecstore.PartitionKey, parentPrefix string, childPrefix string) error + helper = func(partitionKey vecstore.PartitionKey, parentPrefix string, childPrefix string) error { + partition, err := vi.store.GetPartition(ctx, txn, partitionKey) + if err != nil { + return err + } + // Get centroid for the partition and un-randomize it so that it displays + // the original vector. + random := partition.Centroid() + original := make(vector.T, len(random)) + vi.quantizer.RandomizeVector(ctx, original, random, true /* invert */) + buf.WriteString(parentPrefix) + buf.WriteString("• ") + buf.WriteString(strconv.FormatInt(int64(partitionKey), 10)) + buf.WriteByte(' ') + writeVector(original) + buf.WriteByte('\n') + + if partition.Count() == 0 { + return nil + } + + buf.WriteString(childPrefix) + buf.WriteString("│\n") + + for i, childKey := range partition.ChildKeys() { + isLastChild := (i == partition.Count()-1) + if isLastChild { + parentPrefix = childPrefix + "└───" + } else { + parentPrefix = childPrefix + "├───" + } + + if partition.Level() == vecstore.LeafLevel { + refs := []vecstore.VectorWithKey{{Key: childKey}} + if err = vi.store.GetFullVectors(ctx, txn, refs); err != nil { + return err + } + buf.WriteString(parentPrefix) + buf.WriteString("• ") + writePrimaryKey(childKey.PrimaryKey) + if refs[0].Vector != nil { + buf.WriteByte(' ') + writeVector(refs[0].Vector) + } else { + buf.WriteString(" (MISSING)") + } + buf.WriteByte('\n') + + if isLastChild && strings.TrimSpace(childPrefix) != "" { + buf.WriteString(strings.TrimRight(childPrefix, " ")) + buf.WriteByte('\n') + } + } else { + nextChildPrefix := childPrefix + if isLastChild { + nextChildPrefix += " " + } else { + nextChildPrefix += "│ " + } + if err = helper(childKey.PartitionKey, parentPrefix, nextChildPrefix); err != nil { + return err + } + } + } + + return nil + } + + if err = helper(vecstore.RootKey, "", ""); err != nil { + return "", err + } + return buf.String(), nil +} + +// ensureSliceLen returns a slice of the given length and generic type. If the +// existing slice has enough capacity, that slice is returned after adjusting +// its length. Otherwise, a new, larger slice is allocated. +func ensureSliceLen[T any](s []T, l int) []T { + if cap(s) < l { + return make([]T, l, max(l*3/2, 16)) + } + return s[:l] +} diff --git a/pkg/sql/vecindex/vector_index_test.go b/pkg/sql/vecindex/vector_index_test.go new file mode 100644 index 000000000000..bae2f5e73992 --- /dev/null +++ b/pkg/sql/vecindex/vector_index_test.go @@ -0,0 +1,562 @@ +// Copyright 2024 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package vecindex + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "strconv" + "strings" + "testing" + + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/internal" + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/quantize" + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/testutils" + "github.com/cockroachdb/cockroach/pkg/sql/vecindex/vecstore" + "github.com/cockroachdb/cockroach/pkg/util/num32" + "github.com/cockroachdb/cockroach/pkg/util/vector" + "github.com/cockroachdb/datadriven" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/require" +) + +func TestDataDriven(t *testing.T) { + ctx := internal.WithWorkspace(context.Background(), &internal.Workspace{}) + state := testState{T: t, Ctx: ctx} + datadriven.Walk(t, "testdata", func(t *testing.T, path string) { + if !strings.HasSuffix(path, ".ddt") { + // Skip files that are not data-driven tests. + return + } + datadriven.RunTest(t, path, func(t *testing.T, d *datadriven.TestData) string { + switch d.Cmd { + case "new-index": + return state.NewIndex(d) + + case "format-tree": + return state.FormatTree(d) + + case "search": + return state.Search(d) + + case "delete": + return state.Delete(d) + } + + t.Fatalf("unknown cmd: %s", d.Cmd) + return "" + }) + }) +} + +type testState struct { + T *testing.T + Ctx context.Context + Quantizer quantize.Quantizer + InMemStore *vecstore.InMemoryStore + Index *VectorIndex + Features vector.Set +} + +func (s *testState) NewIndex(d *datadriven.TestData) string { + var err error + dims := 2 + hideTree := false + count := 0 + options := VectorIndexOptions{} + for _, arg := range d.CmdArgs { + switch arg.Key { + case "min-partition-size": + require.Len(s.T, arg.Vals, 1) + options.MinPartitionSize, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "max-partition-size": + require.Len(s.T, arg.Vals, 1) + options.MaxPartitionSize, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "quality-samples": + require.Len(s.T, arg.Vals, 1) + options.QualitySamples, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "load-features": + require.Len(s.T, arg.Vals, 1) + count, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "dims": + require.Len(s.T, arg.Vals, 1) + dims, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "beam-size": + require.Len(s.T, arg.Vals, 1) + options.BaseBeamSize, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "hide-tree": + require.Len(s.T, arg.Vals, 0) + hideTree = true + } + } + + s.Quantizer = quantize.NewRaBitQuantizer(dims, 42) + s.InMemStore = vecstore.NewInMemoryStore(dims, 42) + s.Index, err = NewVectorIndex(s.Ctx, s.InMemStore, s.Quantizer, &options) + require.NoError(s.T, err) + + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) + defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) + + // Insert empty root partition. + require.NoError(s.T, s.Index.CreateRoot(s.Ctx, txn)) + + vectors := vector.MakeSet(dims) + childKeys := make([]vecstore.ChildKey, 0, count) + if count != 0 { + // Load features. + s.Features = testutils.LoadFeatures(s.T, 10000) + vectors = s.Features + vectors.SplitAt(count) + for i := 0; i < count; i++ { + key := vecstore.PrimaryKey(fmt.Sprintf("vec%d", i)) + childKeys = append(childKeys, vecstore.ChildKey{PrimaryKey: key}) + } + } else { + // Parse vectors. + for _, line := range strings.Split(d.Input, "\n") { + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + parts := strings.Split(line, ":") + require.Len(s.T, parts, 2) + + vectors.Add(s.parseVector(parts[1])) + key := vecstore.PrimaryKey(parts[0]) + childKeys = append(childKeys, vecstore.ChildKey{PrimaryKey: key}) + } + } + + // Insert vectors into the store. + for i := 0; i < vectors.Count; i++ { + s.InMemStore.InsertVector(txn, childKeys[i].PrimaryKey, vectors.At(i)) + } + + // Build the tree, bottom-up. + s.buildTree(txn, vectors, childKeys, options.MaxPartitionSize) + + if hideTree { + return fmt.Sprintf("Created index with %d vectors with %d dimensions.\n", + vectors.Count, vectors.Dims) + } + + tree, err := s.Index.Format(s.Ctx, txn, FormatOptions{PrimaryKeyStrings: true}) + require.NoError(s.T, err) + return tree +} + +func (s *testState) FormatTree(d *datadriven.TestData) string { + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) + defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) + + tree, err := s.Index.Format(s.Ctx, txn, FormatOptions{PrimaryKeyStrings: true}) + require.NoError(s.T, err) + return tree +} + +func (s *testState) Search(d *datadriven.TestData) string { + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) + defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) + + var vector vector.T + searchSet := vecstore.SearchSet{MaxResults: 1} + options := SearchOptions{} + + var err error + for _, arg := range d.CmdArgs { + switch arg.Key { + case "use-feature": + require.Len(s.T, arg.Vals, 1) + offset, err := strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + vector = s.Features.At(offset) + + case "max-results": + require.Len(s.T, arg.Vals, 1) + searchSet.MaxResults, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "beam-size": + require.Len(s.T, arg.Vals, 1) + options.BaseBeamSize, err = strconv.Atoi(arg.Vals[0]) + require.NoError(s.T, err) + + case "skip-rerank": + require.Len(s.T, arg.Vals, 0) + options.SkipRerank = true + } + } + + if vector == nil { + // Parse input as the vector to search for. + vector = s.parseVector(d.Input) + } + + err = s.Index.Search(s.Ctx, txn, vector, &searchSet, options) + require.NoError(s.T, err) + + var buf bytes.Buffer + results := searchSet.PopResults() + for i := range results { + result := &results[i] + var errorBound string + if result.ErrorBound != 0 { + errorBound = fmt.Sprintf("±%s ", formatFloat(result.ErrorBound)) + } + fmt.Fprintf(&buf, "%s: %s %s(centroid=%s)\n", + string(result.ChildKey.PrimaryKey), formatFloat(result.QuerySquaredDistance), + errorBound, formatFloat(result.CentroidDistance)) + } + + buf.WriteString(fmt.Sprintf("%d leaf vectors, ", searchSet.Stats.QuantizedLeafVectorCount)) + buf.WriteString(fmt.Sprintf("%d vectors, ", searchSet.Stats.QuantizedVectorCount)) + buf.WriteString(fmt.Sprintf("%d full vectors, ", searchSet.Stats.FullVectorCount)) + buf.WriteString(fmt.Sprintf("%d partitions", searchSet.Stats.PartitionCount)) + + return buf.String() +} + +func (s *testState) Delete(d *datadriven.TestData) string { + notFound := false + for _, arg := range d.CmdArgs { + switch arg.Key { + case "not-found": + require.Len(s.T, arg.Vals, 0) + notFound = true + } + } + + txn := beginTransaction(s.Ctx, s.T, s.InMemStore) + defer commitTransaction(s.Ctx, s.T, s.InMemStore, txn) + + // Get root in order to acquire partition lock. + _, err := s.InMemStore.GetPartition(s.Ctx, txn, vecstore.RootKey) + require.NoError(s.T, err) + + if notFound { + for _, line := range strings.Split(d.Input, "\n") { + line = strings.TrimSpace(line) + if len(line) == 0 { + continue + } + + // Simulate case where the vector is deleted in the primary index, but + // it cannot be found in the secondary index. + s.InMemStore.DeleteVector(txn, []byte(line)) + } + } + + // TODO(andyk): Add code to delete vector from index. + + tree, err := s.Index.Format(s.Ctx, txn, FormatOptions{PrimaryKeyStrings: true}) + require.NoError(s.T, err) + return tree +} + +// buildTree uses the K-means++ algorithm to build a K-means tree. Unlike the +// incremental algorithm, this builds the tree from the complete set of initial +// vectors. To start, the leaf level is built from the input vectors, with the +// number of partitions derived from "maxPartitionSize". Once the leaf level has +// been partitioned, the next higher level is built from the centroids of the +// leaf partitions. And so on, up to the root of the tree. +// +// TODO(andyk): Use the incremental algorithm instead, once it's ready. This +// alternate implementation is useful for testing and benchmarking. How much +// more accurate is it than the incremental version? +func (s *testState) buildTree( + txn vecstore.Txn, vectors vector.Set, childKeys []vecstore.ChildKey, maxPartitionSize int, +) { + rng := rand.New(rand.NewSource(42)) + level := vecstore.LeafLevel + + // Randomize vectors. + randomized := vector.MakeSet(vectors.Dims) + randomized.AddUndefined(vectors.Count) + for i := 0; i < vectors.Count; i++ { + s.Quantizer.RandomizeVector(s.Ctx, vectors.At(i), randomized.At(i), false /* invert */) + } + + // Partition each level of the tree. + for randomized.Count > maxPartitionSize { + n := randomized.Count * 2 / maxPartitionSize + randomized, childKeys = s.partitionVectors(txn, level, randomized, childKeys, n, rng) + level++ + } + + unQuantizer := quantize.NewUnQuantizer(randomized.Dims) + quantizedSet := unQuantizer.Quantize(s.Ctx, &randomized) + root := vecstore.NewPartition(unQuantizer, quantizedSet, childKeys, level) + err := s.InMemStore.SetRootPartition(s.Ctx, txn, root) + require.NoError(s.T, err) +} + +// partitionVectors partitions the given full-size vectors at one level of the +// tree using the K-means++ algorithm. +func (s *testState) partitionVectors( + txn vecstore.Txn, + level vecstore.Level, + vectors vector.Set, + childKeys []vecstore.ChildKey, + numPartitions int, + rng *rand.Rand, +) (centroids vector.Set, partitionKeys []vecstore.ChildKey) { + centroids = vector.MakeSet(vectors.Dims) + centroids.AddUndefined(numPartitions) + partitionKeys = make([]vecstore.ChildKey, numPartitions) + + // Run K-means on the input vectors. + km := kmeans{Rand: rng} + partitionOffsets := make([]uint64, vectors.Count) + km.Partition(s.Ctx, &vectors, ¢roids, partitionOffsets) + + // Construct the partitions and insert them into the store. + tempVectors := vector.MakeSet(vectors.Dims) + for partitionIdx := 0; partitionIdx < numPartitions; partitionIdx++ { + var partitionChildKeys []vecstore.ChildKey + for vectorIdx := 0; vectorIdx < vectors.Count; vectorIdx++ { + if partitionIdx == int(partitionOffsets[vectorIdx]) { + tempVectors.Add(vectors.At(vectorIdx)) + partitionChildKeys = append(partitionChildKeys, childKeys[vectorIdx]) + } + } + + quantizedSet := s.Quantizer.Quantize(s.Ctx, &tempVectors) + partition := vecstore.NewPartition(s.Quantizer, quantizedSet, partitionChildKeys, level) + + partitionKey, err := s.InMemStore.InsertPartition(s.Ctx, txn, partition) + require.NoError(s.T, err) + partitionKeys[partitionIdx] = vecstore.ChildKey{PartitionKey: partitionKey} + + tempVectors.Clear() + } + + return centroids, partitionKeys +} + +// parseVector parses a vector string in this form: (1.5, 6, -4). +func (s *testState) parseVector(str string) vector.T { + // Remove parentheses and split by commas. + str = strings.TrimSpace(str) + str = strings.TrimPrefix(str, "(") + str = strings.TrimSuffix(str, ")") + elems := strings.Split(str, ",") + + // Construct the vector. + vector := make(vector.T, len(elems)) + for i, elem := range elems { + elem = strings.TrimSpace(elem) + value, err := strconv.ParseFloat(elem, 32) + require.NoError(s.T, err) + vector[i] = float32(value) + } + + return vector +} + +func formatFloat(value float32) string { + s := strconv.FormatFloat(float64(value), 'f', 4, 32) + if strings.Contains(s, ".") { + s = strings.TrimRight(s, "0") + s = strings.TrimRight(s, ".") + } + return s +} + +// kmeans implements the K-means++ algorithm: +// http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf +type kmeans struct { + MaxIterations int + Rand *rand.Rand + + workspace *internal.Workspace + vectors *vector.Set + oldCentroids *vector.Set + newCentroids *vector.Set + partitions []uint64 +} + +// Partition divides the input vectors into partitions. The caller is expected +// to allocate the "centroids" set with length equal to the desired number of +// partitions. Partition will write the centroid of each calculated partition +// into the set. In addition, the caller allocates the "partitions" slice with +// length equal to the number of input vectors. For each input vector, Partition +// will write the index of its partition into the corresponding entry in +// "partitions". +func (km *kmeans) Partition( + ctx context.Context, vectors *vector.Set, centroids *vector.Set, partitions []uint64, +) { + if vectors.Count != len(partitions) { + panic(errors.AssertionFailedf("vector count %d cannot be different than partitions length %d", + vectors.Count, len(partitions))) + } + + km.workspace = internal.WorkspaceFromContext(ctx) + km.vectors = vectors + km.newCentroids = centroids + km.partitions = partitions + + tempOldCentroids := km.workspace.AllocVectorSet(centroids.Count, vectors.Dims) + defer km.workspace.FreeVectorSet(tempOldCentroids) + km.oldCentroids = &tempOldCentroids + + km.selectInitialCentroids() + + maxIterations := km.MaxIterations + if maxIterations == 0 { + maxIterations = 32 + } + + for i := 0; i < maxIterations; i++ { + km.computeNewCentroids() + + // Check if algorithm has converged. + done := true + for centroidIdx := 0; centroidIdx < km.oldCentroids.Count; centroidIdx++ { + distance := num32.L2SquaredDistance( + km.oldCentroids.At(centroidIdx), km.newCentroids.At(centroidIdx)) + if distance > 1e-4 { + done = false + break + } + } + if done { + break + } + + // Swap old and new centroid slices. + km.oldCentroids, km.newCentroids = km.newCentroids, km.oldCentroids + + // Re-assign vectors to one of the partitions. + km.assignPartitions() + } +} + +// assignPartitions re-assigns each input vector to the partition with the +// closest centroid in "km.oldCentroids". +func (km *kmeans) assignPartitions() { + vectorCount := km.vectors.Count + centroidCount := km.oldCentroids.Count + + // Add vectors in each partition. + for vecIdx := 0; vecIdx < vectorCount; vecIdx++ { + var shortest float32 + shortestIdx := -1 + for centroidIdx := 0; centroidIdx < centroidCount; centroidIdx++ { + distance := num32.L2SquaredDistance(km.vectors.At(vecIdx), km.oldCentroids.At(centroidIdx)) + if shortestIdx == -1 || distance < shortest { + shortest = distance + shortestIdx = centroidIdx + } + } + km.partitions[vecIdx] = uint64(shortestIdx) + } +} + +// computeNewCentroids calculates a new centroid for each partition from the +// vectors that have been assigned to that partition, and stores the resulting +// centroids in "km.newCentroids". +func (km *kmeans) computeNewCentroids() { + centroidCount := km.newCentroids.Count + vectorCount := km.vectors.Count + + tempPartitionCounts := km.workspace.AllocUint64s(centroidCount) + defer km.workspace.FreeUint64s(tempPartitionCounts) + for i := 0; i < centroidCount; i++ { + tempPartitionCounts[i] = 0 + } + + // Calculate new centroids. + num32.Zero(km.newCentroids.Data) + for vecIdx := 0; vecIdx < vectorCount; vecIdx++ { + centroidIdx := int(km.partitions[vecIdx]) + num32.Add(km.newCentroids.At(centroidIdx), km.vectors.At(vecIdx)) + tempPartitionCounts[centroidIdx]++ + } + + // Divide each centroid by the count of vectors in its partition. + for centroidIdx := 0; centroidIdx < centroidCount; centroidIdx++ { + num32.Scale(1.0/float32(tempPartitionCounts[centroidIdx]), km.newCentroids.At(centroidIdx)) + } +} + +// selectInitialCentroids sets "km.oldCentroids" to random input vectors chosen +// using the K-means++ algorithm. +func (km *kmeans) selectInitialCentroids() { + count := km.vectors.Count + tempVectorDistances := km.workspace.AllocFloats(count) + defer km.workspace.FreeFloats(tempVectorDistances) + + // Randomly select the first centroid from the vector set. + var offset int + if km.Rand != nil { + offset = km.Rand.Intn(count) + } else { + offset = rand.Intn(count) + } + copy(km.oldCentroids.At(0), km.vectors.At(offset)) + + selected := 0 + for selected < km.oldCentroids.Count { + // Calculate shortest distance from each vector to one of the already + // selected centroids. + var distanceSum float32 + for vecIdx := 0; vecIdx < count; vecIdx++ { + distance := num32.L2SquaredDistance(km.vectors.At(vecIdx), km.oldCentroids.At(selected)) + if selected == 0 || distance < tempVectorDistances[vecIdx] { + tempVectorDistances[vecIdx] = distance + km.partitions[vecIdx] = uint64(selected) + } + distanceSum += tempVectorDistances[vecIdx] + } + + // Calculate probability of each vector becoming the next centroid, with + // the probability being proportional to the vector's shortest distance + // to one of the already selected centroids. + var cum, rnd float32 + if km.Rand != nil { + rnd = km.Rand.Float32() * distanceSum + } else { + rnd = rand.Float32() * distanceSum + } + offset = 0 + for offset < len(tempVectorDistances) { + cum += tempVectorDistances[offset] + if rnd < cum { + break + } + offset++ + } + + selected++ + copy(km.oldCentroids.At(selected), km.vectors.At(offset)) + } +} + +func beginTransaction(ctx context.Context, t *testing.T, store vecstore.Store) vecstore.Txn { + txn, err := store.BeginTransaction(ctx) + require.NoError(t, err) + return txn +} + +func commitTransaction(ctx context.Context, t *testing.T, store vecstore.Store, txn vecstore.Txn) { + err := store.CommitTransaction(ctx, txn) + require.NoError(t, err) +} diff --git a/pkg/util/vector/BUILD.bazel b/pkg/util/vector/BUILD.bazel index 94b392950481..12acda3e4e60 100644 --- a/pkg/util/vector/BUILD.bazel +++ b/pkg/util/vector/BUILD.bazel @@ -14,6 +14,7 @@ go_library( deps = [ "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", + "//pkg/util/buildutil", "//pkg/util/encoding", "//pkg/util/num32", "@com_github_cockroachdb_errors//:errors", diff --git a/pkg/util/vector/vector_set.go b/pkg/util/vector/vector_set.go index 5073d87a15d7..62271ae633d9 100644 --- a/pkg/util/vector/vector_set.go +++ b/pkg/util/vector/vector_set.go @@ -8,6 +8,7 @@ package vector import ( "slices" + "github.com/cockroachdb/cockroach/pkg/util/buildutil" "github.com/cockroachdb/cockroach/pkg/util/num32" "github.com/cockroachdb/errors" ) @@ -47,7 +48,9 @@ func (vs *Set) AsMatrix() num32.Matrix { } } -// At returns the vector at the given offset in the set. +// At returns the vector at the given offset in the set. The returned vector is +// intended for transient use, since mutations to the vector set can invalidate +// the reference. // //gcassert:inline func (vs *Set) At(offset int) T { @@ -106,6 +109,24 @@ func (vs *Set) AddUndefined(count int) { vs.Data = slices.Grow(vs.Data, count*vs.Dims) vs.Count += count vs.Data = vs.Data[:vs.Count*vs.Dims] + if buildutil.CrdbTestBuild { + // Write non-zero values to undefined memory. + for i := len(vs.Data) - count*vs.Dims; i < len(vs.Data); i++ { + vs.Data[i] = 0xBADF00D + } + } +} + +// Clear empties the set so that it has zero vectors. +func (vs *Set) Clear() { + if buildutil.CrdbTestBuild { + // Write non-zero values to cleared memory. + for i := 0; i < len(vs.Data); i++ { + vs.Data[i] = 0xBADF00D + } + } + vs.Data = vs.Data[:0] + vs.Count = 0 } // ReplaceWithLast removes the vector at the given offset from the set, @@ -115,6 +136,12 @@ func (vs *Set) ReplaceWithLast(offset int) { targetStart := offset * vs.Dims sourceEnd := len(vs.Data) copy(vs.Data[targetStart:targetStart+vs.Dims], vs.Data[sourceEnd-vs.Dims:sourceEnd]) + if buildutil.CrdbTestBuild { + // Write non-zero values to undefined memory. + for i := sourceEnd - vs.Dims; i < sourceEnd; i++ { + vs.Data[i] = 0xBADF00D + } + } vs.Data = vs.Data[:sourceEnd-vs.Dims] vs.Count-- } diff --git a/pkg/util/vector/vector_set_test.go b/pkg/util/vector/vector_set_test.go index d51d36f2ecb6..8063acf6e931 100644 --- a/pkg/util/vector/vector_set_test.go +++ b/pkg/util/vector/vector_set_test.go @@ -51,6 +51,11 @@ func TestVectorSet(t *testing.T) { require.Equal(t, 5, vs.Count) require.Equal(t, []float32{1, 2, 4, 4, 6, 6, 1, 2, 3, 1}, vs.Data) + // Clear. + vs.Clear() + require.Equal(t, 2, vs.Dims) + require.Equal(t, 0, vs.Count) + vs3 := MakeSetFromRawData(vs.Data, 2) require.Equal(t, vs, vs3)