mirror of
https://github.com/dolthub/dolt.git
synced 2026-01-26 02:58:44 -06:00
Heavily improved SQL write performance
This commit is contained in:
committed by
Daylon Wilkins
parent
6d930b7933
commit
2fa2b8fd63
@@ -6,13 +6,14 @@ package eventsapi
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
math "math"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
duration "github.com/golang/protobuf/ptypes/duration"
|
||||
timestamp "github.com/golang/protobuf/ptypes/timestamp"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
||||
@@ -5,8 +5,9 @@ package eventsapi
|
||||
|
||||
import (
|
||||
fmt "fmt"
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
math "math"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
||||
@@ -6,11 +6,12 @@ package remotesapi
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
math "math"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
||||
@@ -6,11 +6,12 @@ package remotesapi
|
||||
import (
|
||||
context "context"
|
||||
fmt "fmt"
|
||||
math "math"
|
||||
|
||||
proto "github.com/golang/protobuf/proto"
|
||||
grpc "google.golang.org/grpc"
|
||||
codes "google.golang.org/grpc/codes"
|
||||
status "google.golang.org/grpc/status"
|
||||
math "math"
|
||||
)
|
||||
|
||||
// Reference imports to suppress errors if they are not otherwise used.
|
||||
|
||||
@@ -19,10 +19,10 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/cmd/dolt/errhand"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/row"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/utils/async"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
@@ -34,15 +34,12 @@ var ErrDuplicatePrimaryKeyFmt = "duplicate primary key given: (%v)"
|
||||
//
|
||||
// This type is thread-safe, and may be used in a multi-threaded environment.
|
||||
type TableEditor struct {
|
||||
t *Table
|
||||
tSch schema.Schema
|
||||
ed *types.MapEditor
|
||||
updated bool // Whether the table has been updated
|
||||
insertedKeys map[hash.Hash]types.Value
|
||||
addedKeys map[hash.Hash]types.Value
|
||||
removedKeys map[hash.Hash]types.Value
|
||||
affectedKeys map[hash.Hash]types.Value
|
||||
indexEds []*IndexEditor
|
||||
t *Table
|
||||
tSch schema.Schema
|
||||
tea *tableEditAccumulator
|
||||
aq *async.ActionExecutor
|
||||
nbf *types.NomsBinFormat
|
||||
indexEds []*IndexEditor
|
||||
|
||||
// This mutex blocks on each operation, so that map reads and updates are serialized
|
||||
writeMutex *sync.Mutex
|
||||
@@ -50,20 +47,52 @@ type TableEditor struct {
|
||||
flushMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
type tableEditAccumulator struct {
|
||||
ed types.EditAccumulator
|
||||
opCount uint64
|
||||
insertedKeys map[hash.Hash]types.Value
|
||||
addedKeys map[hash.Hash]types.Value
|
||||
removedKeys map[hash.Hash]types.Value
|
||||
affectedKeys map[hash.Hash]types.Value
|
||||
}
|
||||
|
||||
const tableEditorMaxOps = 16384
|
||||
|
||||
func NewTableEditor(ctx context.Context, t *Table, tableSch schema.Schema) (*TableEditor, error) {
|
||||
// initialize the mutexes here since they're not reset
|
||||
te := &TableEditor{
|
||||
t: t,
|
||||
tSch: tableSch,
|
||||
tea: newTableEditAcc(t.Format()),
|
||||
nbf: t.Format(),
|
||||
indexEds: make([]*IndexEditor, tableSch.Indexes().Count()),
|
||||
writeMutex: &sync.Mutex{},
|
||||
flushMutex: &sync.RWMutex{},
|
||||
}
|
||||
err := te.reset(ctx, t, tableSch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
te.aq = async.NewActionExecutor(ctx, te.flushEditAccumulator, 1, 1)
|
||||
|
||||
for i, index := range tableSch.Indexes().AllIndexes() {
|
||||
indexData, err := t.GetIndexRowData(ctx, index.Name())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
te.indexEds[i] = NewIndexEditor(index, indexData)
|
||||
}
|
||||
return te, nil
|
||||
}
|
||||
|
||||
func (te *TableEditor) Insert(ctx context.Context, dRow row.Row) error {
|
||||
func newTableEditAcc(nbf *types.NomsBinFormat) *tableEditAccumulator {
|
||||
return &tableEditAccumulator{
|
||||
ed: types.CreateEditAccForMapEdits(nbf),
|
||||
insertedKeys: make(map[hash.Hash]types.Value),
|
||||
addedKeys: make(map[hash.Hash]types.Value),
|
||||
removedKeys: make(map[hash.Hash]types.Value),
|
||||
affectedKeys: make(map[hash.Hash]types.Value),
|
||||
}
|
||||
}
|
||||
|
||||
// InsertRow adds the given row to the table. If the row already exists, use UpdateRow.
|
||||
func (te *TableEditor) InsertRow(ctx context.Context, dRow row.Row) error {
|
||||
defer te.autoFlush()
|
||||
te.flushMutex.RLock()
|
||||
defer te.flushMutex.RUnlock()
|
||||
|
||||
@@ -92,23 +121,35 @@ func (te *TableEditor) Insert(ctx context.Context, dRow row.Row) error {
|
||||
|
||||
// If we've already inserted this key as part of this insert operation, that's an error. Inserting a row that
|
||||
// already exists in the table will be handled in Close().
|
||||
if _, ok := te.addedKeys[keyHash]; ok {
|
||||
if _, ok := te.tea.addedKeys[keyHash]; ok {
|
||||
value, err := types.EncodedValue(ctx, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
|
||||
}
|
||||
te.insertedKeys[keyHash] = key
|
||||
te.addedKeys[keyHash] = key
|
||||
te.affectedKeys[keyHash] = key
|
||||
te.tea.insertedKeys[keyHash] = key
|
||||
te.tea.addedKeys[keyHash] = key
|
||||
te.tea.affectedKeys[keyHash] = key
|
||||
|
||||
te.ed = te.ed.Set(key, dRow.NomsMapValue(te.tSch))
|
||||
te.updated = true
|
||||
te.tea.ed.AddEdit(key, dRow.NomsMapValue(te.tSch))
|
||||
te.tea.opCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *TableEditor) Delete(ctx context.Context, dRow row.Row) error {
|
||||
// DeleteKey removes the given key from the table.
|
||||
func (te *TableEditor) DeleteKey(ctx context.Context, key types.Tuple) error {
|
||||
defer te.autoFlush()
|
||||
te.flushMutex.RLock()
|
||||
defer te.flushMutex.RUnlock()
|
||||
|
||||
return te.delete(key)
|
||||
}
|
||||
|
||||
// DeleteRow removes the given row from the table. This essentially acts as a convenience function for DeleteKey, while
|
||||
// ensuring proper thread safety.
|
||||
func (te *TableEditor) DeleteRow(ctx context.Context, dRow row.Row) error {
|
||||
defer te.autoFlush()
|
||||
te.flushMutex.RLock()
|
||||
defer te.flushMutex.RUnlock()
|
||||
|
||||
@@ -116,25 +157,13 @@ func (te *TableEditor) Delete(ctx context.Context, dRow row.Row) error {
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row key").AddCause(err).Build()
|
||||
}
|
||||
keyHash, err := key.Hash(dRow.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Regarding the lock's position here, refer to the comment in Insert()
|
||||
te.writeMutex.Lock()
|
||||
defer te.writeMutex.Unlock()
|
||||
|
||||
delete(te.addedKeys, keyHash)
|
||||
te.removedKeys[keyHash] = key
|
||||
te.affectedKeys[keyHash] = key
|
||||
|
||||
te.ed = te.ed.Remove(key)
|
||||
te.updated = true
|
||||
return nil
|
||||
return te.delete(key.(types.Tuple))
|
||||
}
|
||||
|
||||
func (te *TableEditor) Update(ctx context.Context, dOldRow row.Row, dNewRow row.Row) error {
|
||||
// UpdateRow takes the current row and new rows, and updates it accordingly.
|
||||
func (te *TableEditor) UpdateRow(ctx context.Context, dOldRow row.Row, dNewRow row.Row) error {
|
||||
defer te.autoFlush()
|
||||
te.flushMutex.RLock()
|
||||
defer te.flushMutex.RUnlock()
|
||||
|
||||
@@ -155,7 +184,7 @@ func (te *TableEditor) Update(ctx context.Context, dOldRow row.Row, dNewRow row.
|
||||
}
|
||||
oldKeyEqualsNewKey := dOldKeyVal.Equals(dNewKeyVal)
|
||||
|
||||
// Regarding the lock's position here, refer to the comment in Insert()
|
||||
// Regarding the lock's position here, refer to the comment in InsertRow
|
||||
te.writeMutex.Lock()
|
||||
defer te.writeMutex.Unlock()
|
||||
|
||||
@@ -167,181 +196,186 @@ func (te *TableEditor) Update(ctx context.Context, dOldRow row.Row, dNewRow row.
|
||||
}
|
||||
|
||||
// If the old value of the primary key we just updated was previously inserted, then we need to remove it now.
|
||||
if _, ok := te.insertedKeys[oldHash]; ok {
|
||||
delete(te.insertedKeys, oldHash)
|
||||
te.ed.Remove(dOldKeyVal)
|
||||
if _, ok := te.tea.insertedKeys[oldHash]; ok {
|
||||
delete(te.tea.insertedKeys, oldHash)
|
||||
te.tea.ed.AddEdit(dOldKeyVal, nil)
|
||||
te.tea.opCount++
|
||||
}
|
||||
|
||||
te.addedKeys[newHash] = dNewKeyVal
|
||||
te.removedKeys[oldHash] = dOldKeyVal
|
||||
te.affectedKeys[oldHash] = dOldKeyVal
|
||||
te.tea.addedKeys[newHash] = dNewKeyVal
|
||||
te.tea.removedKeys[oldHash] = dOldKeyVal
|
||||
te.tea.affectedKeys[oldHash] = dOldKeyVal
|
||||
}
|
||||
|
||||
te.affectedKeys[newHash] = dNewKeyVal
|
||||
te.tea.affectedKeys[newHash] = dNewKeyVal
|
||||
|
||||
te.ed.Set(dNewKeyVal, dNewRow.NomsMapValue(te.tSch))
|
||||
te.updated = true
|
||||
te.tea.ed.AddEdit(dNewKeyVal, dNewRow.NomsMapValue(te.tSch))
|
||||
te.tea.opCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush finalizes all of the changes and returns the updated Table.
|
||||
func (te *TableEditor) Flush(ctx context.Context) (*Table, error) {
|
||||
// Flush finalizes all of the changes made so far.
|
||||
func (te *TableEditor) Flush() {
|
||||
te.flushMutex.Lock()
|
||||
defer te.flushMutex.Unlock()
|
||||
|
||||
if !te.updated {
|
||||
return te.t, nil
|
||||
if te.tea.opCount > 0 {
|
||||
te.aq.Execute(te.tea)
|
||||
te.tea = newTableEditAcc(te.nbf)
|
||||
}
|
||||
}
|
||||
|
||||
// Table returns a Table based on the edits given, if any. If Flush() was not called prior, it will be called here.
|
||||
func (te *TableEditor) Table() (*Table, error) {
|
||||
te.Flush()
|
||||
err := te.aq.WaitForEmpty()
|
||||
return te.t, err
|
||||
}
|
||||
|
||||
// autoFlush is called at the end of every write call (after all locks have been released) and checks if we need to
|
||||
// automatically flush the edits.
|
||||
func (te *TableEditor) autoFlush() {
|
||||
te.flushMutex.RLock()
|
||||
runFlush := te.tea.opCount >= tableEditorMaxOps
|
||||
te.flushMutex.RUnlock()
|
||||
|
||||
if runFlush {
|
||||
te.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (te *TableEditor) delete(key types.Tuple) error {
|
||||
keyHash, err := key.Hash(te.t.Format())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
te.writeMutex.Lock()
|
||||
defer te.writeMutex.Unlock()
|
||||
|
||||
delete(te.tea.addedKeys, keyHash)
|
||||
te.tea.removedKeys[keyHash] = key
|
||||
te.tea.affectedKeys[keyHash] = key
|
||||
|
||||
te.tea.ed.AddEdit(key, nil)
|
||||
te.tea.opCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *TableEditor) flushEditAccumulator(ctx context.Context, teaInterface interface{}) error {
|
||||
// We don't call any locks here since this is called from an ActionExecutor with a concurrency of 1
|
||||
tea := teaInterface.(*tableEditAccumulator)
|
||||
originalRowData, err := te.t.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
}
|
||||
|
||||
// For all added keys, check for and report a collision
|
||||
for keyHash, addedKey := range te.addedKeys {
|
||||
if _, ok := te.removedKeys[keyHash]; !ok {
|
||||
_, rowExists, err := te.t.GetRow(ctx, addedKey.(types.Tuple), te.tSch)
|
||||
for keyHash, addedKey := range tea.addedKeys {
|
||||
if _, ok := tea.removedKeys[keyHash]; !ok {
|
||||
_, rowExists, err := originalRowData.MaybeGet(ctx, addedKey)
|
||||
if err != nil {
|
||||
return nil, errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
return errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
}
|
||||
if rowExists {
|
||||
value, err := types.EncodedValue(ctx, addedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
return nil, fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
|
||||
return fmt.Errorf(ErrDuplicatePrimaryKeyFmt, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
// For all removed keys, remove the map entries that weren't added elsewhere by other updates
|
||||
for keyHash, removedKey := range te.removedKeys {
|
||||
if _, ok := te.addedKeys[keyHash]; !ok {
|
||||
te.ed.Remove(removedKey)
|
||||
for keyHash, removedKey := range tea.removedKeys {
|
||||
if _, ok := tea.addedKeys[keyHash]; !ok {
|
||||
tea.ed.AddEdit(removedKey, nil)
|
||||
}
|
||||
}
|
||||
|
||||
updated, err := te.ed.Map(ctx)
|
||||
accEdits, err := tea.ed.FinishedEditing()
|
||||
if err != nil {
|
||||
_ = te.reset(ctx, te.t, te.tSch)
|
||||
return nil, errhand.BuildDError("failed to modify table").AddCause(err).Build()
|
||||
return errhand.BuildDError("failed to finalize table changes").AddCause(err).Build()
|
||||
}
|
||||
originalRowData, err := te.t.GetRowData(ctx)
|
||||
updated, _, err := types.ApplyEdits(ctx, accEdits, originalRowData)
|
||||
if err != nil {
|
||||
_ = te.reset(ctx, te.t, te.tSch)
|
||||
return nil, errhand.BuildDError("failed to read table").AddCause(err).Build()
|
||||
return errhand.BuildDError("failed to modify table").AddCause(err).Build()
|
||||
}
|
||||
newTable, err := te.t.UpdateRows(ctx, updated)
|
||||
if err != nil {
|
||||
_ = te.reset(ctx, te.t, te.tSch)
|
||||
return nil, errhand.BuildDError("failed to update rows").AddCause(err).Build()
|
||||
return errhand.BuildDError("failed to update rows").AddCause(err).Build()
|
||||
}
|
||||
newTable, err = te.updateIndexes(ctx, newTable, originalRowData, updated)
|
||||
newTable, err = te.updateIndexes(ctx, tea, newTable, originalRowData, updated)
|
||||
if err != nil {
|
||||
_ = te.reset(ctx, te.t, te.tSch)
|
||||
return nil, errhand.BuildDError("failed to update indexes").AddCause(err).Build()
|
||||
return errhand.BuildDError("failed to update indexes").AddCause(err).Build()
|
||||
}
|
||||
|
||||
// Set the TableEditor to the new table state
|
||||
err = te.reset(ctx, newTable, te.tSch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newTable, nil
|
||||
}
|
||||
|
||||
// reset sets the TableEditor to the given table
|
||||
func (te *TableEditor) reset(ctx context.Context, t *Table, tableSch schema.Schema) error {
|
||||
tableData, err := t.GetRowData(ctx)
|
||||
if err != nil {
|
||||
return errhand.BuildDError("failed to get row data.").AddCause(err).Build()
|
||||
}
|
||||
|
||||
te.t = t
|
||||
te.tSch = tableSch
|
||||
te.ed = tableData.Edit()
|
||||
te.updated = false
|
||||
te.insertedKeys = make(map[hash.Hash]types.Value)
|
||||
te.addedKeys = make(map[hash.Hash]types.Value)
|
||||
te.removedKeys = make(map[hash.Hash]types.Value)
|
||||
te.affectedKeys = make(map[hash.Hash]types.Value)
|
||||
te.indexEds = make([]*IndexEditor, tableSch.Indexes().Count())
|
||||
|
||||
for i, index := range tableSch.Indexes().AllIndexes() {
|
||||
indexData, err := t.GetIndexRowData(ctx, index.Name())
|
||||
if err != nil {
|
||||
panic(err) // should never have an index that does not have data, even an empty index
|
||||
}
|
||||
te.indexEds[i] = NewIndexEditor(index, indexData)
|
||||
}
|
||||
te.t = newTable
|
||||
return nil
|
||||
}
|
||||
|
||||
func (te *TableEditor) updateIndexes(ctx context.Context, tbl *Table, originalRowData types.Map, updated types.Map) (*Table, error) {
|
||||
// We don't call any locks here since this is only called from Flush, which acquires a lock
|
||||
func (te *TableEditor) updateIndexes(ctx context.Context, tea *tableEditAccumulator, tbl *Table, originalRowData types.Map, updated types.Map) (*Table, error) {
|
||||
// We don't call any locks here since this is called from an ActionExecutor with a concurrency of 1
|
||||
if len(te.indexEds) == 0 {
|
||||
return tbl, nil
|
||||
}
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
var anyErr error // we only care to catch any error, doesn't matter if it's overwritten
|
||||
indexActionQueue := async.NewActionExecutor(ctx, func(_ context.Context, keyInt interface{}) error {
|
||||
key := keyInt.(types.Value)
|
||||
|
||||
for _, key := range te.affectedKeys {
|
||||
wg.Add(1)
|
||||
go func(key types.Value) {
|
||||
defer wg.Done()
|
||||
var originalRow row.Row
|
||||
var updatedRow row.Row
|
||||
|
||||
var originalRow row.Row
|
||||
var updatedRow row.Row
|
||||
|
||||
if val, ok, err := originalRowData.MaybeGet(ctx, key); err == nil && ok {
|
||||
originalRow, err = row.FromNoms(te.tSch, key.(types.Tuple), val.(types.Tuple))
|
||||
if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
if val, ok, err := originalRowData.MaybeGet(ctx, key); err == nil && ok {
|
||||
originalRow, err = row.FromNoms(te.tSch, key.(types.Tuple), val.(types.Tuple))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if val, ok, err := updated.MaybeGet(ctx, key); err == nil && ok {
|
||||
updatedRow, err = row.FromNoms(te.tSch, key.(types.Tuple), val.(types.Tuple))
|
||||
if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
if val, ok, err := updated.MaybeGet(ctx, key); err == nil && ok {
|
||||
updatedRow, err = row.FromNoms(te.tSch, key.(types.Tuple), val.(types.Tuple))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, indexEd := range te.indexEds {
|
||||
var err error
|
||||
var originalIndexRow row.Row
|
||||
var updatedIndexRow row.Row
|
||||
if originalRow != nil {
|
||||
originalIndexRow, err = originalRow.ReduceToIndex(indexEd.Index())
|
||||
if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
if updatedRow != nil {
|
||||
updatedIndexRow, err = updatedRow.ReduceToIndex(indexEd.Index())
|
||||
if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
err = indexEd.UpdateIndex(ctx, originalIndexRow, updatedIndexRow)
|
||||
for _, indexEd := range te.indexEds {
|
||||
var err error
|
||||
var originalIndexRow row.Row
|
||||
var updatedIndexRow row.Row
|
||||
if originalRow != nil {
|
||||
originalIndexRow, err = originalRow.ReduceToIndex(indexEd.Index())
|
||||
if err != nil {
|
||||
anyErr = err
|
||||
return
|
||||
return err
|
||||
}
|
||||
}
|
||||
}(key)
|
||||
if updatedRow != nil {
|
||||
updatedIndexRow, err = updatedRow.ReduceToIndex(indexEd.Index())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = indexEd.UpdateIndex(ctx, originalIndexRow, updatedIndexRow)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}, 4, 0)
|
||||
|
||||
for _, key := range tea.affectedKeys {
|
||||
indexActionQueue.Execute(key)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
if anyErr != nil {
|
||||
return nil, anyErr
|
||||
err := indexActionQueue.WaitForEmpty()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, indexEd := range te.indexEds {
|
||||
|
||||
@@ -66,7 +66,7 @@ func TestTableEditorConcurrency(t *testing.T) {
|
||||
2: types.Int(val),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Insert(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.InsertRow(context.Background(), dRow))
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func TestTableEditorConcurrency(t *testing.T) {
|
||||
2: types.Int(val + 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Update(context.Background(), dOldRow, dNewRow))
|
||||
require.NoError(t, tableEditor.UpdateRow(context.Background(), dOldRow, dNewRow))
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
@@ -102,13 +102,13 @@ func TestTableEditorConcurrency(t *testing.T) {
|
||||
2: types.Int(val),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Delete(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.DeleteRow(context.Background(), dRow))
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
newTable, err := tableEditor.Flush(context.Background())
|
||||
newTable, err := tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
newTableData, err := newTable.GetRowData(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -157,9 +157,9 @@ func TestTableEditorConcurrencyPostInsert(t *testing.T) {
|
||||
2: types.Int(i),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Insert(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.InsertRow(context.Background(), dRow))
|
||||
}
|
||||
table, err = tableEditor.Flush(context.Background())
|
||||
table, err = tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < tableEditorConcurrencyIterations; i++ {
|
||||
@@ -182,7 +182,7 @@ func TestTableEditorConcurrencyPostInsert(t *testing.T) {
|
||||
2: types.Int(val + 1),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Update(context.Background(), dOldRow, dNewRow))
|
||||
require.NoError(t, tableEditor.UpdateRow(context.Background(), dOldRow, dNewRow))
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
@@ -196,13 +196,13 @@ func TestTableEditorConcurrencyPostInsert(t *testing.T) {
|
||||
2: types.Int(val),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Delete(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.DeleteRow(context.Background(), dRow))
|
||||
wg.Done()
|
||||
}(j)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
newTable, err := tableEditor.Flush(context.Background())
|
||||
newTable, err := tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
newTableData, err := newTable.GetRowData(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -252,10 +252,10 @@ func TestTableEditorWriteAfterFlush(t *testing.T) {
|
||||
2: types.Int(i),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Insert(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.InsertRow(context.Background(), dRow))
|
||||
}
|
||||
|
||||
_, err = tableEditor.Flush(context.Background())
|
||||
_, err = tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 10; i < 20; i++ {
|
||||
@@ -265,10 +265,10 @@ func TestTableEditorWriteAfterFlush(t *testing.T) {
|
||||
2: types.Int(i),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tableEditor.Delete(context.Background(), dRow))
|
||||
require.NoError(t, tableEditor.DeleteRow(context.Background(), dRow))
|
||||
}
|
||||
|
||||
newTable, err := tableEditor.Flush(context.Background())
|
||||
newTable, err := tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
newTableData, err := newTable.GetRowData(context.Background())
|
||||
require.NoError(t, err)
|
||||
@@ -289,7 +289,7 @@ func TestTableEditorWriteAfterFlush(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
sameTable, err := tableEditor.Flush(context.Background())
|
||||
sameTable, err := tableEditor.Table()
|
||||
require.NoError(t, err)
|
||||
sameTableData, err := sameTable.GetRowData(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -30,6 +30,9 @@ import (
|
||||
"github.com/liquidata-inc/dolt/go/store/types"
|
||||
)
|
||||
|
||||
// TableDataLocationUpdateRate is the number of writes that will process before the updated stats are displayed.
|
||||
const TableDataLocationUpdateRate = 32768
|
||||
|
||||
// ErrNoPK is an error returned if a schema is missing a required primary key
|
||||
var ErrNoPK = errors.New("schema does not contain a primary key")
|
||||
|
||||
@@ -207,7 +210,7 @@ type tableEditorWriteCloser struct {
|
||||
var _ DataMoverCloser = (*tableEditorWriteCloser)(nil)
|
||||
|
||||
func (te *tableEditorWriteCloser) GetTable(ctx context.Context) (*doltdb.Table, error) {
|
||||
return te.tableEditor.Flush(ctx)
|
||||
return te.tableEditor.Table()
|
||||
}
|
||||
|
||||
// GetSchema implements TableWriteCloser
|
||||
@@ -217,20 +220,14 @@ func (te *tableEditorWriteCloser) GetSchema() schema.Schema {
|
||||
|
||||
// WriteRow implements TableWriteCloser
|
||||
func (te *tableEditorWriteCloser) WriteRow(ctx context.Context, r row.Row) error {
|
||||
if atomic.LoadInt64(&te.opsSoFar) >= 65536 {
|
||||
if te.statsCB != nil && atomic.LoadInt64(&te.opsSoFar) >= TableDataLocationUpdateRate {
|
||||
atomic.StoreInt64(&te.opsSoFar, 0)
|
||||
_, err := te.tableEditor.Flush(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if te.statsCB != nil {
|
||||
te.statsCB(te.stats)
|
||||
}
|
||||
te.statsCB(te.stats)
|
||||
}
|
||||
if te.insertOnly {
|
||||
_ = atomic.AddInt64(&te.opsSoFar, 1)
|
||||
te.stats.Additions++
|
||||
return te.tableEditor.Insert(ctx, r)
|
||||
return te.tableEditor.InsertRow(ctx, r)
|
||||
} else {
|
||||
pkTuple, err := r.NomsMapKey(te.tableSch).Value(ctx)
|
||||
if err != nil {
|
||||
@@ -243,7 +240,7 @@ func (te *tableEditorWriteCloser) WriteRow(ctx context.Context, r row.Row) error
|
||||
if !ok {
|
||||
_ = atomic.AddInt64(&te.opsSoFar, 1)
|
||||
te.stats.Additions++
|
||||
return te.tableEditor.Insert(ctx, r)
|
||||
return te.tableEditor.InsertRow(ctx, r)
|
||||
}
|
||||
oldRow, err := row.FromNoms(te.tableSch, pkTuple.(types.Tuple), val.(types.Tuple))
|
||||
if err != nil {
|
||||
@@ -255,13 +252,13 @@ func (te *tableEditorWriteCloser) WriteRow(ctx context.Context, r row.Row) error
|
||||
}
|
||||
_ = atomic.AddInt64(&te.opsSoFar, 1)
|
||||
te.stats.Modifications++
|
||||
return te.tableEditor.Update(ctx, oldRow, r)
|
||||
return te.tableEditor.UpdateRow(ctx, oldRow, r)
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements TableWriteCloser
|
||||
func (te *tableEditorWriteCloser) Close(ctx context.Context) error {
|
||||
_, err := te.tableEditor.Flush(ctx)
|
||||
func (te *tableEditorWriteCloser) Close(_ context.Context) error {
|
||||
_, err := te.tableEditor.Table()
|
||||
if te.statsCB != nil {
|
||||
te.statsCB(te.stats)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (te *sqlTableEditor) Insert(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return te.tableEditor.Insert(ctx, dRow)
|
||||
return te.tableEditor.InsertRow(ctx, dRow)
|
||||
}
|
||||
|
||||
func (te *sqlTableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
@@ -67,7 +67,7 @@ func (te *sqlTableEditor) Delete(ctx *sql.Context, sqlRow sql.Row) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return te.tableEditor.Delete(ctx, dRow)
|
||||
return te.tableEditor.DeleteRow(ctx, dRow)
|
||||
}
|
||||
|
||||
func (te *sqlTableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Row) error {
|
||||
@@ -80,7 +80,7 @@ func (te *sqlTableEditor) Update(ctx *sql.Context, oldRow sql.Row, newRow sql.Ro
|
||||
return err
|
||||
}
|
||||
|
||||
return te.tableEditor.Update(ctx, dOldRow, dNewRow)
|
||||
return te.tableEditor.UpdateRow(ctx, dOldRow, dNewRow)
|
||||
}
|
||||
|
||||
// Close implements Closer
|
||||
@@ -93,7 +93,7 @@ func (te *sqlTableEditor) Close(ctx *sql.Context) error {
|
||||
}
|
||||
|
||||
func (te *sqlTableEditor) flush(ctx *sql.Context) error {
|
||||
newTable, err := te.tableEditor.Flush(ctx)
|
||||
newTable, err := te.tableEditor.Table()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
135
go/libraries/utils/async/action_executor.go
Normal file
135
go/libraries/utils/async/action_executor.go
Normal file
@@ -0,0 +1,135 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package async
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// Action is the function called by an ActionExecutor on each given value.
|
||||
type Action func(ctx context.Context, val interface{}) error
|
||||
|
||||
// ActionExecutor is designed for asynchronous workloads that should run when a new task is available. The closest analog
|
||||
// would be to have a long-running goroutine that receives from a channel, however ActionExecutor provides three major
|
||||
// points of differentiation. The first is that there is no need to close the queue, as goroutines automatically exit
|
||||
// when the queue is empty. The second is that a concurrency parameter may be set, that will spin up goroutines as
|
||||
// needed until the maximum number is attained. The third is that you don't have to declare the buffer size beforehand
|
||||
// as with channels, allowing the queue to respond to demand. You may declare a max buffer though, for RAM-limited
|
||||
// situations, which then blocks appends until the buffer is below the max given.
|
||||
type ActionExecutor struct {
|
||||
action Action
|
||||
ctx context.Context
|
||||
concurrency uint32
|
||||
err error
|
||||
finished *WaitGroup
|
||||
linkedList *list.List
|
||||
running uint32
|
||||
maxBuffer uint64
|
||||
syncCond *sync.Cond
|
||||
}
|
||||
|
||||
// NewActionExecutor returns an ActionExecutor that will run the given action on each appended value, and run up to the max
|
||||
// number of goroutines as defined by concurrency. If concurrency is 0, then it is set to 1. If maxBuffer is 0, then it
|
||||
// is unlimited. Panics on a nil action.
|
||||
func NewActionExecutor(ctx context.Context, action Action, concurrency uint32, maxBuffer uint64) *ActionExecutor {
|
||||
if action == nil {
|
||||
panic("action cannot be nil")
|
||||
}
|
||||
if concurrency == 0 {
|
||||
concurrency = 1
|
||||
}
|
||||
return &ActionExecutor{
|
||||
action: action,
|
||||
concurrency: concurrency,
|
||||
ctx: ctx,
|
||||
finished: &WaitGroup{},
|
||||
linkedList: list.New(),
|
||||
running: 0,
|
||||
maxBuffer: maxBuffer,
|
||||
syncCond: sync.NewCond(&sync.Mutex{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Execute adds the value to the end of the queue to be executed. If any action encountered an error before this call,
|
||||
// then the value is not added and this returns immediately.
|
||||
func (aq *ActionExecutor) Execute(val interface{}) {
|
||||
aq.syncCond.L.Lock()
|
||||
defer aq.syncCond.L.Unlock()
|
||||
|
||||
if aq.err != nil { // if we've errored before, then no point in running anything again
|
||||
return
|
||||
}
|
||||
|
||||
for aq.maxBuffer != 0 && uint64(aq.linkedList.Len()) >= aq.maxBuffer {
|
||||
aq.syncCond.Wait()
|
||||
}
|
||||
aq.finished.Add(1)
|
||||
aq.linkedList.PushBack(val)
|
||||
|
||||
if aq.running < aq.concurrency {
|
||||
aq.running++
|
||||
go aq.work()
|
||||
}
|
||||
}
|
||||
|
||||
// WaitForEmpty waits until the queue is empty, and then returns any errors that any actions may have encountered.
|
||||
func (aq *ActionExecutor) WaitForEmpty() error {
|
||||
aq.finished.Wait()
|
||||
return aq.err
|
||||
}
|
||||
|
||||
// work runs until the list is empty. If any error occurs from any action, then we do not call any further actions,
|
||||
// although we still iterate over the list and clear it.
|
||||
func (aq *ActionExecutor) work() {
|
||||
for {
|
||||
aq.syncCond.L.Lock() // check element list and valid state, so we lock
|
||||
|
||||
element := aq.linkedList.Front()
|
||||
if element == nil {
|
||||
aq.running--
|
||||
aq.syncCond.L.Unlock() // early exit, so we unlock
|
||||
return // we don't signal here since the buffer is empty, hence the return in the first place
|
||||
}
|
||||
_ = aq.linkedList.Remove(element)
|
||||
encounteredError := aq.err != nil
|
||||
|
||||
aq.syncCond.Signal() // if an append is waiting because of a full buffer, we signal for it to continue
|
||||
aq.syncCond.L.Unlock() // done checking list and state, so we unlock
|
||||
|
||||
if !encounteredError {
|
||||
var err error
|
||||
func() { // this func is to capture a potential panic from the action, and present it as an error instead
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in ActionExecutor:\n%v", r)
|
||||
}
|
||||
}()
|
||||
err = aq.action(aq.ctx, element.Value)
|
||||
}()
|
||||
// Technically, two actions could error at the same time and only one would persist their error. For async
|
||||
// tasks, we don't care as much about which action errored, just that an action error.
|
||||
if err != nil {
|
||||
aq.syncCond.L.Lock()
|
||||
aq.err = err
|
||||
aq.syncCond.L.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
aq.finished.Done()
|
||||
}
|
||||
}
|
||||
168
go/libraries/utils/async/action_executor_test.go
Normal file
168
go/libraries/utils/async/action_executor_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package async
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestActionExecutorOrdered(t *testing.T) {
|
||||
expectedStr := "abcdefghijklmnopqrstuvwxyz"
|
||||
outStr := ""
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
str := val.(string)
|
||||
outStr += str
|
||||
return nil
|
||||
}, 1, 0)
|
||||
for _, char := range expectedStr {
|
||||
actionExecutor.Execute(string(char))
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedStr, outStr)
|
||||
}
|
||||
|
||||
func TestActionExecutorOrderedBuffered(t *testing.T) {
|
||||
expectedStr := "abcdefghijklmnopqrstuvwxyz"
|
||||
outStr := ""
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
str := val.(string)
|
||||
outStr += str
|
||||
return nil
|
||||
}, 1, 3)
|
||||
for _, char := range expectedStr {
|
||||
actionExecutor.Execute(string(char))
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedStr, outStr)
|
||||
}
|
||||
|
||||
func TestActionExecutorUnordered(t *testing.T) {
|
||||
expectedValue := int64(50005000)
|
||||
outValue := int64(0)
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
atomic.AddInt64(&outValue, val.(int64))
|
||||
return nil
|
||||
}, 5, 0)
|
||||
for i := int64(1); i <= 10000; i++ {
|
||||
actionExecutor.Execute(i)
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, outValue)
|
||||
}
|
||||
|
||||
func TestActionExecutorUnorderedBuffered(t *testing.T) {
|
||||
expectedValue := int64(50005000)
|
||||
outValue := int64(0)
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
atomic.AddInt64(&outValue, val.(int64))
|
||||
return nil
|
||||
}, 5, 10)
|
||||
for i := int64(1); i <= 10000; i++ {
|
||||
actionExecutor.Execute(i)
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedValue, outValue)
|
||||
}
|
||||
|
||||
func TestActionExecutorUnnecessaryWaits(t *testing.T) {
|
||||
outValue := int64(0)
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
atomic.AddInt64(&outValue, val.(int64))
|
||||
return nil
|
||||
}, 5, 10)
|
||||
for i := int64(1); i <= 10000; i++ {
|
||||
actionExecutor.Execute(i)
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionExecutorError(t *testing.T) {
|
||||
for _, conBuf := range []struct {
|
||||
concurrency uint32
|
||||
maxBuffer uint64
|
||||
}{
|
||||
{1, 0},
|
||||
{5, 0},
|
||||
{10, 0},
|
||||
{1, 1},
|
||||
{5, 1},
|
||||
{10, 1},
|
||||
{1, 5},
|
||||
{5, 5},
|
||||
{10, 5},
|
||||
{1, 10},
|
||||
{5, 10},
|
||||
{10, 10},
|
||||
} {
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
if val.(int64) == 11 {
|
||||
return errors.New("hey there")
|
||||
}
|
||||
return nil
|
||||
}, conBuf.concurrency, conBuf.maxBuffer)
|
||||
for i := int64(1); i <= 100; i++ {
|
||||
actionExecutor.Execute(i)
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
assert.Error(t, err)
|
||||
sameErr := actionExecutor.WaitForEmpty()
|
||||
assert.Equal(t, err, sameErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionExecutorPanicRecovery(t *testing.T) {
|
||||
for _, conBuf := range []struct {
|
||||
concurrency uint32
|
||||
maxBuffer uint64
|
||||
}{
|
||||
{1, 0},
|
||||
{5, 0},
|
||||
{10, 0},
|
||||
{1, 1},
|
||||
{5, 1},
|
||||
{10, 1},
|
||||
{1, 5},
|
||||
{5, 5},
|
||||
{10, 5},
|
||||
{1, 10},
|
||||
{5, 10},
|
||||
{10, 10},
|
||||
} {
|
||||
actionExecutor := NewActionExecutor(context.Background(), func(ctx context.Context, val interface{}) error {
|
||||
if val.(int64) == 22 {
|
||||
panic("hey there")
|
||||
}
|
||||
return nil
|
||||
}, conBuf.concurrency, conBuf.maxBuffer)
|
||||
for i := int64(1); i <= 100; i++ {
|
||||
actionExecutor.Execute(i)
|
||||
}
|
||||
err := actionExecutor.WaitForEmpty()
|
||||
require.Error(t, err)
|
||||
}
|
||||
}
|
||||
66
go/libraries/utils/async/wait_group.go
Normal file
66
go/libraries/utils/async/wait_group.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package async
|
||||
|
||||
import "sync"
|
||||
|
||||
// WaitGroup functions similarly to sync.WaitGroup that ships with Go, with the key difference being that this one
|
||||
// allows calls to Add with a positive delta to occur while another thread is waiting, while the sync version
|
||||
// may panic. The tradeoff is a performance reduction since we now lock on all modifications to the counter.
|
||||
type WaitGroup struct {
|
||||
counter int64 // we allow negative counters and don't panic on them, as it could be useful for the caller
|
||||
initOnce sync.Once
|
||||
syncCond *sync.Cond
|
||||
}
|
||||
|
||||
// Add adds delta, which may be negative, to the WaitGroup counter. If the counter becomes zero, all goroutines blocked
|
||||
// on Wait are released. If the counter goes negative, Add panics.
|
||||
func (wg *WaitGroup) Add(delta int64) {
|
||||
wg.init()
|
||||
wg.syncCond.L.Lock()
|
||||
defer wg.syncCond.L.Unlock()
|
||||
|
||||
wg.counter += delta
|
||||
if wg.counter < 0 {
|
||||
panic("negative WaitGroup counter")
|
||||
} else if wg.counter == 0 {
|
||||
wg.syncCond.Broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
// Done decrements the WaitGroup counter by one.
|
||||
func (wg *WaitGroup) Done() {
|
||||
wg.Add(-1)
|
||||
}
|
||||
|
||||
// Wait blocks until the WaitGroup counter is less than or equal to zero.
|
||||
func (wg *WaitGroup) Wait() {
|
||||
wg.init()
|
||||
wg.syncCond.L.Lock()
|
||||
defer wg.syncCond.L.Unlock()
|
||||
|
||||
for wg.counter > 0 {
|
||||
wg.syncCond.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
// sync.WaitGroup allows the user to use the zero value of a wait group with &sync.WaitGroup{}. Since this is supposed
|
||||
// to be a drop-in replacement, the user would expect to call &async.WaitGroup{}. Since we need some setup, we make use
|
||||
// of sync.Once to run that setup the first time the wait group is used.
|
||||
func (wg *WaitGroup) init() {
|
||||
wg.initOnce.Do(func() {
|
||||
wg.syncCond = sync.NewCond(&sync.Mutex{})
|
||||
})
|
||||
}
|
||||
62
go/libraries/utils/async/wait_group_test.go
Normal file
62
go/libraries/utils/async/wait_group_test.go
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package async
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWaitGroupAddWait(_ *testing.T) {
|
||||
wg := &WaitGroup{}
|
||||
wg.Add(100)
|
||||
go func() {
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWaitGroupAddWhileWait(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
assert.Nil(t, r)
|
||||
}()
|
||||
wg := &WaitGroup{}
|
||||
for i := 0; i < 5000000; i++ {
|
||||
wg.Add(1)
|
||||
go wg.Done()
|
||||
}
|
||||
go func() {
|
||||
for i := 0; i < 5000000; i++ {
|
||||
wg.Add(1)
|
||||
wg.Done()
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestWaitGroupPanicOnNegative(t *testing.T) {
|
||||
defer func() {
|
||||
r := recover()
|
||||
assert.NotNil(t, r)
|
||||
}()
|
||||
wg := &WaitGroup{}
|
||||
wg.Add(1)
|
||||
wg.Done()
|
||||
wg.Done()
|
||||
}
|
||||
Reference in New Issue
Block a user