diff --git a/go/gen/fb/serial/fileidentifiers.go b/go/gen/fb/serial/fileidentifiers.go index 64cb36d9f4..787247159b 100644 --- a/go/gen/fb/serial/fileidentifiers.go +++ b/go/gen/fb/serial/fileidentifiers.go @@ -41,6 +41,7 @@ const StashListFileID = "SLST" const StashFileID = "STSH" const StatisticFileID = "STAT" const DoltgresRootValueFileID = "DGRV" +const VectorIndexNodeFileID = "IVFF" const MessageTypesKind int = 27 diff --git a/go/gen/fb/serial/vectorindexnode.go b/go/gen/fb/serial/vectorindexnode.go new file mode 100644 index 0000000000..0c98242919 --- /dev/null +++ b/go/gen/fb/serial/vectorindexnode.go @@ -0,0 +1,316 @@ +// Copyright 2022-2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated by the FlatBuffers compiler. DO NOT EDIT. + +package serial + +import ( + flatbuffers "github.com/dolthub/flatbuffers/v23/go" +) + +type VectorIndexNode struct { + _tab flatbuffers.Table +} + +func InitVectorIndexNodeRoot(o *VectorIndexNode, buf []byte, offset flatbuffers.UOffsetT) error { + n := flatbuffers.GetUOffsetT(buf[offset:]) + return o.Init(buf, n+offset) +} + +func TryGetRootAsVectorIndexNode(buf []byte, offset flatbuffers.UOffsetT) (*VectorIndexNode, error) { + x := &VectorIndexNode{} + return x, InitVectorIndexNodeRoot(x, buf, offset) +} + +func TryGetSizePrefixedRootAsVectorIndexNode(buf []byte, offset flatbuffers.UOffsetT) (*VectorIndexNode, error) { + x := &VectorIndexNode{} + return x, InitVectorIndexNodeRoot(x, buf, offset+flatbuffers.SizeUint32) +} + +func (rcv *VectorIndexNode) Init(buf []byte, i flatbuffers.UOffsetT) error { + rcv._tab.Bytes = buf + rcv._tab.Pos = i + if VectorIndexNodeNumFields < rcv.Table().NumFields() { + return flatbuffers.ErrTableHasUnknownFields + } + return nil +} + +func (rcv *VectorIndexNode) Table() flatbuffers.Table { + return rcv._tab +} + +func (rcv *VectorIndexNode) KeyItems(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyItemsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyItemsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateKeyItems(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(4)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) KeyOffsets(j int) uint16 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetUint16(a + flatbuffers.UOffsetT(j*2)) + } + return 0 +} + +func (rcv *VectorIndexNode) KeyOffsetsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateKeyOffsets(j int, n uint16) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(6)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateUint16(a+flatbuffers.UOffsetT(j*2), n) + } + return false +} + +func (rcv *VectorIndexNode) ValueItems(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueItemsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueItemsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateValueItems(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(8)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) ValueOffsets(j int) uint16 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetUint16(a + flatbuffers.UOffsetT(j*2)) + } + return 0 +} + +func (rcv *VectorIndexNode) ValueOffsetsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateValueOffsets(j int, n uint16) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(10)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateUint16(a+flatbuffers.UOffsetT(j*2), n) + } + return false +} + +func (rcv *VectorIndexNode) AddressArray(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) AddressArrayLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) AddressArrayBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateAddressArray(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(12)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) SubtreeCounts(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *VectorIndexNode) SubtreeCountsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *VectorIndexNode) SubtreeCountsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *VectorIndexNode) MutateSubtreeCounts(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(14)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +func (rcv *VectorIndexNode) TreeCount() uint64 { + o := flatbuffers.UOffsetT(rcv._tab.Offset(16)) + if o != 0 { + return rcv._tab.GetUint64(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateTreeCount(n uint64) bool { + return rcv._tab.MutateUint64Slot(16, n) +} + +func (rcv *VectorIndexNode) TreeLevel() byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.GetByte(o + rcv._tab.Pos) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateTreeLevel(n byte) bool { + return rcv._tab.MutateByteSlot(18, n) +} + +const VectorIndexNodeNumFields = 8 + +func VectorIndexNodeStart(builder *flatbuffers.Builder) { + builder.StartObject(VectorIndexNodeNumFields) +} +func VectorIndexNodeAddKeyItems(builder *flatbuffers.Builder, keyItems flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(keyItems), 0) +} +func VectorIndexNodeStartKeyItemsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddKeyOffsets(builder *flatbuffers.Builder, keyOffsets flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(keyOffsets), 0) +} +func VectorIndexNodeStartKeyOffsetsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(2, numElems, 2) +} +func VectorIndexNodeAddValueItems(builder *flatbuffers.Builder, valueItems flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(valueItems), 0) +} +func VectorIndexNodeStartValueItemsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddValueOffsets(builder *flatbuffers.Builder, valueOffsets flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(valueOffsets), 0) +} +func VectorIndexNodeStartValueOffsetsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(2, numElems, 2) +} +func VectorIndexNodeAddAddressArray(builder *flatbuffers.Builder, addressArray flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(addressArray), 0) +} +func VectorIndexNodeStartAddressArrayVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddSubtreeCounts(builder *flatbuffers.Builder, subtreeCounts flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(subtreeCounts), 0) +} +func VectorIndexNodeStartSubtreeCountsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func VectorIndexNodeAddTreeCount(builder *flatbuffers.Builder, treeCount uint64) { + builder.PrependUint64Slot(6, treeCount, 0) +} +func VectorIndexNodeAddTreeLevel(builder *flatbuffers.Builder, treeLevel byte) { + builder.PrependByteSlot(7, treeLevel, 0) +} +func VectorIndexNodeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { + return builder.EndObject() +} diff --git a/go/go.mod b/go/go.mod index 0615498e9a..48e8433d66 100644 --- a/go/go.mod +++ b/go/go.mod @@ -169,4 +169,4 @@ require ( replace github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi => ./gen/proto/dolt/services/eventsapi -go 1.22.2 +go 1.23.0 diff --git a/go/go.work b/go/go.work index 0535a0ff63..ee98ff0bcb 100644 --- a/go/go.work +++ b/go/go.work @@ -1,6 +1,6 @@ -go 1.22.5 +go 1.23.0 -toolchain go1.22.7 +toolchain go1.23.2 use ( . diff --git a/go/serial/fileidentifiers.go b/go/serial/fileidentifiers.go index 64cb36d9f4..787247159b 100644 --- a/go/serial/fileidentifiers.go +++ b/go/serial/fileidentifiers.go @@ -41,6 +41,7 @@ const StashListFileID = "SLST" const StashFileID = "STSH" const StatisticFileID = "STAT" const DoltgresRootValueFileID = "DGRV" +const VectorIndexNodeFileID = "IVFF" const MessageTypesKind int = 27 diff --git a/go/serial/generate.sh b/go/serial/generate.sh index 267d926bcb..b7c9010ad1 100755 --- a/go/serial/generate.sh +++ b/go/serial/generate.sh @@ -37,7 +37,8 @@ fi stat.fbs \ table.fbs \ tag.fbs \ - workingset.fbs + workingset.fbs \ + vectorindexnode.fbs # prefix files with copyright header for FILE in $GEN_DIR/*.go; diff --git a/go/serial/vectorindexnode.fbs b/go/serial/vectorindexnode.fbs new file mode 100644 index 0000000000..ee362b20f5 --- /dev/null +++ b/go/serial/vectorindexnode.fbs @@ -0,0 +1,57 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +namespace serial; + +// VectorIndexNode is a node that makes up a vector index. Every key contains a vector value, +// and keys are organized according to their proximity to their parent node. +// Currently, this is a subset of the fields used in ProllyMap. However, it uses its own message type +// to limit the risk that lookup/mutation algorithms written for ProllyMap nodes accidentally run on VectorIndex nodes. +table VectorIndexNode { + // sorted array of key items + key_items:[ubyte] (required); + // items offets for |key_items| + // first offset is 0, last offset is len(key_items) + key_offsets:[uint16] (required); + // item type for |key_items| + // key_type:ItemType; + + // array of values items, ordered by paired key + value_items:[ubyte]; + // item offsets for |value_items| + // first offset is 0, last offset is len(value_items) + value_offsets:[uint16]; + // item type for |value_items| + // value_type:ItemType; + + // array of chunk addresses + // - subtree addresses for internal prolly tree nodes + // - value addresses for AddressMap leaf nodes + address_array:[ubyte] (required); + + // array of varint encoded subtree counts + // see: go/store/prolly/message/varint.go + subtree_counts:[ubyte]; + // total count of prolly tree + tree_count:uint64; + // prolly tree level, 0 for leaf nodes + tree_level:uint8; +} + + +// KEEP THIS IN SYNC WITH fileidentifiers.go +file_identifier "IVFF"; + +root_type VectorIndexNode; + diff --git a/go/store/prolly/message/message.go b/go/store/prolly/message/message.go index 4554f5650f..19521d54a9 100644 --- a/go/store/prolly/message/message.go +++ b/go/store/prolly/message/message.go @@ -39,6 +39,8 @@ func UnpackFields(msg serial.Message) (keys, values ItemAccess, level, count uin switch serial.GetFileID(msg) { case serial.ProllyTreeNodeFileID: return getProllyMapKeysAndValues(msg) + case serial.VectorIndexNodeFileID: + return getVectorIndexKeysAndValues(msg) case serial.AddressMapFileID: keys, err = getAddressMapKeys(msg) if err != nil { @@ -96,6 +98,8 @@ func WalkAddresses(ctx context.Context, msg serial.Message, cb func(ctx context. switch id { case serial.ProllyTreeNodeFileID: return walkProllyMapAddresses(ctx, msg, cb) + case serial.VectorIndexNodeFileID: + return walkVectorIndexAddresses(ctx, msg, cb) case serial.AddressMapFileID: return walkAddressMapAddresses(ctx, msg, cb) case serial.MergeArtifactsFileID: @@ -114,6 +118,8 @@ func GetTreeCount(msg serial.Message) (int, error) { switch id { case serial.ProllyTreeNodeFileID: return getProllyMapTreeCount(msg) + case serial.VectorIndexNodeFileID: + return getVectorIndexTreeCount(msg) case serial.AddressMapFileID: return getAddressMapTreeCount(msg) case serial.MergeArtifactsFileID: diff --git a/go/store/prolly/message/vector_index.go b/go/store/prolly/message/vector_index.go new file mode 100644 index 0000000000..f8fa423164 --- /dev/null +++ b/go/store/prolly/message/vector_index.go @@ -0,0 +1,212 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message + +import ( + "context" + "encoding/binary" + "fmt" + + fb "github.com/dolthub/flatbuffers/v23/go" + + "github.com/dolthub/dolt/go/gen/fb/serial" + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" +) + +const ( + // These constants are mirrored from serial.VectorIndexNode + // They are only as stable as the flatbuffers schema that define them. + vectorIvfKeyItemBytesVOffset fb.VOffsetT = 4 + vectorIvfKeyOffsetsVOffset fb.VOffsetT = 6 + vectorIvfValueItemBytesVOffset fb.VOffsetT = 8 + vectorIvfValueOffsetsVOffset fb.VOffsetT = 10 + vectorIvfAddressArrayBytesVOffset fb.VOffsetT = 12 +) + +var vectorIvfFileID = []byte(serial.VectorIndexNodeFileID) + +func NewVectorIndexSerializer(pool pool.BuffPool) VectorIndexSerializer { + return VectorIndexSerializer{pool: pool} +} + +type VectorIndexSerializer struct { + pool pool.BuffPool +} + +var _ Serializer = VectorIndexSerializer{} + +func (s VectorIndexSerializer) Serialize(keys, values [][]byte, subtrees []uint64, level int) serial.Message { + var ( + keyTups, keyOffs fb.UOffsetT + valTups, valOffs fb.UOffsetT + refArr, cardArr fb.UOffsetT + ) + + keySz, valSz, bufSz := estimateVectorIndexSize(keys, values, subtrees) + b := getFlatbufferBuilder(s.pool, bufSz) + + // serialize keys and offStart + keyTups = writeItemBytes(b, keys, keySz) + serial.VectorIndexNodeStartKeyOffsetsVector(b, len(keys)+1) + keyOffs = writeItemOffsets(b, keys, keySz) + + if level == 0 { + // serialize value tuples for leaf nodes + valTups = writeItemBytes(b, values, valSz) + serial.VectorIndexNodeStartValueOffsetsVector(b, len(values)+1) + valOffs = writeItemOffsets(b, values, valSz) + } else { + // serialize child refs and subtree counts for internal nodes + refArr = writeItemBytes(b, values, valSz) + cardArr = writeCountArray(b, subtrees) + } + + // populate the node's vtable + serial.VectorIndexNodeStart(b) + serial.VectorIndexNodeAddKeyItems(b, keyTups) + serial.VectorIndexNodeAddKeyOffsets(b, keyOffs) + if level == 0 { + serial.VectorIndexNodeAddValueItems(b, valTups) + serial.VectorIndexNodeAddValueOffsets(b, valOffs) + serial.VectorIndexNodeAddTreeCount(b, uint64(len(keys))) + } else { + serial.VectorIndexNodeAddAddressArray(b, refArr) + serial.VectorIndexNodeAddSubtreeCounts(b, cardArr) + serial.VectorIndexNodeAddTreeCount(b, sumSubtrees(subtrees)) + } + serial.VectorIndexNodeAddTreeLevel(b, uint8(level)) + + return serial.FinishMessage(b, serial.VectorIndexNodeEnd(b), vectorIvfFileID) +} + +func getVectorIndexKeysAndValues(msg serial.Message) (keys, values ItemAccess, level, count uint16, err error) { + var pm serial.VectorIndexNode + err = serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return + } + keys.bufStart = lookupVectorOffset(vectorIvfKeyItemBytesVOffset, pm.Table()) + keys.bufLen = uint16(pm.KeyItemsLength()) + keys.offStart = lookupVectorOffset(vectorIvfKeyOffsetsVOffset, pm.Table()) + keys.offLen = uint16(pm.KeyOffsetsLength() * uint16Size) + + count = (keys.offLen / 2) - 1 + level = uint16(pm.TreeLevel()) + + vv := pm.ValueItemsBytes() + if vv != nil { + values.bufStart = lookupVectorOffset(vectorIvfValueItemBytesVOffset, pm.Table()) + values.bufLen = uint16(pm.ValueItemsLength()) + values.offStart = lookupVectorOffset(vectorIvfValueOffsetsVOffset, pm.Table()) + values.offLen = uint16(pm.ValueOffsetsLength() * uint16Size) + } else { + values.bufStart = lookupVectorOffset(vectorIvfAddressArrayBytesVOffset, pm.Table()) + values.bufLen = uint16(pm.AddressArrayLength()) + values.itemWidth = hash.ByteLen + } + return +} + +func walkVectorIndexAddresses(ctx context.Context, msg serial.Message, cb func(ctx context.Context, addr hash.Hash) error) error { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return err + } + arr := pm.AddressArrayBytes() + for i := 0; i < len(arr)/hash.ByteLen; i++ { + addr := hash.New(arr[i*addrSize : (i+1)*addrSize]) + if err := cb(ctx, addr); err != nil { + return err + } + } + + return nil +} + +func getVectorIndexCount(msg serial.Message) (uint16, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, err + } + return uint16(pm.KeyOffsetsLength() - 1), nil +} + +func getVectorIndexTreeLevel(msg serial.Message) (int, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, fb.ErrTableHasUnknownFields + } + return int(pm.TreeLevel()), nil +} + +func getVectorIndexTreeCount(msg serial.Message) (int, error) { + var pm serial.VectorIndexNode + err := serial.InitVectorIndexNodeRoot(&pm, msg, serial.MessagePrefixSz) + if err != nil { + return 0, fb.ErrTableHasUnknownFields + } + return int(pm.TreeCount()), nil +} + +func getVectorIndexSubtrees(msg serial.Message) ([]uint64, error) { + sz, err := getVectorIndexCount(msg) + if err != nil { + return nil, err + } + + var pm serial.VectorIndexNode + n := fb.GetUOffsetT(msg[serial.MessagePrefixSz:]) + err = pm.Init(msg, serial.MessagePrefixSz+n) + if err != nil { + return nil, err + } + + counts := make([]uint64, sz) + + return decodeVarints(pm.SubtreeCountsBytes(), counts), nil +} + +// estimateVectorIndexSize returns the exact Size of the tuple vectors for keys and values, +// and an estimate of the overall Size of the final flatbuffer. +func estimateVectorIndexSize(keys, values [][]byte, subtrees []uint64) (int, int, int) { + var keySz, valSz, bufSz int + for i := range keys { + keySz += len(keys[i]) + valSz += len(values[i]) + } + subtreesSz := len(subtrees) * binary.MaxVarintLen64 + + // constraints enforced upstream + if keySz > int(MaxVectorOffset) { + panic(fmt.Sprintf("key vector exceeds Size limit ( %d > %d )", keySz, MaxVectorOffset)) + } + if valSz > int(MaxVectorOffset) { + panic(fmt.Sprintf("value vector exceeds Size limit ( %d > %d )", valSz, MaxVectorOffset)) + } + + bufSz += keySz + valSz // tuples + bufSz += subtreesSz // subtree counts + bufSz += len(keys)*2 + len(values)*2 // offStart + bufSz += 8 + 1 + 1 + 1 // metadata + bufSz += 72 // vtable (approx) + bufSz += 100 // padding? + bufSz += serial.MessagePrefixSz + + return keySz, valSz, bufSz +} diff --git a/go/store/prolly/proximity_map.go b/go/store/prolly/proximity_map.go new file mode 100644 index 0000000000..d1b61043a9 --- /dev/null +++ b/go/store/prolly/proximity_map.go @@ -0,0 +1,553 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "io" + "iter" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/prolly/message" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" +) + +// ProximityMap wraps a tree.ProximityMap but operates on typed Tuples instead of raw bytestrings. +type ProximityMap struct { + tuples tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc] + keyDesc val.TupleDesc + valDesc val.TupleDesc +} + +// Count returns the number of key-value pairs in the Map. +func (m ProximityMap) Count() (int, error) { + return m.tuples.Count() +} + +// Get searches for the key-value pair keyed by |key| and passes the results to the callback. +// If |key| is not present in the map, a nil key-value pair are passed. +func (m ProximityMap) Get(ctx context.Context, query interface{}, cb tree.KeyValueFn[val.Tuple, val.Tuple]) (err error) { + return m.tuples.GetExact(ctx, query, cb) +} + +func (m ProximityMap) GetClosest(ctx context.Context, query interface{}, cb tree.KeyValueDistanceFn[val.Tuple, val.Tuple], limit int) (err error) { + return m.tuples.GetClosest(ctx, query, cb, limit) +} + +// NewProximityMap creates a new ProximityMap from a supplied root node. +func NewProximityMap(ctx context.Context, ns tree.NodeStore, node tree.Node, keyDesc val.TupleDesc, valDesc val.TupleDesc, distanceType expression.DistanceType) ProximityMap { + tuples := tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]{ + Root: node, + NodeStore: ns, + Order: keyDesc, + DistanceType: distanceType, + Convert: func(bytes []byte) []float64 { + h, _ := keyDesc.GetJSONAddr(0, bytes) + doc := tree.NewJSONDoc(h, ns) + jsonWrapper, err := doc.ToIndexedJSONDocument(ctx) + if err != nil { + panic(err) + } + floats, err := sql.ConvertToVector(jsonWrapper) + if err != nil { + panic(err) + } + return floats + }, + } + return ProximityMap{ + tuples: tuples, + keyDesc: keyDesc, + valDesc: valDesc, + } +} + +var levelMapKeyDesc = val.NewTupleDescriptor( + val.Type{Enc: val.Uint8Enc, Nullable: false}, + val.Type{Enc: val.ByteStringEnc, Nullable: false}, +) + +// NewProximityMapFromTuples creates a new ProximityMap from a given list of key-value pairs. +func NewProximityMapFromTuples(ctx context.Context, ns tree.NodeStore, distanceType expression.DistanceType, keyDesc val.TupleDesc, valDesc val.TupleDesc, logChunkSize uint8) (proximityMapBuilder, error) { + + emptyLevelMap, err := NewMapFromTuples(ctx, ns, levelMapKeyDesc, valDesc) + if err != nil { + return proximityMapBuilder{}, err + } + mutableLevelMap := newMutableMap(emptyLevelMap) + return proximityMapBuilder{ + ns: ns, + vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool()), + distanceType: distanceType, + keyDesc: keyDesc, + valDesc: valDesc, + logChunkSize: logChunkSize, + + maxLevel: 0, + levelMap: mutableLevelMap, + }, nil +} + +// proximityMapBuilder is effectively a namespace for helper functions used in creating a ProximityMap. +// It holds the parameters of the operation. +// Each node has an average of 2^|logChunkSize| key-value pairs. +type proximityMapBuilder struct { + ns tree.NodeStore + vectorIndexSerializer message.VectorIndexSerializer + distanceType expression.DistanceType + keyDesc, valDesc val.TupleDesc + logChunkSize uint8 + + maxLevel uint8 + levelMap *MutableMap +} + +func (b *proximityMapBuilder) Insert(ctx context.Context, key, value []byte) error { + keyLevel := tree.DeterministicHashLevel(b.logChunkSize, key) + if keyLevel > b.maxLevel { + b.maxLevel = keyLevel + } + + levelMapKeyBuilder := val.NewTupleBuilder(levelMapKeyDesc) + levelMapKeyBuilder.PutUint8(0, 255-keyLevel) + levelMapKeyBuilder.PutByteString(1, key) + return b.levelMap.Put(ctx, levelMapKeyBuilder.Build(b.ns.Pool()), value) +} + +func (b *proximityMapBuilder) makeRootNode(ctx context.Context, keys, values [][]byte, subtrees []uint64, level int) (ProximityMap, error) { + rootMsg := b.vectorIndexSerializer.Serialize(keys, values, subtrees, level) + rootNode, err := tree.NodeFromBytes(rootMsg) + if err != nil { + return ProximityMap{}, err + } + _, err = b.ns.Write(ctx, rootNode) + if err != nil { + return ProximityMap{}, err + } + + return NewProximityMap(ctx, b.ns, rootNode, b.keyDesc, b.valDesc, b.distanceType), nil +} + +func (b *proximityMapBuilder) Flush(ctx context.Context) (ProximityMap, error) { + // The algorithm for building a ProximityMap's tree requires us to start at the root and build out to the leaf nodes. + // Given that our trees are Merkle Trees, this presents an obvious problem. + // Our solution is to create the final tree by applying a series of transformations to intermediate trees. + + // Note: when talking about tree levels, we use "level" when counting from the leaves, and "depth" when counting + // from the root. In a tree with 5 levels, the root is level 4 (and depth 0), while the leaves are level 0 (and depth 4) + + // The process looks like this: + // Step 1: Create `levelMap`, a map from (indexLevel, keyBytes) -> values + // - indexLevel: the minimum level in which the vector appears + // - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) + // - values: the ProximityMap value tuple + // + // Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap + // The pathMap at depth `i` has the schema (vectorAddrs[1]...vectorAddr[i], keyBytes) -> value + // and contains a row for every vector whose maximum depth is i. + // - vectorAddrs: the path of vectors visited when walking from the root to the maximum depth where the vector appears. + // - keyBytes: a bytestring containing the bytes of the ProximityMap key (which includes the vector) + // - values: the ProximityMap value tuple + // + // These maps must be built in order, from shallowest to deepest. + // + // Step 3: Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, + // generating Nodes as we go. + // + // Currently, the intermediate trees are created using the standard NodeStore. This means that the nodes of these + // trees will inevitably be written out to disk when the NodeStore flushes, despite the fact that we know they + // won't be needed once we finish building the ProximityMap. This could potentially be avoided by creating a + // separate in-memory NodeStore for these values. + + // Check if index is empty. + if !b.levelMap.HasEdits() { + return b.makeRootNode(ctx, nil, nil, nil, 0) + } + + // Step 1: Create `levelMap`, a map from (indexLevel, keyBytes) -> values + // We want the index to be sorted by level (descending), so currently we store the level in the map as + // 255 - the actual level. TODO: Implement a ReverseIter for MutableMap and use that instead. + + if b.maxLevel == 0 { + // index is a single node. + // assuming that the keys are already sorted, we can return them unmodified. + levelMapIter, err := b.levelMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + var keys, values [][]byte + for { + key, value, err := levelMapIter.Next(ctx) + if err == io.EOF { + break + } + originalKey, _ := levelMapKeyDesc.GetBytes(1, key) + if err != nil { + return ProximityMap{}, err + } + keys = append(keys, originalKey) + values = append(values, value) + } + return b.makeRootNode(ctx, keys, values, nil, 0) + } + + // Step 2: Create `pathMaps`, a list of maps, each corresponding to a different level of the ProximityMap + pathMaps, err := b.makePathMaps(ctx, b.levelMap) + if err != nil { + return ProximityMap{}, err + } + + // Step 3: Create an iter over each `pathMap` created in the previous step, and walk the shape of the final ProximityMap, + // generating Nodes as we go. + return b.makeProximityMapFromPathMaps(ctx, pathMaps) +} + +// makePathMaps creates a set of prolly maps, each of which corresponds to a different level in the to-be-built ProximityMap +func (b *proximityMapBuilder) makePathMaps(ctx context.Context, mutableLevelMap *MutableMap) ([]*MutableMap, error) { + levelMapIter, err := mutableLevelMap.IterAll(ctx) + if err != nil { + return nil, err + } + + // The first element of levelMap tells us the height of the tree. + levelMapKey, levelMapValue, err := levelMapIter.Next(ctx) + if err != nil { + return nil, err + } + maxLevel, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + maxLevel = 255 - maxLevel + + // Create every val.TupleBuilder and MutableMap that we will need + // pathMaps[i] is the pathMap for level i (and depth maxLevel - i) + pathMaps, keyTupleBuilders, prefixTupleBuilders, err := b.createInitialPathMaps(ctx, maxLevel) + + // Next, visit each key-value pair in decreasing order of level / increasing order of depth. + // When visiting a pair from depth `i`, we use each of the previous `i` pathMaps to compute a path of `i` index keys. + // This path dictate's that pair's location in the final ProximityMap. + for { + level, _ := mutableLevelMap.keyDesc.GetUint8(0, levelMapKey) + level = 255 - level // we currently store the level as 255 - the actual level for sorting purposes. + depth := int(maxLevel - level) + + keyTupleBuilder := keyTupleBuilders[level] + var hashPath []hash.Hash + keyToInsert, _ := mutableLevelMap.keyDesc.GetBytes(1, levelMapKey) + vectorHashToInsert, _ := b.keyDesc.GetJSONAddr(0, keyToInsert) + vectorToInsert, err := getVectorFromHash(ctx, b.ns, vectorHashToInsert) + if err != nil { + return nil, err + } + // Compute the path that this row will have in the vector index, starting at the root. + // A key-value pair at depth D will have a path D prior keys. + // This path is computed in steps, by performing a lookup in each of the prior pathMaps. + for pathDepth := 0; pathDepth < depth; pathDepth++ { + lookupLevel := int(maxLevel) - pathDepth + prefixTupleBuilder := prefixTupleBuilders[lookupLevel] + pathMap := pathMaps[lookupLevel] + + pathMapIter, err := b.getNextPathSegmentCandidates(ctx, pathMap, prefixTupleBuilder, hashPath) + if err != nil { + return nil, err + } + + // Create an iterator that yields every candidate vector + nextCandidate, stopIter := iter.Pull2(func(yield func(hash.Hash, error) bool) { + if pathDepth != 0 { + firstCandidate := hashPath[pathDepth-1] + yield(firstCandidate, nil) + } + for { + pathMapKey, _, err := pathMapIter.Next(ctx) + if err == io.EOF { + return + } + if err != nil { + yield(hash.Hash{}, err) + } + originalKey, _ := pathMap.keyDesc.GetBytes(pathDepth, pathMapKey) + candidateVectorHash, _ := b.keyDesc.GetJSONAddr(0, originalKey) + yield(candidateVectorHash, nil) + } + }) + defer stopIter() + + closestVectorHash, _ := b.getClosestVector(ctx, vectorToInsert, nextCandidate) + + hashPath = append(hashPath, closestVectorHash) + } + + // Once we have the path for this key, we turn it into a tuple and add it to the next pathMap. + for i, h := range hashPath { + keyTupleBuilder.PutJSONAddr(i, h) + } + keyTupleBuilder.PutByteString(depth, keyToInsert) + + err = pathMaps[level].Put(ctx, keyTupleBuilder.Build(b.ns.Pool()), levelMapValue) + if err != nil { + return nil, err + } + + levelMapKey, levelMapValue, err = levelMapIter.Next(ctx) + if err == io.EOF { + return pathMaps, nil + } + if err != nil { + return nil, err + } + } +} + +// createInitialPathMaps creates a list of MutableMaps that will eventually store a single level of the to-be-built ProximityMap +func (b *proximityMapBuilder) createInitialPathMaps(ctx context.Context, maxLevel uint8) (pathMaps []*MutableMap, keyTupleBuilders, prefixTupleBuilders []*val.TupleBuilder, err error) { + keyTupleBuilders = make([]*val.TupleBuilder, maxLevel+1) + prefixTupleBuilders = make([]*val.TupleBuilder, maxLevel+1) + pathMaps = make([]*MutableMap, maxLevel+1) + + // Make a type slice for the maximum depth pathMap: each other slice we need is a subslice of this one. + pathMapKeyDescTypes := make([]val.Type, maxLevel+1) + for i := uint8(0); i < maxLevel; i++ { + pathMapKeyDescTypes[i] = val.Type{Enc: val.JSONAddrEnc, Nullable: false} + } + pathMapKeyDescTypes[maxLevel] = val.Type{Enc: val.ByteStringEnc, Nullable: false} + + for i := uint8(0); i <= maxLevel; i++ { + pathMapKeyDesc := val.NewTupleDescriptor(pathMapKeyDescTypes[i:]...) + + emptyPathMap, err := NewMapFromTuples(ctx, b.ns, pathMapKeyDesc, b.valDesc) + if err != nil { + return nil, nil, nil, err + } + pathMaps[i] = newMutableMap(emptyPathMap) + + keyTupleBuilders[i] = val.NewTupleBuilder(pathMapKeyDesc) + prefixTupleBuilders[i] = val.NewTupleBuilder(val.NewTupleDescriptor(pathMapKeyDescTypes[i:maxLevel]...)) + } + + return pathMaps, keyTupleBuilders, prefixTupleBuilders, nil +} + +// getNextPathSegmentCandidates takes a list of keys, representing a path into the ProximityMap from the root. +// It returns an iter over all possible keys that could be the next path segment. +func (b *proximityMapBuilder) getNextPathSegmentCandidates(ctx context.Context, pathMap *MutableMap, prefixTupleBuilder *val.TupleBuilder, currentPath []hash.Hash) (MapIter, error) { + for tupleElem := 0; tupleElem < len(currentPath); tupleElem++ { + prefixTupleBuilder.PutJSONAddr(tupleElem, currentPath[tupleElem]) + } + prefixTuple := prefixTupleBuilder.Build(b.ns.Pool()) + + prefixRange := PrefixRange(prefixTuple, prefixTupleBuilder.Desc) + return pathMap.IterRange(ctx, prefixRange) +} + +// getClosestVector iterates over a range of candidate +func (b *proximityMapBuilder) getClosestVector(ctx context.Context, targetVector []float64, nextCandidate func() (candidate hash.Hash, err error, valid bool)) (hash.Hash, error) { + // First call to nextCandidate is guaranteed to be valid because there's at least one vector in the set. + // (non-root nodes inherit the first vector from their parent) + candidateVectorHash, err, _ := nextCandidate() + if err != nil { + return hash.Hash{}, err + } + + candidateVector, err := getVectorFromHash(ctx, b.ns, candidateVectorHash) + if err != nil { + return hash.Hash{}, err + } + closestVectorHash := candidateVectorHash + closestDistance, err := b.distanceType.Eval(targetVector, candidateVector) + if err != nil { + return hash.Hash{}, err + } + + for { + candidateVectorHash, err, valid := nextCandidate() + if err != nil { + return hash.Hash{}, err + } + if !valid { + return closestVectorHash, nil + } + candidateVector, err = getVectorFromHash(ctx, b.ns, candidateVectorHash) + if err != nil { + return hash.Hash{}, err + } + candidateDistance, err := b.distanceType.Eval(targetVector, candidateVector) + if err != nil { + return hash.Hash{}, err + } + if candidateDistance < closestDistance { + closestVectorHash = candidateVectorHash + closestDistance = candidateDistance + } + } +} + +// makeProximityMapFromPathMaps builds a ProximityMap from a list of maps, each of which corresponds to a different tree level. +func (b *proximityMapBuilder) makeProximityMapFromPathMaps(ctx context.Context, pathMaps []*MutableMap) (proximityMap ProximityMap, err error) { + maxLevel := len(pathMaps) - 1 + + // We create a chain of vectorIndexChunker objects, with the leaf row at the tail. + // Because the root node has no parent, the logic is slightly different. We don't make a vectorIndexChunker for it. + var chunker *vectorIndexChunker + for i, pathMap := range pathMaps[:maxLevel] { + chunker, err = newVectorIndexChunker(ctx, pathMap, maxLevel-i, chunker) + if err != nil { + return ProximityMap{}, err + } + } + + rootPathMap := pathMaps[maxLevel] + topLevelPathMapIter, err := rootPathMap.IterAll(ctx) + if err != nil { + return ProximityMap{}, err + } + var topLevelKeys [][]byte + var topLevelValues [][]byte + var topLevelSubtrees []uint64 + for { + key, value, err := topLevelPathMapIter.Next(ctx) + if err == io.EOF { + break + } + if err != nil { + return ProximityMap{}, err + } + originalKey, _ := rootPathMap.keyDesc.GetBytes(0, key) + path, _ := b.keyDesc.GetJSONAddr(0, originalKey) + _, nodeCount, nodeHash, err := chunker.Next(ctx, b.ns, b.vectorIndexSerializer, path, originalKey, value, maxLevel-1, 1, b.keyDesc) + if err != nil { + return ProximityMap{}, err + } + topLevelKeys = append(topLevelKeys, originalKey) + topLevelValues = append(topLevelValues, nodeHash[:]) + topLevelSubtrees = append(topLevelSubtrees, nodeCount) + } + return b.makeRootNode(ctx, topLevelKeys, topLevelValues, topLevelSubtrees, maxLevel) +} + +// vectorIndexChunker is a stateful chunker that iterates over |pathMap|, a map that contains an element +// for every key-value pair for a given level of a ProximityMap, and provides the path of keys to reach +// that pair from the root. It uses this iterator to build each of the ProximityMap nodes for that level. +type vectorIndexChunker struct { + pathMap *MutableMap + pathMapIter MapIter + lastPathSegment hash.Hash + lastKey []byte + lastValue []byte + lastSubtreeCount uint64 + childChunker *vectorIndexChunker + atEnd bool +} + +func newVectorIndexChunker(ctx context.Context, pathMap *MutableMap, depth int, childChunker *vectorIndexChunker) (*vectorIndexChunker, error) { + pathMapIter, err := pathMap.IterAll(ctx) + if err != nil { + return nil, err + } + firstKey, firstValue, err := pathMapIter.Next(ctx) + if err == io.EOF { + // In rare situations, there aren't any vectors at a given level. + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + atEnd: true, + }, nil + } + if err != nil { + return nil, err + } + lastPathSegment, _ := pathMap.keyDesc.GetJSONAddr(depth-1, firstKey) + originalKey, _ := pathMap.keyDesc.GetBytes(depth, firstKey) + return &vectorIndexChunker{ + pathMap: pathMap, + pathMapIter: pathMapIter, + childChunker: childChunker, + lastKey: originalKey, + lastValue: firstValue, + lastPathSegment: lastPathSegment, + atEnd: false, + }, nil +} + +func (c *vectorIndexChunker) Next(ctx context.Context, ns tree.NodeStore, serializer message.VectorIndexSerializer, parentPathSegment hash.Hash, parentKey val.Tuple, parentValue val.Tuple, level, depth int, originalKeyDesc val.TupleDesc) (tree.Node, uint64, hash.Hash, error) { + indexMapKeys := [][]byte{parentKey} + var indexMapValues [][]byte + var indexMapSubtrees []uint64 + subtreeSum := uint64(0) + if c.childChunker != nil { + _, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, parentPathSegment, parentKey, parentValue, level-1, depth+1, originalKeyDesc) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + indexMapValues = append(indexMapValues, nodeHash[:]) + indexMapSubtrees = append(indexMapSubtrees, childCount) + subtreeSum += childCount + } else { + indexMapValues = append(indexMapValues, parentValue) + subtreeSum++ + } + + for { + if c.atEnd || c.lastPathSegment != parentPathSegment { + msg := serializer.Serialize(indexMapKeys, indexMapValues, indexMapSubtrees, level) + node, err := tree.NodeFromBytes(msg) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + nodeHash, err := ns.Write(ctx, node) + return node, subtreeSum, nodeHash, err + } + vectorHash, _ := originalKeyDesc.GetJSONAddr(0, c.lastKey) + if c.childChunker != nil { + _, childCount, nodeHash, err := c.childChunker.Next(ctx, ns, serializer, vectorHash, c.lastKey, c.lastValue, level-1, depth+1, originalKeyDesc) + if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } + c.lastValue = nodeHash[:] + indexMapSubtrees = append(indexMapSubtrees, childCount) + subtreeSum += childCount + } else { + subtreeSum++ + } + indexMapKeys = append(indexMapKeys, c.lastKey) + indexMapValues = append(indexMapValues, c.lastValue) + + nextKey, nextValue, err := c.pathMapIter.Next(ctx) + if err == io.EOF { + c.atEnd = true + } else if err != nil { + return tree.Node{}, 0, hash.Hash{}, err + } else { + c.lastPathSegment, _ = c.pathMap.keyDesc.GetJSONAddr(depth-1, nextKey) + c.lastKey, _ = c.pathMap.keyDesc.GetBytes(depth, nextKey) + c.lastValue = nextValue + } + } +} + +func getJsonValueFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) (interface{}, error) { + return tree.NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns tree.NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} diff --git a/go/store/prolly/proximity_map_test.go b/go/store/prolly/proximity_map_test.go new file mode 100644 index 0000000000..014ee8a5f9 --- /dev/null +++ b/go/store/prolly/proximity_map_test.go @@ -0,0 +1,208 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package prolly + +import ( + "context" + "os" + "testing" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dolthub/dolt/go/store/hash" + "github.com/dolthub/dolt/go/store/pool" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/val" +) + +func newJsonValue(t *testing.T, v interface{}) sql.JSONWrapper { + doc, _, err := types.JSON.Convert(v) + require.NoError(t, err) + return doc.(sql.JSONWrapper) +} + +// newJsonDocument creates a JSON value from a provided value. +func newJsonDocument(t *testing.T, ctx context.Context, ns tree.NodeStore, v interface{}) hash.Hash { + doc := newJsonValue(t, v) + root, err := tree.SerializeJsonToAddr(ctx, ns, doc) + require.NoError(t, err) + return root.HashOf() +} + +func createProximityMap(t *testing.T, ctx context.Context, ns tree.NodeStore, vectors []interface{}, pks []int64, logChunkSize uint8) (ProximityMap, [][]byte, [][]byte) { + bp := pool.NewBuffPool() + + count := len(vectors) + require.Equal(t, count, len(pks)) + + kd := val.NewTupleDescriptor( + val.Type{Enc: val.JSONAddrEnc, Nullable: true}, + ) + + vd := val.NewTupleDescriptor( + val.Type{Enc: val.Int64Enc, Nullable: true}, + ) + + distanceType := expression.DistanceL2Squared{} + + builder, err := NewProximityMapFromTuples(ctx, ns, distanceType, kd, vd, logChunkSize) + require.NoError(t, err) + + keys := make([][]byte, count) + values := make([][]byte, count) + keyBuilder := val.NewTupleBuilder(kd) + valueBuilder := val.NewTupleBuilder(vd) + for i, vector := range vectors { + keyBuilder.PutJSONAddr(0, newJsonDocument(t, ctx, ns, vector)) + nextKey := keyBuilder.Build(bp) + keys[i] = nextKey + + valueBuilder.PutInt64(0, pks[i]) + nextValue := valueBuilder.Build(bp) + values[i] = nextValue + + err = builder.Insert(ctx, nextKey, nextValue) + require.NoError(t, err) + } + + m, err := builder.Flush(ctx) + require.NoError(t, err) + + require.NoError(t, err) + mapCount, err := m.Count() + require.NoError(t, err) + require.Equal(t, count, mapCount) + + return m, keys, values +} + +func TestEmptyProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + createProximityMap(t, ctx, ns, nil, nil, 10) +} + +func TestSingleEntryProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[1.0]"}, []int64{1}, 10) + matches := 0 + vectorHash, _ := m.keyDesc.GetJSONAddr(0, keys[0]) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + require.NoError(t, err) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(keys[0]), foundKey) + require.Equal(t, val.Tuple(values[0]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + require.Equal(t, matches, 1) +} + +func TestDoubleEntryProximityMapGetExact(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) + matches := 0 + for i, key := range keys { + vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestDoubleEntryProximityMapGetClosest(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + m, keys, values := createProximityMap(t, ctx, ns, []interface{}{"[0.0, 6.0]", "[3.0, 4.0]"}, []int64{1, 2}, 10) + matches := 0 + + cb := func(foundKey val.Tuple, foundValue val.Tuple, distance float64) error { + require.Equal(t, val.Tuple(keys[1]), foundKey) + require.Equal(t, val.Tuple(values[1]), foundValue) + require.InDelta(t, distance, 25.0, 0.1) + matches++ + return nil + } + + err := m.GetClosest(ctx, newJsonValue(t, "[0.0, 0.0]"), cb, 1) + require.NoError(t, err) + require.Equal(t, matches, 1) +} + +func TestMultilevelProximityMap(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + keyStrings := []interface{}{ + "[0.0, 1.0]", + "[3.0, 4.0]", + "[5.0, 6.0]", + "[7.0, 8.0]", + } + valueStrings := []int64{1, 2, 3, 4} + m, keys, values := createProximityMap(t, ctx, ns, keyStrings, valueStrings, 1) + matches := 0 + for i, key := range keys { + vectorHash, _ := m.keyDesc.GetJSONAddr(0, key) + vectorDoc, err := tree.NewJSONDoc(vectorHash, ns).ToIndexedJSONDocument(ctx) + require.NoError(t, err) + err = m.Get(ctx, vectorDoc, func(foundKey val.Tuple, foundValue val.Tuple) error { + require.Equal(t, val.Tuple(key), foundKey) + require.Equal(t, val.Tuple(values[i]), foundValue) + matches++ + return nil + }) + require.NoError(t, err) + } + require.Equal(t, matches, len(keys)) +} + +func TestInsertOrderIndependence(t *testing.T) { + ctx := context.Background() + ns := tree.NewTestNodeStore() + keyStrings1 := []interface{}{ + "[0.0, 1.0]", + "[3.0, 4.0]", + "[5.0, 6.0]", + "[7.0, 8.0]", + } + valueStrings1 := []int64{1, 2, 3, 4} + keyStrings2 := []interface{}{ + "[7.0, 8.0]", + "[5.0, 6.0]", + "[3.0, 4.0]", + "[0.0, 1.0]", + } + valueStrings2 := []int64{4, 3, 2, 1} + m1, _, _ := createProximityMap(t, ctx, ns, keyStrings1, valueStrings1, 1) + m2, _, _ := createProximityMap(t, ctx, ns, keyStrings2, valueStrings2, 1) + + if !assert.Equal(t, m1.tuples.Root.HashOf(), m2.tuples.Root.HashOf(), "trees have different hashes") { + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m1.tuples.Root)) + require.NoError(t, tree.OutputProllyNodeBytes(os.Stdout, m2.tuples.Root)) + } +} diff --git a/go/store/prolly/tree/chunker.go b/go/store/prolly/tree/chunker.go index 04056f2835..26ec974e78 100644 --- a/go/store/prolly/tree/chunker.go +++ b/go/store/prolly/tree/chunker.go @@ -40,9 +40,10 @@ type chunker[S message.Serializer] struct { level int done bool - splitter nodeSplitter - builder *nodeBuilder[S] - serializer S + splitterFactory splitterFactory + splitter nodeSplitter + builder *nodeBuilder[S] + serializer S ns NodeStore } @@ -58,20 +59,25 @@ func newEmptyChunker[S message.Serializer](ctx context.Context, ns NodeStore, se } func newChunker[S message.Serializer](ctx context.Context, cur *cursor, level int, ns NodeStore, serializer S) (*chunker[S], error) { + return newChunkerWithSplitterFactory(ctx, cur, level, ns, serializer, defaultSplitterFactory) +} + +func newChunkerWithSplitterFactory[S message.Serializer](ctx context.Context, cur *cursor, level int, ns NodeStore, serializer S, splitterFactory splitterFactory) (*chunker[S], error) { // |cur| will be nil if this is a new Node, implying this is a new tree, or the tree has grown in height relative // to its original chunked form. - splitter := defaultSplitterFactory(uint8(level % 256)) + splitter := splitterFactory(uint8(level % 256)) builder := newNodeBuilder(serializer, level) sc := &chunker[S]{ - cur: cur, - parent: nil, - level: level, - splitter: splitter, - builder: builder, - serializer: serializer, - ns: ns, + cur: cur, + parent: nil, + level: level, + splitter: splitter, + splitterFactory: splitterFactory, + builder: builder, + serializer: serializer, + ns: ns, } if cur != nil { @@ -355,7 +361,7 @@ func (tc *chunker[S]) createParentChunker(ctx context.Context) (err error) { parent = tc.cur.parent } - tc.parent, err = newChunker(ctx, parent, tc.level+1, tc.ns, tc.serializer) + tc.parent, err = newChunkerWithSplitterFactory(ctx, parent, tc.level+1, tc.ns, tc.serializer, tc.splitterFactory) if err != nil { return err } diff --git a/go/store/prolly/tree/node_splitter.go b/go/store/prolly/tree/node_splitter.go index 5714dc0bd2..874aecbcb2 100644 --- a/go/store/prolly/tree/node_splitter.go +++ b/go/store/prolly/tree/node_splitter.go @@ -25,6 +25,7 @@ import ( "crypto/sha512" "encoding/binary" "math" + "math/bits" "github.com/kch42/buzhash" "github.com/zeebo/xxh3" @@ -111,7 +112,7 @@ func newRollingHashSplitter(salt uint8) nodeSplitter { var _ splitterFactory = newRollingHashSplitter -// Append implements NodeSplitter +// Append implements nodeSplitter func (sns *rollingHashSplitter) Append(key, value Item) (err error) { for _, byt := range key { _ = sns.hashByte(byt) @@ -146,12 +147,12 @@ func (sns *rollingHashSplitter) hashByte(b byte) bool { return sns.crossedBoundary } -// CrossedBoundary implements NodeSplitter +// CrossedBoundary implements nodeSplitter func (sns *rollingHashSplitter) CrossedBoundary() bool { return sns.crossedBoundary } -// Reset implements NodeSplitter +// Reset implements nodeSplitter func (sns *rollingHashSplitter) Reset() { sns.crossedBoundary = false sns.offset = 0 @@ -264,3 +265,8 @@ func saltFromLevel(level uint8) (salt uint64) { full := sha512.Sum512([]byte{level}) return binary.LittleEndian.Uint64(full[:8]) } + +func DeterministicHashLevel(leadingZerosPerLevel uint8, key Item) uint8 { + h := xxHash32(key, levelSalt[1]) + return uint8(bits.LeadingZeros32(h)) / leadingZerosPerLevel +} diff --git a/go/store/prolly/tree/proximity_map.go b/go/store/prolly/tree/proximity_map.go new file mode 100644 index 0000000000..f3b6f75150 --- /dev/null +++ b/go/store/prolly/tree/proximity_map.go @@ -0,0 +1,184 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tree + +import ( + "context" + "fmt" + "math" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + "github.com/dolthub/dolt/go/store/hash" +) + +type KeyValueDistanceFn[K, V ~[]byte] func(key K, value V, distance float64) error + +// ProximityMap is a static Prolly Tree where the position of a key in the tree is based on proximity, as opposed to a traditional ordering. +// O provides the ordering only within a node. +type ProximityMap[K, V ~[]byte, O Ordering[K]] struct { + Root Node + NodeStore NodeStore + DistanceType expression.DistanceType + Convert func([]byte) []float64 + Order O +} + +func (t ProximityMap[K, V, O]) Count() (int, error) { + return t.Root.TreeCount() +} + +func (t ProximityMap[K, V, O]) Height() int { + return t.Root.Level() + 1 +} + +func (t ProximityMap[K, V, O]) HashOf() hash.Hash { + return t.Root.HashOf() +} + +func (t ProximityMap[K, V, O]) WalkAddresses(ctx context.Context, cb AddressCb) error { + return WalkAddresses(ctx, t.Root, t.NodeStore, cb) +} + +func (t ProximityMap[K, V, O]) WalkNodes(ctx context.Context, cb NodeCb) error { + return WalkNodes(ctx, t.Root, t.NodeStore, cb) +} + +// GetExact searches for an exact vector in the index, calling |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) GetExact(ctx context.Context, query interface{}, cb KeyValueFn[K, V]) (err error) { + nd := t.Root + + queryVector, err := sql.ConvertToVector(query) + if err != nil { + return err + } + + // Find the child with the minimum distance. + + for { + var closestKey K + var closestIdx int + distance := math.Inf(1) + + for i := 0; i < int(nd.count); i++ { + k := nd.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + if newDistance < distance { + closestIdx = i + distance = newDistance + closestKey = []byte(k) + } + } + + if nd.IsLeaf() { + return cb(closestKey, []byte(nd.GetValue(closestIdx))) + } + + nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + if err != nil { + return err + } + } +} + +func (t ProximityMap[K, V, O]) Has(ctx context.Context, query K) (ok bool, err error) { + err = t.GetExact(ctx, query, func(_ K, _ V) error { + ok = true + return nil + }) + return ok, err +} + +// GetClosest performs an approximate nearest neighbors search. It finds |limit| vectors that are close to the query vector, +// and calls |cb| with the matching key-value pairs. +func (t ProximityMap[K, V, O]) GetClosest(ctx context.Context, query interface{}, cb KeyValueDistanceFn[K, V], limit int) (err error) { + if limit != 1 { + return fmt.Errorf("currently only limit = 1 (find single closest vector) is supported for ProximityMap") + } + + queryVector, err := sql.ConvertToVector(query) + if err != nil { + return err + } + + nd := t.Root + + var closestKey K + var closestIdx int + distance := math.Inf(1) + + for { + for i := 0; i < int(nd.count); i++ { + k := nd.GetKey(i) + newDistance, err := t.DistanceType.Eval(t.Convert(k), queryVector) + if err != nil { + return err + } + if newDistance < distance { + closestIdx = i + distance = newDistance + closestKey = []byte(k) + } + } + + if nd.IsLeaf() { + return cb(closestKey, []byte(nd.GetValue(closestIdx)), distance) + } + + nd, err = fetchChild(ctx, t.NodeStore, nd.getAddress(closestIdx)) + if err != nil { + return err + } + } +} + +func (t ProximityMap[K, V, O]) IterAll(ctx context.Context) (*OrderedTreeIter[K, V], error) { + c, err := newCursorAtStart(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + s, err := newCursorPastEnd(ctx, t.NodeStore, t.Root) + if err != nil { + return nil, err + } + + stop := func(curr *cursor) bool { + return curr.compare(s) >= 0 + } + + if stop(c) { + // empty range + return &OrderedTreeIter[K, V]{curr: nil}, nil + } + + return &OrderedTreeIter[K, V]{curr: c, stop: stop, step: c.advance}, nil +} + +func getJsonValueFromHash(ctx context.Context, ns NodeStore, h hash.Hash) (interface{}, error) { + return NewJSONDoc(h, ns).ToIndexedJSONDocument(ctx) +} + +func getVectorFromHash(ctx context.Context, ns NodeStore, h hash.Hash) ([]float64, error) { + otherValue, err := getJsonValueFromHash(ctx, ns, h) + if err != nil { + return nil, err + } + return sql.ConvertToVector(otherValue) +} diff --git a/integration-tests/go-sql-server-driver/go.mod b/integration-tests/go-sql-server-driver/go.mod index 57b39afe4e..5695eee519 100644 --- a/integration-tests/go-sql-server-driver/go.mod +++ b/integration-tests/go-sql-server-driver/go.mod @@ -1,8 +1,8 @@ module github.com/dolthub/dolt/integration-tests/go-sql-server-driver -go 1.22.5 +go 1.23.0 -toolchain go1.22.7 +toolchain go1.23.2 require ( github.com/dolthub/dolt/go v0.40.4