-
-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
728 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
// Copyright 2024 Dolthub, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package prolly | ||
|
||
import ( | ||
"context" | ||
"github.com/dolthub/dolt/go/store/hash" | ||
"github.com/dolthub/dolt/go/store/prolly/message" | ||
"github.com/dolthub/dolt/go/store/prolly/tree" | ||
"github.com/dolthub/dolt/go/store/val" | ||
"github.com/dolthub/go-mysql-server/sql" | ||
"github.com/dolthub/go-mysql-server/sql/expression" | ||
) | ||
|
||
// ProximityMap wraps a tree.ProximityMap but operates on typed Tuples instead of raw bytestrings. | ||
type ProximityMap struct { | ||
tuples tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc] | ||
keyDesc val.TupleDesc | ||
valDesc val.TupleDesc | ||
ctx context.Context | ||
} | ||
|
||
// NewProximityMap creates an empty prolly Tree Map | ||
func NewProximityMap(ctx context.Context, node tree.Node, ns tree.NodeStore, keyDesc val.TupleDesc, valDesc val.TupleDesc) ProximityMap { | ||
tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{ | ||
Root: node, | ||
NodeStore: ns, | ||
Order: keyDesc, | ||
DistanceType: expression.DistanceL2Squared{}, | ||
Convert: func(bytes []byte) []float64 { | ||
h, _ := keyDesc.GetJSONAddr(0, bytes) | ||
doc := tree.NewJSONDoc(h, ns) | ||
jsonWrapper, err := doc.ToIndexedJSONDocument(ctx) | ||
if err != nil { | ||
panic(err) | ||
} | ||
floats, err := sql.ConvertToVector(jsonWrapper) | ||
if err != nil { | ||
panic(err) | ||
} | ||
return floats | ||
}, | ||
} | ||
return ProximityMap{ | ||
tuples: tuples, | ||
keyDesc: keyDesc, | ||
valDesc: valDesc, | ||
} | ||
} | ||
|
||
type VectorIter interface { | ||
Next(ctx context.Context) (k interface{}, v val.Tuple) | ||
} | ||
|
||
func NewProximityMapFromTupleIter(ctx context.Context, ns tree.NodeStore, distanceType expression.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, keys []val.Tuple, values []val.Tuple, logChunkSize uint8) (ProximityMap, error) { | ||
serializer := message.NewVectorIndexSerializer(ns.Pool()) | ||
ch, err := tree.NewChunkerWithDeterministicSplitter(ctx, nil, 0, ns, serializer, logChunkSize) | ||
|
||
if err != nil { | ||
return ProximityMap{}, err | ||
} | ||
|
||
for i := 0; i < len(keys); i++ { | ||
if err = ch.AddPair(ctx, tree.Item(keys[i]), tree.Item(values[i])); err != nil { | ||
return ProximityMap{}, err | ||
} | ||
} | ||
|
||
root, err := ch.Done(ctx) | ||
if err != nil { | ||
return ProximityMap{}, err | ||
} | ||
|
||
// We now have a map where each node is at the right level, but now we need to sort it. | ||
|
||
getHash := func(tuple []byte) hash.Hash { | ||
h, _ := keyDesc.GetJSONAddr(0, tuple) | ||
return h | ||
} | ||
newRoot, err := tree.FixupProximityMap[val.Tuple, val.TupleDesc](ctx, ns, distanceType, root, getHash, keyDesc) | ||
if err != nil { | ||
return ProximityMap{}, err | ||
} | ||
|
||
return NewProximityMap(ctx, newRoot, ns, keyDesc, valDesc), nil | ||
} | ||
|
||
// Count returns the number of key-value pairs in the Map. | ||
func (m ProximityMap) Count() (int, error) { | ||
return m.tuples.Count() | ||
} | ||
|
||
// Get searches for the key-value pair keyed by |key| and passes the results to the callback. | ||
// If |key| is not present in the map, a nil key-value pair are passed. | ||
func (m ProximityMap) Get(ctx context.Context, query interface{}, cb tree.KeyValueFn[val.Tuple, val.Tuple]) (err error) { | ||
return m.tuples.GetExact(ctx, query, cb) | ||
} | ||
|
||
func (m ProximityMap) GetClosest(ctx context.Context, query interface{}, cb tree.KeyValueDistanceFn[val.Tuple, val.Tuple], limit int) (err error) { | ||
return m.tuples.GetClosest(ctx, query, cb, limit) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
// Copyright 2024 Dolthub, Inc. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
package prolly | ||
|
||
import ( | ||
"context" | ||
"github.com/dolthub/dolt/go/store/hash" | ||
"github.com/dolthub/dolt/go/store/pool" | ||
"github.com/dolthub/dolt/go/store/prolly/tree" | ||
"github.com/dolthub/dolt/go/store/val" | ||
"github.com/dolthub/go-mysql-server/sql" | ||
"github.com/dolthub/go-mysql-server/sql/expression" | ||
"github.com/dolthub/go-mysql-server/sql/types" | ||
"github.com/stretchr/testify/require" | ||
"testing" | ||
) | ||
|
||
func newJsonValue(t *testing.T, v interface{}) sql.JSONWrapper { | ||
doc, _, err := types.JSON.Convert(v) | ||
require.NoError(t, err) | ||
return doc.(sql.JSONWrapper) | ||
} | ||
|
||
// newJsonDocument creates a JSON value from a provided value. | ||
func newJsonDocument(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) hash.Hash { | ||
doc := newJsonValue(t, v) | ||
root, err := tree.SerializeJsonToAddr(ctx, ns, doc) | ||
require.NoError(t, err) | ||
return root.HashOf() | ||
} | ||
|
||
func createProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, vectors []interface{}, pks []int64, logChunkSize uint8) (ProximityMap, []val.Tuple, []val.Tuple) { | ||
bp := pool.NewBuffPool() | ||
|
||
count := len(vectors) | ||
require.Equal(t, count, len(pks)) | ||
|
||
kd := val.NewTupleDescriptor( | ||
val.Type{Enc: val.JSONAddrEnc, Nullable: true}, | ||
) | ||
|
||
vd := val.NewTupleDescriptor( | ||
val.Type{Enc: val.Int64Enc, Nullable: true}, | ||
) | ||
|
||
distanceType := expression.DistanceL2Squared{} | ||
|
||
keys := make([]val.Tuple, count) | ||
keyBuilder := val.NewTupleBuilder(kd) | ||
for i, vector := range vectors { | ||
keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, vector)) | ||
keys[i] = keyBuilder.Build(bp) | ||
} | ||
|
||
valueBuilder := val.NewTupleBuilder(vd) | ||
values := make([]val.Tuple, count) | ||
for i, pk := range pks { | ||
valueBuilder.PutInt64(0, pk) | ||
values[i] = valueBuilder.Build(bp) | ||
} | ||
|
||
m, err := NewProximityMapFromTupleIter(ctx, ns, distanceType, kd, vd, keys, values, logChunkSize) | ||
require.NoError(t, err) | ||
mapCount, err := m.Count() | ||
require.NoError(t, err) | ||
require.Equal(t, count, mapCount) | ||
|
||
return m, keys, values | ||
} | ||
|
||
func TestEmptyProximityMap(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
createProximityMap(t, ctx, ns, nil, nil, 10) | ||
} | ||
|
||
func TestSingleEntryProximityMap(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[1.0]"}, []int64{1}, 10) | ||
matches := 0 | ||
vectorHash, _ := m.keyDesc.GetJSONAddr(0, keys[0]) | ||
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) | ||
require.NoError(t, err) | ||
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { | ||
require.Equal(t, keys[0], foundKey) | ||
require.Equal(t, values[0], foundValue) | ||
matches++ | ||
return nil | ||
}) | ||
require.NoError(t, err) | ||
require.Equal(t, matches, 1) | ||
} | ||
|
||
func TestDoubleEntryProximityMapGetExact(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) | ||
matches := 0 | ||
for i, key := range keys { | ||
vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) | ||
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) | ||
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { | ||
require.Equal(t, key, foundKey) | ||
require.Equal(t, values[i], foundValue) | ||
matches++ | ||
return nil | ||
}) | ||
require.NoError(t, err) | ||
} | ||
require.Equal(t, matches, len(keys)) | ||
} | ||
|
||
func TestDoubleEntryProximityMapGetClosest(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) | ||
matches := 0 | ||
|
||
cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error { | ||
require.Equal(t, keys[1], foundKey) | ||
require.Equal(t, values[1], foundValue) | ||
require.InDelta(t, distance, 25.0, 0.1) | ||
matches++ | ||
return nil | ||
} | ||
|
||
err := m.GetClosest(ctx, newJsonValue(t, "[0.0, 0.0]"), cb, 1) | ||
require.NoError(t, err) | ||
require.Equal(t, matches, 1) | ||
} | ||
|
||
func TestMultilevelProximityMap(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
keyStrings := []interface{}{ | ||
"[0.0, 1.0]", | ||
"[3.0, 4.0]", | ||
"[5.0, 6.0]", | ||
"[7.0, 8.0]", | ||
} | ||
valueStrings := []int64{1, 2, 3, 4} | ||
m, keys, values := createProximityMap(t, ctx, ns, keyStrings, valueStrings, 1) | ||
matches := 0 | ||
for i, key := range keys { | ||
vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) | ||
vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) | ||
require.NoError(t, err) | ||
err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { | ||
require.Equal(t, key, foundKey) | ||
require.Equal(t, values[i], foundValue) | ||
matches++ | ||
return nil | ||
}) | ||
require.NoError(t, err) | ||
} | ||
require.Equal(t, matches, len(keys)) | ||
} | ||
|
||
func TestInsertOrderIndependence(t *testing.T) { | ||
ctx := context.Background() | ||
ns := tree.NewTestNodeStore() | ||
keyStrings1 := []interface{}{ | ||
"[0.0, 1.0]", | ||
"[3.0, 4.0]", | ||
"[5.0, 6.0]", | ||
"[7.0, 8.0]", | ||
} | ||
valueStrings1 := []int64{1, 2, 3, 4} | ||
keyStrings2 := []interface{}{ | ||
"[7.0, 8.0]", | ||
"[5.0, 6.0]", | ||
"[3.0, 4.0]", | ||
"[0.0, 1.0]", | ||
} | ||
valueStrings2 := []int64{4, 3, 2, 1} | ||
m1, _, _ := createProximityMap(t, ctx, ns, keyStrings1, valueStrings1, 1) | ||
m2, _, _ := createProximityMap(t, ctx, ns, keyStrings2, valueStrings2, 1) | ||
require.Equal(t, m1.tuples.Root.HashOf(), m2.tuples.Root.HashOf()) | ||
} |
Oops, something went wrong.