From 94f6aebe60ab392a515175408394edc0a76fe41c Mon Sep 17 00:00:00 2001 From: Nick Tobey Date: Mon, 13 Jan 2025 14:00:34 -0800 Subject: [PATCH] Store and propagate vector logChunkSize and distanceType --- go/gen/fb/serial/vectorindexnode.go | 17 +++++++++++- go/serial/vectorindexnode.fbs | 6 ++++ go/store/prolly/message/vector_index.go | 15 ++++++++-- go/store/prolly/proximity_map.go | 2 +- go/store/prolly/proximity_map_test.go | 35 ++++++++++++++---------- go/store/prolly/proximity_mutable_map.go | 14 ++++++---- 6 files changed, 65 insertions(+), 24 deletions(-) diff --git a/go/gen/fb/serial/vectorindexnode.go b/go/gen/fb/serial/vectorindexnode.go index d5ae97f540..9928b29b4d 100644 --- a/go/gen/fb/serial/vectorindexnode.go +++ b/go/gen/fb/serial/vectorindexnode.go @@ -276,7 +276,19 @@ func (rcv *VectorIndexNode) MutateLogChunkSize(n byte) bool { return rcv._tab.MutateByteSlot(20, n) } -const VectorIndexNodeNumFields = 9 +func (rcv *VectorIndexNode) DistanceType() DistanceType { + o := flatbuffers.UOffsetT(rcv._tab.Offset(22)) + if o != 0 { + return DistanceType(rcv._tab.GetByte(o + rcv._tab.Pos)) + } + return 0 +} + +func (rcv *VectorIndexNode) MutateDistanceType(n DistanceType) bool { + return rcv._tab.MutateByteSlot(22, byte(n)) +} + +const VectorIndexNodeNumFields = 10 func VectorIndexNodeStart(builder *flatbuffers.Builder) { builder.StartObject(VectorIndexNodeNumFields) @@ -326,6 +338,9 @@ func VectorIndexNodeAddTreeLevel(builder *flatbuffers.Builder, treeLevel byte) { func VectorIndexNodeAddLogChunkSize(builder *flatbuffers.Builder, logChunkSize byte) { builder.PrependByteSlot(8, logChunkSize, 0) } +func VectorIndexNodeAddDistanceType(builder *flatbuffers.Builder, distanceType DistanceType) { + builder.PrependByteSlot(9, byte(distanceType), 0) +} func VectorIndexNodeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/go/serial/vectorindexnode.fbs b/go/serial/vectorindexnode.fbs index 852b9ba2a1..6705257df1 100644 --- a/go/serial/vectorindexnode.fbs +++ b/go/serial/vectorindexnode.fbs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +include "schema.fbs"; + namespace serial; // VectorIndexNode is a node that makes up a vector index. Every key contains a vector value, @@ -49,6 +51,10 @@ table VectorIndexNode { // may choose to use a different size, or even select the best size for each index. // all nodes in an index must use the same size, and when modifying an existing index, we must use this value. log_chunk_size:uint8; + + // each node encodes the distance function used for the index. This allows lookups without needing to retrieve the + // distance function from the schema. + distance_type:DistanceType; } diff --git a/go/store/prolly/message/vector_index.go b/go/store/prolly/message/vector_index.go index 9839e6efbf..6cad06e87c 100644 --- a/go/store/prolly/message/vector_index.go +++ b/go/store/prolly/message/vector_index.go @@ -18,6 +18,7 @@ import ( "context" "encoding/binary" "fmt" + "github.com/dolthub/go-mysql-server/sql/expression/function/vector" "math" fb "github.com/dolthub/flatbuffers/v23/go" @@ -39,13 +40,22 @@ const ( var vectorIvfFileID = []byte(serial.VectorIndexNodeFileID) -func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8) VectorIndexSerializer { - return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize} +func distanceTypeToEnum(distanceType vector.DistanceType) serial.DistanceType { + switch distanceType.(type) { + case vector.DistanceL2Squared: + return serial.DistanceTypeL2_Squared + } + return serial.DistanceTypeNull +} + +func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8, distanceType vector.DistanceType) VectorIndexSerializer { + return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize, distanceType: distanceType} } type VectorIndexSerializer struct { pool pool.BuffPool logChunkSize uint8 + distanceType vector.DistanceType } var _ Serializer = VectorIndexSerializer{} @@ -91,6 +101,7 @@ func (s VectorIndexSerializer) Serialize(keys, values [][]byte, subtrees []uint6 } serial.VectorIndexNodeAddTreeLevel(b, uint8(level)) serial.VectorIndexNodeAddLogChunkSize(b, s.logChunkSize) + serial.VectorIndexNodeAddDistanceType(b, distanceTypeToEnum(s.distanceType)) return serial.FinishMessage(b, serial.VectorIndexNodeEnd(b), vectorIvfFileID) } diff --git a/go/store/prolly/proximity_map.go b/go/store/prolly/proximity_map.go index 58a49fec56..467196fbf1 100644 --- a/go/store/prolly/proximity_map.go +++ b/go/store/prolly/proximity_map.go @@ -182,7 +182,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType mutableLevelMap := newMutableMap(emptyLevelMap) return ProximityMapBuilder{ ns: ns, - vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize), + vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType), distanceType: distanceType, keyDesc: keyDesc, valDesc: valDesc, diff --git a/go/store/prolly/proximity_map_test.go b/go/store/prolly/proximity_map_test.go index f584c3c035..0b1ef47585 100644 --- a/go/store/prolly/proximity_map_test.go +++ b/go/store/prolly/proximity_map_test.go @@ -438,7 +438,9 @@ func TestIncrementalInserts(t *testing.T) { ctx := context.Background() ns := tree.NewTestNodeStore() pb := pool.NewBuffPool() - + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} keyRows1 := [][]interface{}{ {"[0.0, 1.0]"}, {"[3.0, 4.0]"}, @@ -450,7 +452,7 @@ func TestIncrementalInserts(t *testing.T) { valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) - m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1) + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) l1 := m1.tuples.Root.Level() _ = l1 @@ -473,7 +475,7 @@ func TestIncrementalInserts(t *testing.T) { } // Check that map looks how we expect. - newMap, err := ProximityFlusher{}.Map(ctx, mutableMap) + newMap, err := flusher.Map(ctx, mutableMap) require.NoError(t, err) l2 := m1.tuples.Root.Level() @@ -494,13 +496,16 @@ func TestIncrementalInserts(t *testing.T) { combinedValueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}, {int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}} combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) - validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 8) + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) } func TestIncrementalUpdates(t *testing.T) { ctx := context.Background() ns := tree.NewTestNodeStore() pb := pool.NewBuffPool() + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} keyRows1 := [][]interface{}{ {"[0.0, 1.0]"}, {"[3.0, 4.0]"}, @@ -512,7 +517,7 @@ func TestIncrementalUpdates(t *testing.T) { valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) - m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1) + m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) mutableMap := newProximityMutableMap(m1) @@ -532,7 +537,7 @@ func TestIncrementalUpdates(t *testing.T) { err := mutableMap.Put(ctx, nextKey, nextValue) require.NoError(t, err) - newMap, err := ProximityFlusher{}.Map(ctx, mutableMap) + newMap, err := flusher.Map(ctx, mutableMap) require.NoError(t, err) newCount, err := newMap.Count() @@ -552,7 +557,7 @@ func TestIncrementalUpdates(t *testing.T) { combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(3)}, {int64(4)}} combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) - validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4) + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) } // update root node @@ -566,7 +571,7 @@ func TestIncrementalUpdates(t *testing.T) { err := mutableMap.Put(ctx, nextKey, nextValue) require.NoError(t, err) - newMap, err := ProximityFlusher{}.Map(ctx, mutableMap) + newMap, err := flusher.Map(ctx, mutableMap) require.NoError(t, err) combinedKeyRows := [][]interface{}{ @@ -579,7 +584,7 @@ func TestIncrementalUpdates(t *testing.T) { combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(6)}, {int64(4)}} combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) - validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4) + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) } } @@ -588,6 +593,9 @@ func TestIncrementalDeletes(t *testing.T) { ctx := context.Background() ns := tree.NewTestNodeStore() pb := pool.NewBuffPool() + logChunkSize := uint8(1) + distanceType := vector.DistanceL2Squared{} + flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType} keyRows1 := [][]interface{}{ {"[0.0, 1.0]"}, {"[3.0, 4.0]"}, @@ -599,7 +607,6 @@ func TestIncrementalDeletes(t *testing.T) { valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}} values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1) - logChunkSize := uint8(1) m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize) mutableMap := newProximityMutableMap(m1) @@ -616,7 +623,7 @@ func TestIncrementalDeletes(t *testing.T) { err := mutableMap.Put(ctx, nextKey, nil) require.NoError(t, err) - newMap, err := ProximityFlusher{logChunkSize: logChunkSize}.Map(ctx, mutableMap) + newMap, err := flusher.Map(ctx, mutableMap) require.NoError(t, err) combinedKeyRows := [][]interface{}{ @@ -628,7 +635,7 @@ func TestIncrementalDeletes(t *testing.T) { combinedValueRows := [][]interface{}{{int64(2)}, {int64(3)}, {int64(4)}} combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) - validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 3) + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) } @@ -640,7 +647,7 @@ func TestIncrementalDeletes(t *testing.T) { err := mutableMap.Put(ctx, nextKey, nil) require.NoError(t, err) - newMap, err := ProximityFlusher{}.Map(ctx, mutableMap) + newMap, err := flusher.Map(ctx, mutableMap) require.NoError(t, err) combinedKeyRows := [][]interface{}{ @@ -651,7 +658,7 @@ func TestIncrementalDeletes(t *testing.T) { combinedValueRows := [][]interface{}{{int64(2)}, {int64(4)}} combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows) - validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 2) + validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize) } } diff --git a/go/store/prolly/proximity_mutable_map.go b/go/store/prolly/proximity_mutable_map.go index 67389a9592..2233711351 100644 --- a/go/store/prolly/proximity_mutable_map.go +++ b/go/store/prolly/proximity_mutable_map.go @@ -33,6 +33,7 @@ type ProximityMutableMap = GenericMutableMap[ProximityMap, tree.ProximityMap[val type ProximityFlusher struct { logChunkSize uint8 + distanceType vector.DistanceType } var _ MutableMapFlusher[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]] = ProximityFlusher{} @@ -410,7 +411,7 @@ func (f ProximityFlusher) rebuildNode(ctx context.Context, ns tree.NodeStore, no } func (f ProximityFlusher) GetDefaultSerializer(ctx context.Context, mutableMap *GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]]) message.Serializer { - return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize) + return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize, f.distanceType) } // newMutableMap returns a new MutableMap. @@ -420,7 +421,7 @@ func newProximityMutableMap(m ProximityMap) *ProximityMutableMap { keyDesc: m.keyDesc, valDesc: m.valDesc, maxPending: defaultMaxPending, - flusher: ProximityFlusher{logChunkSize: m.logChunkSize}, + flusher: ProximityFlusher{logChunkSize: m.logChunkSize, distanceType: m.tuples.DistanceType}, } } @@ -442,7 +443,7 @@ func (f ProximityFlusher) MapInterface(ctx context.Context, mut *ProximityMutabl // TreeMap materializes all pending and applied mutations in the MutableMap. func (f ProximityFlusher) TreeMap(ctx context.Context, mut *ProximityMutableMap) (tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc], error) { - s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize) + s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize, f.distanceType) return mut.flushWithSerializer(ctx, s) } @@ -453,8 +454,9 @@ func (f ProximityFlusher) Map(ctx context.Context, mut *ProximityMutableMap) (Pr return ProximityMap{}, err } return ProximityMap{ - tuples: treeMap, - keyDesc: mut.keyDesc, - valDesc: mut.valDesc, + tuples: treeMap, + keyDesc: mut.keyDesc, + valDesc: mut.valDesc, + logChunkSize: f.logChunkSize, }, nil }