From 69de83a757623a7052d461a76d47e6204bafe719 Mon Sep 17 00:00:00 2001 From: Brian Hendriks Date: Tue, 16 Feb 2021 17:28:20 -0800 Subject: [PATCH] Fixed worker pool used for index lookups (#1304) --- .../sysbench_scripts/lua/local_sysbench.sh | 1 + .../doltcore/sqle/async_indexed_lookups.go | 110 ++++++++++ go/libraries/doltcore/sqle/dolt_index_test.go | 189 +++++++++++++++++- go/libraries/doltcore/sqle/index_row_iter.go | 158 ++++++++------- go/libraries/doltcore/sqle/sqlinsert_test.go | 48 ++--- go/libraries/doltcore/sqle/sqlreplace_test.go | 24 +-- .../doltcore/sqle/table_editor_test.go | 13 +- go/libraries/utils/async/ring_buffer.go | 173 ++++++++++++++++ go/libraries/utils/async/ring_buffer_test.go | 163 +++++++++++++++ 9 files changed, 759 insertions(+), 120 deletions(-) create mode 100644 benchmark/perf_tools/sysbench_scripts/lua/local_sysbench.sh create mode 100644 go/libraries/doltcore/sqle/async_indexed_lookups.go create mode 100644 go/libraries/utils/async/ring_buffer.go create mode 100644 go/libraries/utils/async/ring_buffer_test.go diff --git a/benchmark/perf_tools/sysbench_scripts/lua/local_sysbench.sh b/benchmark/perf_tools/sysbench_scripts/lua/local_sysbench.sh new file mode 100644 index 0000000000..e58cb23fe2 --- /dev/null +++ b/benchmark/perf_tools/sysbench_scripts/lua/local_sysbench.sh @@ -0,0 +1 @@ +sysbench --db-ps-mode=disable --rand-type=uniform --rand-seed=1 --percentile=50 --mysql-host=127.0.0.1 --mysql-user=root $@ diff --git a/go/libraries/doltcore/sqle/async_indexed_lookups.go b/go/libraries/doltcore/sqle/async_indexed_lookups.go new file mode 100644 index 0000000000..1946004169 --- /dev/null +++ b/go/libraries/doltcore/sqle/async_indexed_lookups.go @@ -0,0 +1,110 @@ +// Copyright 2020 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqle + +import ( + "context" + "fmt" + "runtime" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/utils/async" + "github.com/dolthub/dolt/go/store/types" +) + +type lookupResult struct { + idx uint64 + r sql.Row + err error +} + +// toLookup represents an table lookup that should be performed by one of the global asyncLookups instance's worker routines +type toLookup struct { + idx uint64 + t types.Tuple + tupleToRow func(types.Tuple) (sql.Row, error) + resBuf *async.RingBuffer +} + +// asyncLookups is a pool of worker routines reading from a channel doing table lookups +type asyncLookups struct { + ctx context.Context + toLookupCh chan toLookup +} + +// newAsyncLookups kicks off a number of worker routines and creates a channel for sending lookups to workers. The +// routines live for the life of the program +func newAsyncLookups(bufferSize int) *asyncLookups { + toLookupCh := make(chan toLookup, bufferSize) + art := &asyncLookups{toLookupCh: toLookupCh} + + return art +} + +func (art *asyncLookups) start(numWorkers int) { + for i := 0; i < numWorkers; i++ { + go func() { + art.workerFunc() + }() + } +} + +func (art *asyncLookups) workerFunc() { + f := func() { + var curr toLookup + var ok bool + + defer func() { + if r := recover(); r != nil { + // Attempt to write a failure to the channel and discard any additional errors + if err, ok := r.(error); ok { + _ = curr.resBuf.Push(lookupResult{idx: curr.idx, r: nil, err: err}) + } else { + _ = curr.resBuf.Push(lookupResult{idx: curr.idx, r: nil, err: fmt.Errorf("%v", r)}) + } + } + + // if the channel used for lookups is closed then fail spectacularly + if !ok { + panic("toLookup channel closed. All lookups will fail and will result in a deadlock") + } + }() + + for { + curr, ok = <-art.toLookupCh + + if !ok { + break + } + + r, err := curr.tupleToRow(curr.t) + _ = curr.resBuf.Push(lookupResult{idx: curr.idx, r: r, err: err}) + } + } + + // these routines will run forever unless f is allowed to panic which only happens when the lookup channel is closed + for { + f() + } +} + +// lookups is a global asyncLookups instance which is used by the indexLookupRowIterAdapter +var lookups *asyncLookups + +func init() { + lookups = newAsyncLookups(runtime.NumCPU() * 256) + lookups.start(runtime.NumCPU()) +} diff --git a/go/libraries/doltcore/sqle/dolt_index_test.go b/go/libraries/doltcore/sqle/dolt_index_test.go index 1dd21f71eb..e06dbcf451 100644 --- a/go/libraries/doltcore/sqle/dolt_index_test.go +++ b/go/libraries/doltcore/sqle/dolt_index_test.go @@ -16,13 +16,15 @@ package sqle import ( "context" + "errors" "fmt" "io" + "sort" + "strings" "testing" "time" "github.com/dolthub/go-mysql-server/sql" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" @@ -932,7 +934,7 @@ func TestDoltIndexBetween(t *testing.T) { } require.Equal(t, io.EOF, err) - assert.Equal(t, expectedRows, readRows) + requireUnorderedRowsEqual(t, expectedRows, readRows) indexLookup, err = index.DescendRange(test.lessThanOrEqual, test.greaterThanOrEqual) require.NoError(t, err) @@ -947,11 +949,190 @@ func TestDoltIndexBetween(t *testing.T) { } require.Equal(t, io.EOF, err) - assert.Equal(t, expectedRows, readRows) + requireUnorderedRowsEqual(t, expectedRows, readRows) }) } } +type rowSlice struct { + rows []sql.Row + sortErr error +} + +func (r *rowSlice) setSortErr(err error) { + if err == nil || r.sortErr != nil { + return + } + + r.sortErr = err +} + +func (r *rowSlice) Len() int { + return len(r.rows) +} + +func (r *rowSlice) Less(i, j int) bool { + r1 := r.rows[i] + r2 := r.rows[j] + + longerLen := len(r1) + if len(r2) > longerLen { + longerLen = len(r2) + } + + for pos := 0; pos < longerLen; pos++ { + if pos == len(r1) { + return true + } + + if pos == len(r2) { + return false + } + + c1, c2 := r1[pos], r2[pos] + + var cmp int + var err error + switch typedVal := c1.(type) { + case int: + cmp, err = signedCompare(int64(typedVal), c2) + case int64: + cmp, err = signedCompare(typedVal, c2) + case int32: + cmp, err = signedCompare(int64(typedVal), c2) + case int16: + cmp, err = signedCompare(int64(typedVal), c2) + case int8: + cmp, err = signedCompare(int64(typedVal), c2) + + case uint: + cmp, err = unsignedCompare(uint64(typedVal), c2) + case uint64: + cmp, err = unsignedCompare(typedVal, c2) + case uint32: + cmp, err = unsignedCompare(uint64(typedVal), c2) + case uint16: + cmp, err = unsignedCompare(uint64(typedVal), c2) + case uint8: + cmp, err = unsignedCompare(uint64(typedVal), c2) + + case float64: + cmp, err = floatCompare(float64(typedVal), c2) + case float32: + cmp, err = floatCompare(float64(typedVal), c2) + + case string: + cmp, err = stringCompare(typedVal, c2) + + default: + panic("not implemented please add") + } + + if err != nil { + r.setSortErr(err) + return false + } + + if cmp != 0 { + return cmp < 0 + } + } + + // equal + return false +} + +func signedCompare(n1 int64, c interface{}) (int, error) { + var n2 int64 + switch typedVal := c.(type) { + case int: + n2 = int64(typedVal) + case int64: + n2 = typedVal + case int32: + n2 = int64(typedVal) + case int16: + n2 = int64(typedVal) + case int8: + n2 = int64(typedVal) + default: + return 0, errors.New("comparing rows with different schemas") + } + + return int(n1 - n2), nil +} + +func unsignedCompare(n1 uint64, c interface{}) (int, error) { + var n2 uint64 + switch typedVal := c.(type) { + case uint: + n2 = uint64(typedVal) + case uint64: + n2 = typedVal + case uint32: + n2 = uint64(typedVal) + case uint16: + n2 = uint64(typedVal) + case uint8: + n2 = uint64(typedVal) + default: + return 0, errors.New("comparing rows with different schemas") + } + + if n1 == n2 { + return 0, nil + } else if n1 < n2 { + return -1, nil + } else { + return 1, nil + } +} + +func floatCompare(n1 float64, c interface{}) (int, error) { + var n2 float64 + switch typedVal := c.(type) { + case float32: + n2 = float64(typedVal) + case float64: + n2 = typedVal + default: + return 0, errors.New("comparing rows with different schemas") + } + + if n1 == n2 { + return 0, nil + } else if n1 < n2 { + return -1, nil + } else { + return 1, nil + } +} + +func stringCompare(s1 string, c interface{}) (int, error) { + s2, ok := c.(string) + if !ok { + return 0, errors.New("comparing rows with different schemas") + } + + return strings.Compare(s1, s2), nil +} + +func (r *rowSlice) Swap(i, j int) { + r.rows[i], r.rows[j] = r.rows[j], r.rows[i] +} + +func requireUnorderedRowsEqual(t *testing.T, rows1, rows2 []sql.Row) { + slice1 := &rowSlice{rows: rows1} + sort.Stable(slice1) + require.NoError(t, slice1.sortErr) + + slice2 := &rowSlice{rows: rows2} + sort.Stable(slice2) + require.NoError(t, slice2.sortErr) + + require.Equal(t, rows1, rows2) +} + func testDoltIndex(t *testing.T, keys []interface{}, expectedRows []sql.Row, indexLookupFn func(keys ...interface{}) (sql.IndexLookup, error)) { indexLookup, err := indexLookupFn(keys...) require.NoError(t, err) @@ -967,7 +1148,7 @@ func testDoltIndex(t *testing.T, keys []interface{}, expectedRows []sql.Row, ind } require.Equal(t, io.EOF, err) - assert.Equal(t, convertSqlRowToInt64(expectedRows), readRows) + requireUnorderedRowsEqual(t, convertSqlRowToInt64(expectedRows), readRows) } func doltIndexSetup(t *testing.T) map[string]DoltIndex { diff --git a/go/libraries/doltcore/sqle/index_row_iter.go b/go/libraries/doltcore/sqle/index_row_iter.go index 77ed95fe17..882efe5198 100644 --- a/go/libraries/doltcore/sqle/index_row_iter.go +++ b/go/libraries/doltcore/sqle/index_row_iter.go @@ -17,7 +17,7 @@ package sqle import ( "context" "io" - "runtime" + "sync" "github.com/dolthub/go-mysql-server/sql" @@ -27,20 +27,27 @@ import ( "github.com/dolthub/dolt/go/store/types" ) +const ( + ringBufferAllocSize = 1024 +) + +var resultBufferPool = &sync.Pool{ + New: func() interface{} { + return async.NewRingBuffer(ringBufferAllocSize) + }, +} + type indexLookupRowIterAdapter struct { idx DoltIndex keyIter nomsKeyIter pkTags map[uint64]int conv *KVToSqlRowConverter ctx *sql.Context - rowChan chan sql.Row - err error - buffer []sql.Row -} -type keyPos struct { - key types.Tuple - position int + read uint64 + count uint64 + + resultBuf *async.RingBuffer } // NewIndexLookupRowIterAdapter returns a new indexLookupRowIterAdapter. @@ -52,81 +59,87 @@ func NewIndexLookupRowIterAdapter(ctx *sql.Context, idx DoltIndex, keyIter nomsK cols := idx.Schema().GetAllCols().GetColumns() conv := NewKVToSqlRowConverterForCols(idx.IndexRowData().Format(), cols) + resBuf := resultBufferPool.Get().(*async.RingBuffer) + resBuf.Reset() iter := &indexLookupRowIterAdapter{ - idx: idx, - keyIter: keyIter, - conv: conv, - pkTags: pkTags, - ctx: ctx, - rowChan: make(chan sql.Row, runtime.NumCPU()*10), - buffer: make([]sql.Row, runtime.NumCPU()*5), + idx: idx, + keyIter: keyIter, + conv: conv, + pkTags: pkTags, + ctx: ctx, + resultBuf: resBuf, } - go iter.queueRows() + + go iter.queueRows(ctx) return iter } // Next returns the next row from the iterator. func (i *indexLookupRowIterAdapter) Next() (sql.Row, error) { - r, ok := <-i.rowChan - if !ok { // Only closes when we are finished iterating over the keys or an error has occurred. - if i.err != nil { - return nil, i.err + for i.count == 0 || i.read < i.count { + item, err := i.resultBuf.Pop() + + if err != nil { + return nil, err } - return nil, io.EOF + + res := item.(lookupResult) + + i.read++ + if res.err != nil { + if res.err == io.EOF { + i.count = res.idx + continue + } + + return nil, res.err + } + + return res.r, res.err } - return r, nil + + return nil, io.EOF } -func (*indexLookupRowIterAdapter) Close() error { +func (i *indexLookupRowIterAdapter) Close() error { + resultBufferPool.Put(i.resultBuf) return nil } -// queueRows reads each key from the key iterator and runs a goroutine for each logical processor to process the keys. -func (i *indexLookupRowIterAdapter) queueRows() { - defer close(i.rowChan) - exec := async.NewActionExecutor(i.ctx, i.processKey, uint32(runtime.NumCPU()), 0) +// queueRows reads each key from the key iterator and writes it to lookups.toLookupCh which manages a pool of worker +// routines which will process the requests in parallel. +func (i *indexLookupRowIterAdapter) queueRows(ctx context.Context) { + for idx := uint64(1); ; idx++ { + indexKey, err := i.keyIter.ReadKey(i.ctx) - var err error - for { - shouldBreak := false - pos := 0 - for ; pos < len(i.buffer); pos++ { - var indexKey types.Tuple - indexKey, err = i.keyIter.ReadKey(i.ctx) - if err != nil { - break - } - exec.Execute(keyPos{ - key: indexKey, - position: pos, + if err != nil { + i.resultBuf.Push(lookupResult{ + idx: idx, + r: nil, + err: err, }) + + return } - if err != nil { - if err == io.EOF { - shouldBreak = true - } else { - i.err = err - return - } + + lookup := toLookup{ + idx: idx, + t: indexKey, + tupleToRow: i.processKey, + resBuf: i.resultBuf, } - i.err = exec.WaitForEmpty() - if err != nil { - if err == io.EOF { - shouldBreak = true - } else { - i.err = err - return - } - } - for idx, r := range i.buffer { - if idx == pos { - break - } - i.rowChan <- r - } - if shouldBreak { - break + + select { + case lookups.toLookupCh <- lookup: + case <-ctx.Done(): + i.resultBuf.Push(lookupResult{ + idx: idx, + r: nil, + err: ctx.Err(), + }) + + return } } } @@ -174,32 +187,29 @@ func (i *indexLookupRowIterAdapter) indexKeyToTableKey(nbf *types.NomsBinFormat, } // processKey is called within queueRows and processes each key, sending the resulting row to the row channel. -func (i *indexLookupRowIterAdapter) processKey(_ context.Context, valInt interface{}) error { - val := valInt.(keyPos) - +func (i *indexLookupRowIterAdapter) processKey(indexKey types.Tuple) (sql.Row, error) { tableData := i.idx.TableData() - pkTupleVal, err := i.indexKeyToTableKey(tableData.Format(), val.key) + pkTupleVal, err := i.indexKeyToTableKey(tableData.Format(), indexKey) if err != nil { - return err + return nil, err } fieldsVal, ok, err := tableData.MaybeGetTuple(i.ctx, pkTupleVal) if err != nil { - return err + return nil, err } if !ok { - return nil + return nil, nil } sqlRow, err := i.conv.ConvertKVTuplesToSqlRow(pkTupleVal, fieldsVal) if err != nil { - return err + return nil, err } - i.buffer[val.position] = sqlRow - return nil + return sqlRow, nil } type coveringIndexRowIterAdapter struct { diff --git a/go/libraries/doltcore/sqle/sqlinsert_test.go b/go/libraries/doltcore/sqle/sqlinsert_test.go index 32343449ee..f1a036dc48 100644 --- a/go/libraries/doltcore/sqle/sqlinsert_test.go +++ b/go/libraries/doltcore/sqle/sqlinsert_test.go @@ -66,7 +66,7 @@ var BasicInsertTests = []InsertTest{ { Name: "insert no columns", InsertQuery: "insert into people values (2, 'Bart', 'Simpson', false, 10, 9, '00000000-0000-0000-0000-000000000002', 222)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, @@ -83,14 +83,14 @@ var BasicInsertTests = []InsertTest{ { Name: "insert full columns", InsertQuery: "insert into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (2, 'Bart', 'Simpson', false, 10, 9, '00000000-0000-0000-0000-000000000002', 222)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "insert full columns mixed order", InsertQuery: "insert into people (num_episodes, uuid, rating, age, is_married, last_name, first_name, id) values (222, '00000000-0000-0000-0000-000000000002', 9, 10, false, 'Simpson', 'Bart', 2)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, @@ -98,21 +98,21 @@ var BasicInsertTests = []InsertTest{ Name: "insert full columns negative values", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (-7, "Maggie", "Simpson", false, -1, -5.1, '00000000-0000-0000-0000-000000000005', 677)`, - SelectQuery: "select * from people where id = -7", + SelectQuery: "select * from people where id = -7 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, NewPeopleRowWithOptionalFields(-7, "Maggie", "Simpson", false, -1, -5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "insert full columns null values", InsertQuery: "insert into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (2, 'Bart', 'Simpson', null, null, null, null, null)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(CompressSchema(PeopleTestSchema), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson"))), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "insert partial columns", InsertQuery: "insert into people (id, first_name, last_name) values (2, 'Bart', 'Simpson')", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson")), @@ -122,7 +122,7 @@ var BasicInsertTests = []InsertTest{ { Name: "insert partial columns mixed order", InsertQuery: "insert into people (last_name, first_name, id) values ('Simpson', 'Bart', 2)", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson")), @@ -157,7 +157,7 @@ var BasicInsertTests = []InsertTest{ { Name: "insert partial columns functions", InsertQuery: "insert into people (id, first_name, last_name) values (2, UPPER('Bart'), 'Simpson')", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("BART"), types.String("Simpson")), @@ -183,7 +183,7 @@ var BasicInsertTests = []InsertTest{ (9, "Jacqueline", "Bouvier", true, 80, 2), (10, "Patty", "Bouvier", false, 40, 7), (11, "Selma", "Bouvier", false, 40, 7)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6 ORDER BY id", ExpectedRows: ToSqlRows(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating"), NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1), NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5), @@ -199,7 +199,7 @@ var BasicInsertTests = []InsertTest{ InsertQuery: `insert ignore into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", null, false, 1, 5.1), (8, "Milhouse", "Van Houten", false, 8, 3.5)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5)), ExpectedSchema: NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind, "is_married", types.BoolKind, "age", types.IntKind, "rating", types.FloatKind), @@ -224,7 +224,7 @@ var BasicInsertTests = []InsertTest{ InsertQuery: `insert ignore into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", false, 1, 5.1), (7, "Milhouse", "Van Houten", false, 8, 3.5)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1)), ExpectedSchema: NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind, "is_married", types.BoolKind, "age", types.IntKind, "rating", types.FloatKind), @@ -252,7 +252,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch int -> string", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", 100, false, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("100"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -263,7 +263,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch int -> bool", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", 1, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(true), types.Int(1), types.Float(5.1)), @@ -280,7 +280,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch string -> int", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values ("7", "Maggie", "Simpson", false, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -291,7 +291,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch string -> float", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", false, 1, "5.1")`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -302,7 +302,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch string -> uint", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, num_episodes) values (7, "Maggie", "Simpson", false, 1, "100")`, - SelectQuery: "select id, first_name, last_name, is_married, age, num_episodes from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, num_episodes from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "num_episodes")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Uint(100)), @@ -319,7 +319,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch float -> string", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, 8.1, "Simpson", false, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("8.1"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -330,7 +330,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch float -> bool", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", 0.5, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -341,7 +341,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch float -> int", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", false, 1.0, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -352,7 +352,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch bool -> int", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (true, "Maggie", "Simpson", false, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 1", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 1 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(1), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -363,7 +363,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch bool -> float", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, "Maggie", "Simpson", false, 1, true)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("Maggie"), types.String("Simpson"), types.Bool(false), types.Int(1), types.Float(1.0)), @@ -374,7 +374,7 @@ var BasicInsertTests = []InsertTest{ Name: "type mismatch bool -> string", InsertQuery: `insert into people (id, first_name, last_name, is_married, age, rating) values (7, true, "Simpson", false, 1, 5.1)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id = 7 ORDER BY id", ExpectedRows: ToSqlRows( CompressSchema(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating")), NewResultSetRow(types.Int(7), types.String("true"), types.String("Simpson" /*"West"*/), types.Bool(false), types.Int(1), types.Float(5.1)), @@ -417,7 +417,7 @@ var systemTableInsertTests = []InsertTest{ types.String("select 2+2 from dual"), types.String("description"))), InsertQuery: "insert into dolt_query_catalog (id, display_order, name, query, description) values ('abc123', 1, 'example', 'select 1+1 from dual', 'description')", - SelectQuery: "select * from dolt_query_catalog", + SelectQuery: "select * from dolt_query_catalog ORDER BY id", ExpectedRows: ToSqlRows(CompressSchema(dtables.DoltQueryCatalogSchema), NewRow(types.String("abc123"), types.Uint(1), types.String("example"), types.String("select 1+1 from dual"), types.String("description")), NewRow(types.String("existingEntry"), types.Uint(2), types.String("example"), types.String("select 2+2 from dual"), types.String("description")), @@ -428,7 +428,7 @@ var systemTableInsertTests = []InsertTest{ Name: "insert into dolt_schemas", AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemasTableDoltSchema()), InsertQuery: "insert into dolt_schemas (id, type, name, fragment) values (1, 'view', 'name', 'select 2+2 from dual')", - SelectQuery: "select * from dolt_schemas", + SelectQuery: "select * from dolt_schemas ORDER BY id", ExpectedRows: ToSqlRows(CompressSchema(schemasTableDoltSchema()), NewRow(types.String("view"), types.String("name"), types.String("select 2+2 from dual"), types.Int(1)), ), diff --git a/go/libraries/doltcore/sqle/sqlreplace_test.go b/go/libraries/doltcore/sqle/sqlreplace_test.go index a3a1dff025..63078ec065 100644 --- a/go/libraries/doltcore/sqle/sqlreplace_test.go +++ b/go/libraries/doltcore/sqle/sqlreplace_test.go @@ -58,7 +58,7 @@ var BasicReplaceTests = []ReplaceTest{ { Name: "replace no columns", ReplaceQuery: "replace into people values (2, 'Bart', 'Simpson', false, 10, 9, '00000000-0000-0000-0000-000000000002', 222)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, @@ -66,7 +66,7 @@ var BasicReplaceTests = []ReplaceTest{ Name: "replace set", ReplaceQuery: "replace into people set id = 2, first_name = 'Bart', last_name = 'Simpson'," + "is_married = false, age = 10, rating = 9, uuid = '00000000-0000-0000-0000-000000000002', num_episodes = 222", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, @@ -83,14 +83,14 @@ var BasicReplaceTests = []ReplaceTest{ { Name: "replace full columns", ReplaceQuery: "replace into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (2, 'Bart', 'Simpson', false, 10, 9, '00000000-0000-0000-0000-000000000002', 222)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "replace full columns mixed order", ReplaceQuery: "replace into people (num_episodes, uuid, rating, age, is_married, last_name, first_name, id) values (222, '00000000-0000-0000-0000-000000000002', 9, 10, false, 'Simpson', 'Bart', 2)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, Bart), ExpectedSchema: CompressSchema(PeopleTestSchema), }, @@ -98,21 +98,21 @@ var BasicReplaceTests = []ReplaceTest{ Name: "replace full columns negative values", ReplaceQuery: `replace into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (-7, "Maggie", "Simpson", false, -1, -5.1, '00000000-0000-0000-0000-000000000005', 677)`, - SelectQuery: "select * from people where id = -7", + SelectQuery: "select * from people where id = -7 ORDER BY id", ExpectedRows: ToSqlRows(PeopleTestSchema, NewPeopleRowWithOptionalFields(-7, "Maggie", "Simpson", false, -1, -5.1, uuid.MustParse("00000000-0000-0000-0000-000000000005"), 677)), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "replace full columns null values", ReplaceQuery: "replace into people (id, first_name, last_name, is_married, age, rating, uuid, num_episodes) values (2, 'Bart', 'Simpson', null, null, null, null, null)", - SelectQuery: "select * from people where id = 2", + SelectQuery: "select * from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows(CompressSchema(PeopleTestSchema), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson"))), ExpectedSchema: CompressSchema(PeopleTestSchema), }, { Name: "replace partial columns", ReplaceQuery: "replace into people (id, first_name, last_name) values (2, 'Bart', 'Simpson')", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson")), @@ -122,7 +122,7 @@ var BasicReplaceTests = []ReplaceTest{ { Name: "replace partial columns mixed order", ReplaceQuery: "replace into people (last_name, first_name, id) values ('Simpson', 'Bart', 2)", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson")), @@ -157,7 +157,7 @@ var BasicReplaceTests = []ReplaceTest{ { Name: "replace partial columns functions", ReplaceQuery: "replace into people (id, first_name, last_name) values (2, UPPER('Bart'), 'Simpson')", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("BART"), types.String("Simpson")), @@ -183,7 +183,7 @@ var BasicReplaceTests = []ReplaceTest{ (9, "Jacqueline", "Bouvier", true, 80, 2), (10, "Patty", "Bouvier", false, 40, 7), (11, "Selma", "Bouvier", false, 40, 7)`, - SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6", + SelectQuery: "select id, first_name, last_name, is_married, age, rating from people where id > 6 ORDER BY id", ExpectedRows: ToSqlRows(SubsetSchema(PeopleTestSchema, "id", "first_name", "last_name", "is_married", "age", "rating"), NewPeopleRow(7, "Maggie", "Simpson", false, 1, 5.1), NewPeopleRow(8, "Milhouse", "Van Houten", false, 8, 3.5), @@ -202,7 +202,7 @@ var BasicReplaceTests = []ReplaceTest{ { Name: "replace partial columns multiple rows duplicate", ReplaceQuery: "replace into people (id, first_name, last_name) values (2, 'Bart', 'Simpson'), (2, 'Bart', 'Simpson')", - SelectQuery: "select id, first_name, last_name from people where id = 2", + SelectQuery: "select id, first_name, last_name from people where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson")), @@ -215,7 +215,7 @@ var BasicReplaceTests = []ReplaceTest{ NewSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind, "num", types.IntKind), NewRow(types.Int(2), types.String("Bart"), types.String("Simpson"), types.Int(44))), ReplaceQuery: "replace into temppeople (id, first_name, last_name, num) values (2, 'Bart', 'Simpson', 88)", - SelectQuery: "select id, first_name, last_name, num from temppeople where id = 2", + SelectQuery: "select id, first_name, last_name, num from temppeople where id = 2 ORDER BY id", ExpectedRows: ToSqlRows( NewResultSetSchema("id", types.IntKind, "first_name", types.StringKind, "last_name", types.StringKind, "num", types.IntKind), NewResultSetRow(types.Int(2), types.String("Bart"), types.String("Simpson"), types.Int(88))), diff --git a/go/libraries/doltcore/sqle/table_editor_test.go b/go/libraries/doltcore/sqle/table_editor_test.go index 73f199f3a4..be07ea53d6 100644 --- a/go/libraries/doltcore/sqle/table_editor_test.go +++ b/go/libraries/doltcore/sqle/table_editor_test.go @@ -68,7 +68,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema))) require.NoError(t, ed.Insert(ctx, r(troyMclure, PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10", + selectQuery: "select * from people where id >= 10 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, edna, krusty, smithers, ralph, martin, skinner, fatTony, troyMclure, ), @@ -80,7 +80,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema))) require.NoError(t, ed.Delete(ctx, r(edna, PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10", + selectQuery: "select * from people where id >= 10 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, krusty, ), @@ -94,7 +94,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Insert(ctx, r(fatTony, PeopleTestSchema))) require.NoError(t, ed.Delete(ctx, r(Homer, PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10 or id = 0", + selectQuery: "select * from people where id >= 10 or id = 0 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, krusty, fatTony, ), @@ -106,7 +106,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema))) require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(PeopleTestSchema, edna, AgeTag, 1), PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10", + selectQuery: "select * from people where id >= 10 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, MutateRow(PeopleTestSchema, edna, AgeTag, 1), krusty, @@ -126,7 +126,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Delete(ctx, r(ralph, PeopleTestSchema))) require.NoError(t, ed.Insert(ctx, r(ralph, PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10", + selectQuery: "select * from people where id >= 10 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, MutateRow(PeopleTestSchema, edna, AgeTag, 1), krusty, @@ -141,7 +141,7 @@ func TestTableEditor(t *testing.T) { require.NoError(t, ed.Insert(ctx, r(krusty, PeopleTestSchema))) require.NoError(t, ed.Update(ctx, r(edna, PeopleTestSchema), r(MutateRow(PeopleTestSchema, edna, IdTag, 30), PeopleTestSchema))) }, - selectQuery: "select * from people where id >= 10", + selectQuery: "select * from people where id >= 10 ORDER BY id", expectedRows: ToSqlRows(PeopleTestSchema, krusty, MutateRow(PeopleTestSchema, edna, IdTag, 30), @@ -183,6 +183,7 @@ func TestTableEditor(t *testing.T) { actualRows, _, err := executeSelect(context.Background(), dEnv, root, test.selectQuery) require.NoError(t, err) + assert.Equal(t, test.expectedRows, actualRows) }) } diff --git a/go/libraries/utils/async/ring_buffer.go b/go/libraries/utils/async/ring_buffer.go new file mode 100644 index 0000000000..d70b5a51b6 --- /dev/null +++ b/go/libraries/utils/async/ring_buffer.go @@ -0,0 +1,173 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package async + +import ( + "io" + "os" + "sync" +) + +// RingBuffer is a dynamically sized ring buffer that is thread safe +type RingBuffer struct { + cond *sync.Cond + allocSize int + + closed bool + headPos int + tailPos int + headSlice int + tailSlice int + + items [][]interface{} +} + +// NewRingBuffer creates a new RingBuffer instance +func NewRingBuffer(allocSize int) *RingBuffer { + itemBuffer := make([]interface{}, allocSize*2) + items := [][]interface{}{itemBuffer[:allocSize], itemBuffer[allocSize:]} + + return &RingBuffer{ + cond: sync.NewCond(&sync.Mutex{}), + allocSize: allocSize, + items: items, + } +} + +// Reset clears a ring buffer so that it can be reused +func (rb *RingBuffer) Reset() { + rb.cond.L.Lock() + defer rb.cond.L.Unlock() + + rb.closed = false + rb.headPos = 0 + rb.tailPos = 0 + rb.headSlice = 0 + rb.tailSlice = 0 + + for i := 0; i < len(rb.items); i++ { + for j := 0; j < len(rb.items[i]); j++ { + rb.items[i][j] = nil + } + } +} + +// Close closes a RingBuffer so that no new items can be pushed onto it. Items that are already in the buffer can still +// be read via Pop and TryPop. Close will broadcast to all go routines waiting inside Pop +func (rb *RingBuffer) Close() error { + rb.cond.L.Lock() + defer rb.cond.L.Unlock() + + if rb.closed { + return os.ErrClosed + } + + rb.closed = true + rb.cond.Broadcast() + + return nil +} + +// Push will add a new item to the RingBuffer and signal a go routine waiting inside Pop that new data is available +func (rb *RingBuffer) Push(item interface{}) error { + rb.cond.L.Lock() + defer rb.cond.L.Unlock() + + if rb.closed { + return os.ErrClosed + } + + rb.items[rb.headSlice][rb.headPos] = item + rb.headPos++ + + if rb.headPos == rb.allocSize { + numSlices := len(rb.items) + nextSlice := (rb.headSlice + 1) % numSlices + + if nextSlice == rb.tailSlice { + newItems := make([][]interface{}, numSlices+1) + + j := 0 + for i := rb.tailSlice; i < numSlices; i, j = i+1, j+1 { + newItems[j] = rb.items[i] + } + + for i := 0; i < rb.tailSlice; i, j = i+1, j+1 { + newItems[j] = rb.items[i] + } + + newItems[numSlices] = make([]interface{}, rb.allocSize) + + rb.items = newItems + rb.tailSlice = 0 + rb.headSlice = numSlices + } else { + rb.headSlice = nextSlice + } + + rb.headPos = 0 + } + + rb.cond.Signal() + + return nil +} + +// noLockPop is used internally by methods that already hold a lock on the RingBuffer and should never be called directly +func (rb *RingBuffer) noLockPop() (interface{}, bool) { + if rb.tailPos == rb.headPos && rb.tailSlice == rb.headSlice { + return nil, false + } + + item := rb.items[rb.tailSlice][rb.tailPos] + rb.tailPos++ + + if rb.tailPos == rb.allocSize { + rb.tailPos = 0 + rb.tailSlice = (rb.tailSlice + 1) % len(rb.items) + } + + return item, true +} + +// TryPop will return the next item in the RingBuffer. If there are no items available TryPop will return immediately +// with with `ok` set to false. +func (rb *RingBuffer) TryPop() (item interface{}, ok bool) { + rb.cond.L.Lock() + defer rb.cond.L.Unlock() + + return rb.noLockPop() +} + +// Pop will return the next item in the RingBuffer. If there are no items available, Pop will wait until a new item is +// pushed, or the RingBuffer is closed. +func (rb *RingBuffer) Pop() (item interface{}, err error) { + rb.cond.L.Lock() + defer rb.cond.L.Unlock() + + for { + item, ok := rb.noLockPop() + + if ok { + return item, nil + } + + if rb.closed { + return nil, io.EOF + } + + rb.cond.Wait() + } +} diff --git a/go/libraries/utils/async/ring_buffer_test.go b/go/libraries/utils/async/ring_buffer_test.go new file mode 100644 index 0000000000..57d27dafe7 --- /dev/null +++ b/go/libraries/utils/async/ring_buffer_test.go @@ -0,0 +1,163 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package async + +import ( + "fmt" + "io" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSingleThread(t *testing.T) { + tests := []struct { + allocSize int + numItems int + }{ + {128, 127}, + {128, 128}, + {128, 129}, + {1, 1024}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("alloc %d items %d", test.allocSize, test.numItems), func(t *testing.T) { + rb := NewRingBuffer(test.allocSize) + + for i := 0; i < test.numItems; i++ { + err := rb.Push(i) + assert.NoError(t, err) + } + + for i := 0; i < test.numItems; i++ { + item, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, i, item.(int)) + } + + item, ok := rb.TryPop() + assert.Nil(t, item) + assert.False(t, ok) + }) + } +} + +func TestOneProducerOneConsumer(t *testing.T) { + tests := []struct { + allocSize int + numItems int + }{ + {128, 127}, + {128, 128}, + {128, 129}, + {1, 1024}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("alloc %d items %d", test.allocSize, test.numItems), func(t *testing.T) { + rb := NewRingBuffer(test.allocSize) + + go func() { + defer rb.Close() + + for i := 0; i < test.numItems; i++ { + err := rb.Push(i) + assert.NoError(t, err) + } + }() + + for i := 0; i < test.numItems; i++ { + item, err := rb.Pop() + assert.NoError(t, err) + assert.Equal(t, i, item.(int)) + } + + item, err := rb.Pop() + assert.Nil(t, item) + assert.Equal(t, io.EOF, err) + }) + } +} + +func TestNProducersNConsumers(t *testing.T) { + tests := []struct { + producers int + consumers int + allocSize int + itemsPerProducer int + }{ + {2, 8, 128, 127}, + {2, 8, 128, 128}, + {2, 8, 128, 129}, + {2, 8, 1, 1024}, + {8, 2, 1, 1024}, + {8, 8, 1, 1024}, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("producers %d consumers %d alloc %d items per producer %d", test.producers, test.consumers, test.allocSize, test.itemsPerProducer), func(t *testing.T) { + rb := NewRingBuffer(test.allocSize) + + producerGroup := &sync.WaitGroup{} + producerGroup.Add(test.producers) + for i := 0; i < test.producers; i++ { + go func() { + defer producerGroup.Done() + for i := 0; i < test.itemsPerProducer; i++ { + err := rb.Push(i) + assert.NoError(t, err) + } + }() + } + + consumerResults := make([][]int, test.consumers) + consumerGroup := &sync.WaitGroup{} + consumerGroup.Add(test.consumers) + for i := 0; i < test.consumers; i++ { + results := make([]int, test.itemsPerProducer) + consumerResults[i] = results + go func() { + defer consumerGroup.Done() + for { + item, err := rb.Pop() + + if err != nil { + assert.Equal(t, io.EOF, err) + return + } + + results[item.(int)]++ + } + }() + } + + producerGroup.Wait() + err := rb.Close() + assert.NoError(t, err) + consumerGroup.Wait() + + for i := 0; i < test.itemsPerProducer; i++ { + sum := 0 + for j := 0; j < test.consumers; j++ { + sum += consumerResults[j][i] + } + + assert.Equal(t, test.producers, sum) + } + }) + } +}