mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-22 11:29:06 -05:00
Use map.IterRange for iterating over rows (#1192)
This commit is contained in:
@@ -202,17 +202,19 @@ func GetGetFuncForMapIter(mapItr types.MapIterator) func(ctx context.Context) (t
|
||||
// DoltMapIter uses a types.MapIterator to iterate over a types.Map and returns sql.Row instances that it reads and
|
||||
// converts
|
||||
type DoltMapIter struct {
|
||||
ctx context.Context
|
||||
kvGet KVGetFunc
|
||||
conv *KVToSqlRowConverter
|
||||
ctx context.Context
|
||||
kvGet KVGetFunc
|
||||
closeKVGetter func() error
|
||||
conv *KVToSqlRowConverter
|
||||
}
|
||||
|
||||
// NewDoltMapIter returns a new DoltMapIter
|
||||
func NewDoltMapIter(ctx context.Context, keyValGet KVGetFunc, conv *KVToSqlRowConverter) *DoltMapIter {
|
||||
func NewDoltMapIter(ctx context.Context, keyValGet KVGetFunc, closeKVGetter func() error, conv *KVToSqlRowConverter) *DoltMapIter {
|
||||
return &DoltMapIter{
|
||||
ctx: ctx,
|
||||
kvGet: keyValGet,
|
||||
conv: conv,
|
||||
ctx: ctx,
|
||||
kvGet: keyValGet,
|
||||
closeKVGetter: closeKVGetter,
|
||||
conv: conv,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,5 +230,9 @@ func (dmi *DoltMapIter) Next() (sql.Row, error) {
|
||||
}
|
||||
|
||||
func (dmi *DoltMapIter) Close() error {
|
||||
if dmi.closeKVGetter != nil {
|
||||
return dmi.closeKVGetter()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ package sqle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
@@ -105,7 +106,11 @@ func newKeyedRowIter(ctx context.Context, tbl *DoltTable, partition *doltTablePa
|
||||
}
|
||||
|
||||
conv := NewKVToSqlRowConverter(tagToSqlColIdx, cols, len(cols))
|
||||
return NewDoltMapIter(ctx, GetGetFuncForMapIter(mapIter), conv), nil
|
||||
var closer func() error
|
||||
if cl, ok := mapIter.(io.Closer); ok {
|
||||
closer = cl.Close
|
||||
}
|
||||
return NewDoltMapIter(ctx, GetGetFuncForMapIter(mapIter), closer, conv), nil
|
||||
}
|
||||
|
||||
// Next returns the next row in this row iterator, or an io.EOF error if there aren't any more.
|
||||
|
||||
@@ -230,7 +230,7 @@ func (t *DoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
|
||||
numElements := rowData.Len()
|
||||
|
||||
if numElements == 0 {
|
||||
return newDoltTablePartitionIter(rowData, doltTablePartition{0, 1}), nil
|
||||
return newDoltTablePartitionIter(rowData, doltTablePartition{0, 0}), nil
|
||||
}
|
||||
|
||||
maxPartitions := uint64(partitionMultiplier * runtime.NumCPU())
|
||||
@@ -254,10 +254,24 @@ func (t *DoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
|
||||
return newDoltTablePartitionIter(rowData, partitions...), nil
|
||||
}
|
||||
|
||||
type emptyRowIterator struct{}
|
||||
|
||||
func (itr emptyRowIterator) Next() (sql.Row, error) {
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
func (itr emptyRowIterator) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the table rows for the partition given
|
||||
func (t *DoltTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
|
||||
switch typedPartition := partition.(type) {
|
||||
case doltTablePartition:
|
||||
if typedPartition.end == 0 {
|
||||
return emptyRowIterator{}, nil
|
||||
}
|
||||
|
||||
return newRowIterator(t, ctx, &typedPartition)
|
||||
case sqlutil.SinglePartition:
|
||||
return newRowIterator(t, ctx, nil)
|
||||
@@ -509,7 +523,7 @@ func (p doltTablePartition) Key() []byte {
|
||||
// for index = start; index < end. This iterator is not thread safe and should only be used from a single go routine
|
||||
// unless paired with a mutex
|
||||
func (p doltTablePartition) IteratorForPartition(ctx context.Context, m types.Map) (types.MapIterator, error) {
|
||||
return newPartitionIter(ctx, m, p.start, p.end)
|
||||
return m.RangeIterator(ctx, 64, p.start, p.end)
|
||||
}
|
||||
|
||||
type partitionIter struct {
|
||||
|
||||
+27
-1
@@ -419,7 +419,33 @@ type mapIterAllCallback func(key, value Value) error
|
||||
|
||||
func (m Map) IterAll(ctx context.Context, cb mapIterAllCallback) error {
|
||||
var k Value
|
||||
err := iterAll(ctx, m, func(v Value, idx uint64) error {
|
||||
err := iterAll(ctx, m, func(v Value, _ uint64) error {
|
||||
if k != nil {
|
||||
err := cb(k, v)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
k = nil
|
||||
} else {
|
||||
k = v
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
d.PanicIfFalse(k == nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Map) IterRange(ctx context.Context, startIdx, endIdx uint64, cb mapIterAllCallback) error {
|
||||
var k Value
|
||||
_, err := iterRange(ctx, m, startIdx, endIdx, func(v Value) error {
|
||||
if k != nil {
|
||||
err := cb(k, v)
|
||||
|
||||
|
||||
@@ -21,7 +21,12 @@
|
||||
|
||||
package types
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// MapIterator is the interface used by iterators over Noms Maps.
|
||||
type MapIterator interface {
|
||||
@@ -58,3 +63,78 @@ func (mi *mapIterator) Next(ctx context.Context) (k, v Value, err error) {
|
||||
|
||||
return mi.currentKey, mi.currentValue, nil
|
||||
}
|
||||
|
||||
type mapKeyValuePair struct {
|
||||
k Value
|
||||
v Value
|
||||
}
|
||||
|
||||
var errClosed = errors.New("closed")
|
||||
|
||||
type readAheadRangeIter struct {
|
||||
ctx context.Context
|
||||
eg *errgroup.Group
|
||||
kvCh chan mapKeyValuePair
|
||||
}
|
||||
|
||||
func (itr *readAheadRangeIter) Next(context.Context) (Value, Value, error) {
|
||||
select {
|
||||
case kvp, ok := <-itr.kvCh:
|
||||
if !ok {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
return kvp.k, kvp.v, nil
|
||||
|
||||
case <-itr.ctx.Done():
|
||||
err := itr.eg.Wait()
|
||||
|
||||
if err != errClosed {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return nil, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (itr *readAheadRangeIter) Close() error {
|
||||
itr.eg.Go(func() error {
|
||||
return errClosed
|
||||
})
|
||||
|
||||
_ = itr.eg.Wait()
|
||||
close(itr.kvCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Map) RangeIterator(ctx context.Context, readAhead int, startIdx, endIdx uint64) (MapIterator, error) {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
keyValCh := make(chan mapKeyValuePair, readAhead)
|
||||
eg.Go(func() error {
|
||||
err := m.IterRange(ctx, startIdx, endIdx, func(key, value Value) error {
|
||||
kvp := mapKeyValuePair{key, value}
|
||||
select {
|
||||
case keyValCh <- kvp:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send an empty kvp to signify the end of the range
|
||||
kvp := mapKeyValuePair{}
|
||||
select {
|
||||
case keyValCh <- kvp:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
})
|
||||
|
||||
return &readAheadRangeIter{ctx, eg, keyValCh}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user