Use map.IterRange for iterating over rows (#1192)

This commit is contained in:
Brian Hendriks
2021-01-07 15:49:29 -08:00
committed by GitHub
parent e5334fe420
commit 133af01494
5 changed files with 143 additions and 12 deletions
+13 -7
View File
@@ -202,17 +202,19 @@ func GetGetFuncForMapIter(mapItr types.MapIterator) func(ctx context.Context) (t
// DoltMapIter uses a types.MapIterator to iterate over a types.Map and returns sql.Row instances that it reads and
// converts
type DoltMapIter struct {
ctx context.Context
kvGet KVGetFunc
conv *KVToSqlRowConverter
ctx context.Context
kvGet KVGetFunc
closeKVGetter func() error
conv *KVToSqlRowConverter
}
// NewDoltMapIter returns a new DoltMapIter
func NewDoltMapIter(ctx context.Context, keyValGet KVGetFunc, conv *KVToSqlRowConverter) *DoltMapIter {
func NewDoltMapIter(ctx context.Context, keyValGet KVGetFunc, closeKVGetter func() error, conv *KVToSqlRowConverter) *DoltMapIter {
return &DoltMapIter{
ctx: ctx,
kvGet: keyValGet,
conv: conv,
ctx: ctx,
kvGet: keyValGet,
closeKVGetter: closeKVGetter,
conv: conv,
}
}
@@ -228,5 +230,9 @@ func (dmi *DoltMapIter) Next() (sql.Row, error) {
}
func (dmi *DoltMapIter) Close() error {
if dmi.closeKVGetter != nil {
return dmi.closeKVGetter()
}
return nil
}
+6 -1
View File
@@ -16,6 +16,7 @@ package sqle
import (
"context"
"io"
"github.com/dolthub/go-mysql-server/sql"
@@ -105,7 +106,11 @@ func newKeyedRowIter(ctx context.Context, tbl *DoltTable, partition *doltTablePa
}
conv := NewKVToSqlRowConverter(tagToSqlColIdx, cols, len(cols))
return NewDoltMapIter(ctx, GetGetFuncForMapIter(mapIter), conv), nil
var closer func() error
if cl, ok := mapIter.(io.Closer); ok {
closer = cl.Close
}
return NewDoltMapIter(ctx, GetGetFuncForMapIter(mapIter), closer, conv), nil
}
// Next returns the next row in this row iterator, or an io.EOF error if there aren't any more.
+16 -2
View File
@@ -230,7 +230,7 @@ func (t *DoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
numElements := rowData.Len()
if numElements == 0 {
return newDoltTablePartitionIter(rowData, doltTablePartition{0, 1}), nil
return newDoltTablePartitionIter(rowData, doltTablePartition{0, 0}), nil
}
maxPartitions := uint64(partitionMultiplier * runtime.NumCPU())
@@ -254,10 +254,24 @@ func (t *DoltTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
return newDoltTablePartitionIter(rowData, partitions...), nil
}
type emptyRowIterator struct{}
func (itr emptyRowIterator) Next() (sql.Row, error) {
return nil, io.EOF
}
func (itr emptyRowIterator) Close() error {
return nil
}
// Returns the table rows for the partition given
func (t *DoltTable) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.RowIter, error) {
switch typedPartition := partition.(type) {
case doltTablePartition:
if typedPartition.end == 0 {
return emptyRowIterator{}, nil
}
return newRowIterator(t, ctx, &typedPartition)
case sqlutil.SinglePartition:
return newRowIterator(t, ctx, nil)
@@ -509,7 +523,7 @@ func (p doltTablePartition) Key() []byte {
// for index = start; index < end. This iterator is not thread safe and should only be used from a single go routine
// unless paired with a mutex
func (p doltTablePartition) IteratorForPartition(ctx context.Context, m types.Map) (types.MapIterator, error) {
return newPartitionIter(ctx, m, p.start, p.end)
return m.RangeIterator(ctx, 64, p.start, p.end)
}
type partitionIter struct {
+27 -1
View File
@@ -419,7 +419,33 @@ type mapIterAllCallback func(key, value Value) error
func (m Map) IterAll(ctx context.Context, cb mapIterAllCallback) error {
var k Value
err := iterAll(ctx, m, func(v Value, idx uint64) error {
err := iterAll(ctx, m, func(v Value, _ uint64) error {
if k != nil {
err := cb(k, v)
if err != nil {
return err
}
k = nil
} else {
k = v
}
return nil
})
if err != nil {
return err
}
d.PanicIfFalse(k == nil)
return nil
}
func (m Map) IterRange(ctx context.Context, startIdx, endIdx uint64, cb mapIterAllCallback) error {
var k Value
_, err := iterRange(ctx, m, startIdx, endIdx, func(v Value) error {
if k != nil {
err := cb(k, v)
+81 -1
View File
@@ -21,7 +21,12 @@
package types
import "context"
import (
"context"
"errors"
"golang.org/x/sync/errgroup"
)
// MapIterator is the interface used by iterators over Noms Maps.
type MapIterator interface {
@@ -58,3 +63,78 @@ func (mi *mapIterator) Next(ctx context.Context) (k, v Value, err error) {
return mi.currentKey, mi.currentValue, nil
}
type mapKeyValuePair struct {
k Value
v Value
}
var errClosed = errors.New("closed")
type readAheadRangeIter struct {
ctx context.Context
eg *errgroup.Group
kvCh chan mapKeyValuePair
}
func (itr *readAheadRangeIter) Next(context.Context) (Value, Value, error) {
select {
case kvp, ok := <-itr.kvCh:
if !ok {
return nil, nil, nil
}
return kvp.k, kvp.v, nil
case <-itr.ctx.Done():
err := itr.eg.Wait()
if err != errClosed {
return nil, nil, err
}
return nil, nil, nil
}
}
func (itr *readAheadRangeIter) Close() error {
itr.eg.Go(func() error {
return errClosed
})
_ = itr.eg.Wait()
close(itr.kvCh)
return nil
}
func (m Map) RangeIterator(ctx context.Context, readAhead int, startIdx, endIdx uint64) (MapIterator, error) {
eg, ctx := errgroup.WithContext(ctx)
keyValCh := make(chan mapKeyValuePair, readAhead)
eg.Go(func() error {
err := m.IterRange(ctx, startIdx, endIdx, func(key, value Value) error {
kvp := mapKeyValuePair{key, value}
select {
case keyValCh <- kvp:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
if err != nil {
return err
}
// send an empty kvp to signify the end of the range
kvp := mapKeyValuePair{}
select {
case keyValCh <- kvp:
return nil
case <-ctx.Done():
return ctx.Err()
}
})
return &readAheadRangeIter{ctx, eg, keyValCh}, nil
}