Store and propagate vector logChunkSize and distanceType

This commit is contained in:
Nick Tobey
2025-01-13 14:00:34 -08:00
parent ae762564e6
commit 94f6aebe60
6 changed files with 65 additions and 24 deletions

View File

@@ -276,7 +276,19 @@ func (rcv *VectorIndexNode) MutateLogChunkSize(n byte) bool {
return rcv._tab.MutateByteSlot(20, n)
}
const VectorIndexNodeNumFields = 9
func (rcv *VectorIndexNode) DistanceType() DistanceType {
o := flatbuffers.UOffsetT(rcv._tab.Offset(22))
if o != 0 {
return DistanceType(rcv._tab.GetByte(o + rcv._tab.Pos))
}
return 0
}
func (rcv *VectorIndexNode) MutateDistanceType(n DistanceType) bool {
return rcv._tab.MutateByteSlot(22, byte(n))
}
const VectorIndexNodeNumFields = 10
func VectorIndexNodeStart(builder *flatbuffers.Builder) {
builder.StartObject(VectorIndexNodeNumFields)
@@ -326,6 +338,9 @@ func VectorIndexNodeAddTreeLevel(builder *flatbuffers.Builder, treeLevel byte) {
func VectorIndexNodeAddLogChunkSize(builder *flatbuffers.Builder, logChunkSize byte) {
builder.PrependByteSlot(8, logChunkSize, 0)
}
func VectorIndexNodeAddDistanceType(builder *flatbuffers.Builder, distanceType DistanceType) {
builder.PrependByteSlot(9, byte(distanceType), 0)
}
func VectorIndexNodeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
}

View File

@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
include "schema.fbs";
namespace serial;
// VectorIndexNode is a node that makes up a vector index. Every key contains a vector value,
@@ -49,6 +51,10 @@ table VectorIndexNode {
// may choose to use a different size, or even select the best size for each index.
// all nodes in an index must use the same size, and when modifying an existing index, we must use this value.
log_chunk_size:uint8;
// each node encodes the distance function used for the index. This allows lookups without needing to retrieve the
// distance function from the schema.
distance_type:DistanceType;
}

View File

@@ -18,6 +18,7 @@ import (
"context"
"encoding/binary"
"fmt"
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
"math"
fb "github.com/dolthub/flatbuffers/v23/go"
@@ -39,13 +40,22 @@ const (
var vectorIvfFileID = []byte(serial.VectorIndexNodeFileID)
func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8) VectorIndexSerializer {
return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize}
func distanceTypeToEnum(distanceType vector.DistanceType) serial.DistanceType {
switch distanceType.(type) {
case vector.DistanceL2Squared:
return serial.DistanceTypeL2_Squared
}
return serial.DistanceTypeNull
}
func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8, distanceType vector.DistanceType) VectorIndexSerializer {
return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize, distanceType: distanceType}
}
type VectorIndexSerializer struct {
pool pool.BuffPool
logChunkSize uint8
distanceType vector.DistanceType
}
var _ Serializer = VectorIndexSerializer{}
@@ -91,6 +101,7 @@ func (s VectorIndexSerializer) Serialize(keys, values [][]byte, subtrees []uint6
}
serial.VectorIndexNodeAddTreeLevel(b, uint8(level))
serial.VectorIndexNodeAddLogChunkSize(b, s.logChunkSize)
serial.VectorIndexNodeAddDistanceType(b, distanceTypeToEnum(s.distanceType))
return serial.FinishMessage(b, serial.VectorIndexNodeEnd(b), vectorIvfFileID)
}

View File

@@ -182,7 +182,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
mutableLevelMap := newMutableMap(emptyLevelMap)
return ProximityMapBuilder{
ns: ns,
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize),
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType),
distanceType: distanceType,
keyDesc: keyDesc,
valDesc: valDesc,

View File

@@ -438,7 +438,9 @@ func TestIncrementalInserts(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
pb := pool.NewBuffPool()
logChunkSize := uint8(1)
distanceType := vector.DistanceL2Squared{}
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
keyRows1 := [][]interface{}{
{"[0.0, 1.0]"},
{"[3.0, 4.0]"},
@@ -450,7 +452,7 @@ func TestIncrementalInserts(t *testing.T) {
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1)
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
l1 := m1.tuples.Root.Level()
_ = l1
@@ -473,7 +475,7 @@ func TestIncrementalInserts(t *testing.T) {
}
// Check that map looks how we expect.
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
newMap, err := flusher.Map(ctx, mutableMap)
require.NoError(t, err)
l2 := m1.tuples.Root.Level()
@@ -494,13 +496,16 @@ func TestIncrementalInserts(t *testing.T) {
combinedValueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}, {int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}}
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 8)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
}
func TestIncrementalUpdates(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
pb := pool.NewBuffPool()
logChunkSize := uint8(1)
distanceType := vector.DistanceL2Squared{}
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
keyRows1 := [][]interface{}{
{"[0.0, 1.0]"},
{"[3.0, 4.0]"},
@@ -512,7 +517,7 @@ func TestIncrementalUpdates(t *testing.T) {
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1)
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
mutableMap := newProximityMutableMap(m1)
@@ -532,7 +537,7 @@ func TestIncrementalUpdates(t *testing.T) {
err := mutableMap.Put(ctx, nextKey, nextValue)
require.NoError(t, err)
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
newMap, err := flusher.Map(ctx, mutableMap)
require.NoError(t, err)
newCount, err := newMap.Count()
@@ -552,7 +557,7 @@ func TestIncrementalUpdates(t *testing.T) {
combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(3)}, {int64(4)}}
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
}
// update root node
@@ -566,7 +571,7 @@ func TestIncrementalUpdates(t *testing.T) {
err := mutableMap.Put(ctx, nextKey, nextValue)
require.NoError(t, err)
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
newMap, err := flusher.Map(ctx, mutableMap)
require.NoError(t, err)
combinedKeyRows := [][]interface{}{
@@ -579,7 +584,7 @@ func TestIncrementalUpdates(t *testing.T) {
combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(6)}, {int64(4)}}
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
}
}
@@ -588,6 +593,9 @@ func TestIncrementalDeletes(t *testing.T) {
ctx := context.Background()
ns := tree.NewTestNodeStore()
pb := pool.NewBuffPool()
logChunkSize := uint8(1)
distanceType := vector.DistanceL2Squared{}
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
keyRows1 := [][]interface{}{
{"[0.0, 1.0]"},
{"[3.0, 4.0]"},
@@ -599,7 +607,6 @@ func TestIncrementalDeletes(t *testing.T) {
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
logChunkSize := uint8(1)
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
mutableMap := newProximityMutableMap(m1)
@@ -616,7 +623,7 @@ func TestIncrementalDeletes(t *testing.T) {
err := mutableMap.Put(ctx, nextKey, nil)
require.NoError(t, err)
newMap, err := ProximityFlusher{logChunkSize: logChunkSize}.Map(ctx, mutableMap)
newMap, err := flusher.Map(ctx, mutableMap)
require.NoError(t, err)
combinedKeyRows := [][]interface{}{
@@ -628,7 +635,7 @@ func TestIncrementalDeletes(t *testing.T) {
combinedValueRows := [][]interface{}{{int64(2)}, {int64(3)}, {int64(4)}}
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 3)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
}
@@ -640,7 +647,7 @@ func TestIncrementalDeletes(t *testing.T) {
err := mutableMap.Put(ctx, nextKey, nil)
require.NoError(t, err)
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
newMap, err := flusher.Map(ctx, mutableMap)
require.NoError(t, err)
combinedKeyRows := [][]interface{}{
@@ -651,7 +658,7 @@ func TestIncrementalDeletes(t *testing.T) {
combinedValueRows := [][]interface{}{{int64(2)}, {int64(4)}}
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 2)
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
}
}

View File

@@ -33,6 +33,7 @@ type ProximityMutableMap = GenericMutableMap[ProximityMap, tree.ProximityMap[val
type ProximityFlusher struct {
logChunkSize uint8
distanceType vector.DistanceType
}
var _ MutableMapFlusher[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]] = ProximityFlusher{}
@@ -410,7 +411,7 @@ func (f ProximityFlusher) rebuildNode(ctx context.Context, ns tree.NodeStore, no
}
func (f ProximityFlusher) GetDefaultSerializer(ctx context.Context, mutableMap *GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]]) message.Serializer {
return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize)
return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize, f.distanceType)
}
// newMutableMap returns a new MutableMap.
@@ -420,7 +421,7 @@ func newProximityMutableMap(m ProximityMap) *ProximityMutableMap {
keyDesc: m.keyDesc,
valDesc: m.valDesc,
maxPending: defaultMaxPending,
flusher: ProximityFlusher{logChunkSize: m.logChunkSize},
flusher: ProximityFlusher{logChunkSize: m.logChunkSize, distanceType: m.tuples.DistanceType},
}
}
@@ -442,7 +443,7 @@ func (f ProximityFlusher) MapInterface(ctx context.Context, mut *ProximityMutabl
// TreeMap materializes all pending and applied mutations in the MutableMap.
func (f ProximityFlusher) TreeMap(ctx context.Context, mut *ProximityMutableMap) (tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc], error) {
s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize)
s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize, f.distanceType)
return mut.flushWithSerializer(ctx, s)
}
@@ -453,8 +454,9 @@ func (f ProximityFlusher) Map(ctx context.Context, mut *ProximityMutableMap) (Pr
return ProximityMap{}, err
}
return ProximityMap{
tuples: treeMap,
keyDesc: mut.keyDesc,
valDesc: mut.valDesc,
tuples: treeMap,
keyDesc: mut.keyDesc,
valDesc: mut.valDesc,
logChunkSize: f.logChunkSize,
}, nil
}