mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-06 00:39:40 -06:00
Store and propagate vector logChunkSize and distanceType
This commit is contained in:
@@ -276,7 +276,19 @@ func (rcv *VectorIndexNode) MutateLogChunkSize(n byte) bool {
|
||||
return rcv._tab.MutateByteSlot(20, n)
|
||||
}
|
||||
|
||||
const VectorIndexNodeNumFields = 9
|
||||
func (rcv *VectorIndexNode) DistanceType() DistanceType {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(22))
|
||||
if o != 0 {
|
||||
return DistanceType(rcv._tab.GetByte(o + rcv._tab.Pos))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (rcv *VectorIndexNode) MutateDistanceType(n DistanceType) bool {
|
||||
return rcv._tab.MutateByteSlot(22, byte(n))
|
||||
}
|
||||
|
||||
const VectorIndexNodeNumFields = 10
|
||||
|
||||
func VectorIndexNodeStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(VectorIndexNodeNumFields)
|
||||
@@ -326,6 +338,9 @@ func VectorIndexNodeAddTreeLevel(builder *flatbuffers.Builder, treeLevel byte) {
|
||||
func VectorIndexNodeAddLogChunkSize(builder *flatbuffers.Builder, logChunkSize byte) {
|
||||
builder.PrependByteSlot(8, logChunkSize, 0)
|
||||
}
|
||||
func VectorIndexNodeAddDistanceType(builder *flatbuffers.Builder, distanceType DistanceType) {
|
||||
builder.PrependByteSlot(9, byte(distanceType), 0)
|
||||
}
|
||||
func VectorIndexNodeEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
include "schema.fbs";
|
||||
|
||||
namespace serial;
|
||||
|
||||
// VectorIndexNode is a node that makes up a vector index. Every key contains a vector value,
|
||||
@@ -49,6 +51,10 @@ table VectorIndexNode {
|
||||
// may choose to use a different size, or even select the best size for each index.
|
||||
// all nodes in an index must use the same size, and when modifying an existing index, we must use this value.
|
||||
log_chunk_size:uint8;
|
||||
|
||||
// each node encodes the distance function used for the index. This allows lookups without needing to retrieve the
|
||||
// distance function from the schema.
|
||||
distance_type:DistanceType;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
|
||||
"math"
|
||||
|
||||
fb "github.com/dolthub/flatbuffers/v23/go"
|
||||
@@ -39,13 +40,22 @@ const (
|
||||
|
||||
var vectorIvfFileID = []byte(serial.VectorIndexNodeFileID)
|
||||
|
||||
func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8) VectorIndexSerializer {
|
||||
return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize}
|
||||
func distanceTypeToEnum(distanceType vector.DistanceType) serial.DistanceType {
|
||||
switch distanceType.(type) {
|
||||
case vector.DistanceL2Squared:
|
||||
return serial.DistanceTypeL2_Squared
|
||||
}
|
||||
return serial.DistanceTypeNull
|
||||
}
|
||||
|
||||
func NewVectorIndexSerializer(pool pool.BuffPool, logChunkSize uint8, distanceType vector.DistanceType) VectorIndexSerializer {
|
||||
return VectorIndexSerializer{pool: pool, logChunkSize: logChunkSize, distanceType: distanceType}
|
||||
}
|
||||
|
||||
type VectorIndexSerializer struct {
|
||||
pool pool.BuffPool
|
||||
logChunkSize uint8
|
||||
distanceType vector.DistanceType
|
||||
}
|
||||
|
||||
var _ Serializer = VectorIndexSerializer{}
|
||||
@@ -91,6 +101,7 @@ func (s VectorIndexSerializer) Serialize(keys, values [][]byte, subtrees []uint6
|
||||
}
|
||||
serial.VectorIndexNodeAddTreeLevel(b, uint8(level))
|
||||
serial.VectorIndexNodeAddLogChunkSize(b, s.logChunkSize)
|
||||
serial.VectorIndexNodeAddDistanceType(b, distanceTypeToEnum(s.distanceType))
|
||||
|
||||
return serial.FinishMessage(b, serial.VectorIndexNodeEnd(b), vectorIvfFileID)
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ func NewProximityMapBuilder(ctx context.Context, ns tree.NodeStore, distanceType
|
||||
mutableLevelMap := newMutableMap(emptyLevelMap)
|
||||
return ProximityMapBuilder{
|
||||
ns: ns,
|
||||
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize),
|
||||
vectorIndexSerializer: message.NewVectorIndexSerializer(ns.Pool(), logChunkSize, distanceType),
|
||||
distanceType: distanceType,
|
||||
keyDesc: keyDesc,
|
||||
valDesc: valDesc,
|
||||
|
||||
@@ -438,7 +438,9 @@ func TestIncrementalInserts(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ns := tree.NewTestNodeStore()
|
||||
pb := pool.NewBuffPool()
|
||||
|
||||
logChunkSize := uint8(1)
|
||||
distanceType := vector.DistanceL2Squared{}
|
||||
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
|
||||
keyRows1 := [][]interface{}{
|
||||
{"[0.0, 1.0]"},
|
||||
{"[3.0, 4.0]"},
|
||||
@@ -450,7 +452,7 @@ func TestIncrementalInserts(t *testing.T) {
|
||||
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
|
||||
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
|
||||
|
||||
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1)
|
||||
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
|
||||
|
||||
l1 := m1.tuples.Root.Level()
|
||||
_ = l1
|
||||
@@ -473,7 +475,7 @@ func TestIncrementalInserts(t *testing.T) {
|
||||
}
|
||||
|
||||
// Check that map looks how we expect.
|
||||
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
|
||||
newMap, err := flusher.Map(ctx, mutableMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
l2 := m1.tuples.Root.Level()
|
||||
@@ -494,13 +496,16 @@ func TestIncrementalInserts(t *testing.T) {
|
||||
combinedValueRows := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}, {int64(5)}, {int64(6)}, {int64(7)}, {int64(8)}}
|
||||
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
|
||||
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 8)
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
|
||||
}
|
||||
|
||||
func TestIncrementalUpdates(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ns := tree.NewTestNodeStore()
|
||||
pb := pool.NewBuffPool()
|
||||
logChunkSize := uint8(1)
|
||||
distanceType := vector.DistanceL2Squared{}
|
||||
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
|
||||
keyRows1 := [][]interface{}{
|
||||
{"[0.0, 1.0]"},
|
||||
{"[3.0, 4.0]"},
|
||||
@@ -512,7 +517,7 @@ func TestIncrementalUpdates(t *testing.T) {
|
||||
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
|
||||
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
|
||||
|
||||
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, 1)
|
||||
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
|
||||
|
||||
mutableMap := newProximityMutableMap(m1)
|
||||
|
||||
@@ -532,7 +537,7 @@ func TestIncrementalUpdates(t *testing.T) {
|
||||
err := mutableMap.Put(ctx, nextKey, nextValue)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
|
||||
newMap, err := flusher.Map(ctx, mutableMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
newCount, err := newMap.Count()
|
||||
@@ -552,7 +557,7 @@ func TestIncrementalUpdates(t *testing.T) {
|
||||
combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(3)}, {int64(4)}}
|
||||
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
|
||||
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4)
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
|
||||
}
|
||||
|
||||
// update root node
|
||||
@@ -566,7 +571,7 @@ func TestIncrementalUpdates(t *testing.T) {
|
||||
err := mutableMap.Put(ctx, nextKey, nextValue)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
|
||||
newMap, err := flusher.Map(ctx, mutableMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
combinedKeyRows := [][]interface{}{
|
||||
@@ -579,7 +584,7 @@ func TestIncrementalUpdates(t *testing.T) {
|
||||
combinedValueRows := [][]interface{}{{int64(5)}, {int64(2)}, {int64(6)}, {int64(4)}}
|
||||
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
|
||||
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 4)
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
|
||||
|
||||
}
|
||||
}
|
||||
@@ -588,6 +593,9 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ns := tree.NewTestNodeStore()
|
||||
pb := pool.NewBuffPool()
|
||||
logChunkSize := uint8(1)
|
||||
distanceType := vector.DistanceL2Squared{}
|
||||
flusher := ProximityFlusher{logChunkSize: logChunkSize, distanceType: distanceType}
|
||||
keyRows1 := [][]interface{}{
|
||||
{"[0.0, 1.0]"},
|
||||
{"[3.0, 4.0]"},
|
||||
@@ -599,7 +607,6 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
valueRows1 := [][]interface{}{{int64(1)}, {int64(2)}, {int64(3)}, {int64(4)}}
|
||||
values1 := buildTuples(t, ctx, ns, pb, testValDesc, valueRows1)
|
||||
|
||||
logChunkSize := uint8(1)
|
||||
m1 := createAndValidateProximityMap(t, ctx, ns, testKeyDesc, keys1, testValDesc, values1, logChunkSize)
|
||||
|
||||
mutableMap := newProximityMutableMap(m1)
|
||||
@@ -616,7 +623,7 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
err := mutableMap.Put(ctx, nextKey, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMap, err := ProximityFlusher{logChunkSize: logChunkSize}.Map(ctx, mutableMap)
|
||||
newMap, err := flusher.Map(ctx, mutableMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
combinedKeyRows := [][]interface{}{
|
||||
@@ -628,7 +635,7 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
combinedValueRows := [][]interface{}{{int64(2)}, {int64(3)}, {int64(4)}}
|
||||
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
|
||||
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 3)
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
|
||||
|
||||
}
|
||||
|
||||
@@ -640,7 +647,7 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
err := mutableMap.Put(ctx, nextKey, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
newMap, err := ProximityFlusher{}.Map(ctx, mutableMap)
|
||||
newMap, err := flusher.Map(ctx, mutableMap)
|
||||
require.NoError(t, err)
|
||||
|
||||
combinedKeyRows := [][]interface{}{
|
||||
@@ -651,7 +658,7 @@ func TestIncrementalDeletes(t *testing.T) {
|
||||
combinedValueRows := [][]interface{}{{int64(2)}, {int64(4)}}
|
||||
combinedValues := buildTuples(t, ctx, ns, pb, testValDesc, combinedValueRows)
|
||||
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, 2)
|
||||
validateProximityMap(t, ctx, ns, &newMap, testKeyDesc, testValDesc, combinedKeys, combinedValues, logChunkSize)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ type ProximityMutableMap = GenericMutableMap[ProximityMap, tree.ProximityMap[val
|
||||
|
||||
type ProximityFlusher struct {
|
||||
logChunkSize uint8
|
||||
distanceType vector.DistanceType
|
||||
}
|
||||
|
||||
var _ MutableMapFlusher[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]] = ProximityFlusher{}
|
||||
@@ -410,7 +411,7 @@ func (f ProximityFlusher) rebuildNode(ctx context.Context, ns tree.NodeStore, no
|
||||
}
|
||||
|
||||
func (f ProximityFlusher) GetDefaultSerializer(ctx context.Context, mutableMap *GenericMutableMap[ProximityMap, tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc]]) message.Serializer {
|
||||
return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize)
|
||||
return message.NewVectorIndexSerializer(mutableMap.NodeStore().Pool(), f.logChunkSize, f.distanceType)
|
||||
}
|
||||
|
||||
// newMutableMap returns a new MutableMap.
|
||||
@@ -420,7 +421,7 @@ func newProximityMutableMap(m ProximityMap) *ProximityMutableMap {
|
||||
keyDesc: m.keyDesc,
|
||||
valDesc: m.valDesc,
|
||||
maxPending: defaultMaxPending,
|
||||
flusher: ProximityFlusher{logChunkSize: m.logChunkSize},
|
||||
flusher: ProximityFlusher{logChunkSize: m.logChunkSize, distanceType: m.tuples.DistanceType},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +443,7 @@ func (f ProximityFlusher) MapInterface(ctx context.Context, mut *ProximityMutabl
|
||||
|
||||
// TreeMap materializes all pending and applied mutations in the MutableMap.
|
||||
func (f ProximityFlusher) TreeMap(ctx context.Context, mut *ProximityMutableMap) (tree.ProximityMap[val.Tuple, val.Tuple, val.TupleDesc], error) {
|
||||
s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize)
|
||||
s := message.NewVectorIndexSerializer(mut.NodeStore().Pool(), f.logChunkSize, f.distanceType)
|
||||
return mut.flushWithSerializer(ctx, s)
|
||||
}
|
||||
|
||||
@@ -453,8 +454,9 @@ func (f ProximityFlusher) Map(ctx context.Context, mut *ProximityMutableMap) (Pr
|
||||
return ProximityMap{}, err
|
||||
}
|
||||
return ProximityMap{
|
||||
tuples: treeMap,
|
||||
keyDesc: mut.keyDesc,
|
||||
valDesc: mut.valDesc,
|
||||
tuples: treeMap,
|
||||
keyDesc: mut.keyDesc,
|
||||
valDesc: mut.valDesc,
|
||||
logChunkSize: f.logChunkSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user