diff --git a/go/cmd/dolt/commands/tblcmds/import.go b/go/cmd/dolt/commands/tblcmds/import.go index 69d7dcaaf5..622cdbb352 100644 --- a/go/cmd/dolt/commands/tblcmds/import.go +++ b/go/cmd/dolt/commands/tblcmds/import.go @@ -26,6 +26,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/fatih/color" "golang.org/x/sync/errgroup" + "golang.org/x/text/message" "github.com/dolthub/dolt/go/cmd/dolt/cli" "github.com/dolthub/dolt/go/cmd/dolt/commands" @@ -422,7 +423,8 @@ var displayStrLen int func importStatsCB(stats types.AppliedEditStats) { noEffect := stats.NonExistentDeletes + stats.SameVal total := noEffect + stats.Modifications + stats.Additions - displayStr := fmt.Sprintf("Rows Processed: %d, Additions: %d, Modifications: %d, Had No Effect: %d", total, stats.Additions, stats.Modifications, noEffect) + p := message.NewPrinter(message.MatchLanguage("en")) // adds commas + displayStr := p.Sprintf("Rows Processed: %d, Additions: %d, Modifications: %d, Had No Effect: %d", total, stats.Additions, stats.Modifications, noEffect) displayStrLen = cli.DeleteAndPrint(displayStrLen, displayStr) } diff --git a/go/go.mod b/go/go.mod index 19a8e5f507..9d06bd960e 100644 --- a/go/go.mod +++ b/go/go.mod @@ -68,7 +68,7 @@ require ( ) require ( - github.com/dolthub/go-mysql-server v0.11.1-0.20220304002938-823e425edf58 + github.com/dolthub/go-mysql-server v0.11.1-0.20220304213711-4d7d9a2c6f81 github.com/google/flatbuffers v2.0.5+incompatible github.com/gosuri/uilive v0.0.4 github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 @@ -76,6 +76,7 @@ require ( github.com/shirou/gopsutil/v3 v3.22.1 github.com/xitongsys/parquet-go v1.6.1 github.com/xitongsys/parquet-go-source v0.0.0-20211010230925-397910c5e371 + golang.org/x/text v0.3.7 ) require ( @@ -120,7 +121,6 @@ require ( golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 // indirect golang.org/x/mod v0.5.1 // indirect golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect - golang.org/x/text v0.3.7 // indirect golang.org/x/tools v0.1.9 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go/go.sum b/go/go.sum index 199001d2c8..f3a9ea2d5e 100755 --- a/go/go.sum +++ b/go/go.sum @@ -170,8 +170,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= -github.com/dolthub/go-mysql-server v0.11.1-0.20220304002938-823e425edf58 h1:Fgi+KjilXXJe1fYsdlr+xf7tubprV3wtKP+Y3WyCILM= -github.com/dolthub/go-mysql-server v0.11.1-0.20220304002938-823e425edf58/go.mod h1:5WoXPdkIrkNBjKH+Y1XMfwREEtPXOW/yN8QfulFpZ1s= +github.com/dolthub/go-mysql-server v0.11.1-0.20220304213711-4d7d9a2c6f81 h1:uk9aHMW7ji1rbSBhAq0h/Ncy4/mIN+7cFqk/zQES3Zo= +github.com/dolthub/go-mysql-server v0.11.1-0.20220304213711-4d7d9a2c6f81/go.mod h1:5WoXPdkIrkNBjKH+Y1XMfwREEtPXOW/yN8QfulFpZ1s= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g= github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms= github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8= diff --git a/go/libraries/doltcore/mvdata/engine_table_writer.go b/go/libraries/doltcore/mvdata/engine_table_writer.go index 9e41fa758d..9db7cafe88 100644 --- a/go/libraries/doltcore/mvdata/engine_table_writer.go +++ b/go/libraries/doltcore/mvdata/engine_table_writer.go @@ -93,12 +93,9 @@ func NewSqlEngineTableWriter(ctx context.Context, dEnv *env.DoltEnv, createTable return nil, err } - var doltCreateTableSchema sql.PrimaryKeySchema - if options.Operation == CreateOp { - doltCreateTableSchema, err = sqlutil.FromDoltSchema(options.TableToWriteTo, createTableSchema) - if err != nil { - return nil, err - } + doltCreateTableSchema, err := sqlutil.FromDoltSchema(options.TableToWriteTo, createTableSchema) + if err != nil { + return nil, err } doltRowOperationSchema, err := sqlutil.FromDoltSchema(options.TableToWriteTo, rowOperationSchema) @@ -183,11 +180,11 @@ func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan } // If the length of the row does not match the schema then we have an update operation. - if len(row) != len(s.rowOperationSchema.Schema) { + if len(row) != len(s.tableSchema.Schema) { oldRow := row[:len(row)/2] newRow := row[len(row)/2:] - if ok, err := oldRow.Equals(newRow, s.rowOperationSchema.Schema); err == nil { + if ok, err := oldRow.Equals(newRow, s.tableSchema.Schema); err == nil { if ok { s.stats.SameVal++ } else { @@ -208,11 +205,11 @@ func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan if err != nil { return err } + defer func() { - if err != nil { - iter.Close(s.sqlCtx) // save the error that should be propagated. - } else { - err = iter.Close(s.sqlCtx) + rerr := iter.Close(s.sqlCtx) + if err == nil { + err = rerr } }() diff --git a/go/libraries/doltcore/schema/col_coll.go b/go/libraries/doltcore/schema/col_coll.go index b8160a916a..f0d4b73f1d 100644 --- a/go/libraries/doltcore/schema/col_coll.go +++ b/go/libraries/doltcore/schema/col_coll.go @@ -292,10 +292,16 @@ func FilterColCollection(cc *ColCollection, cb func(col Column) bool) *ColCollec } func ColCollUnion(colColls ...*ColCollection) (*ColCollection, error) { + var allTags = make(map[uint64]bool) var allCols []Column for _, sch := range colColls { err := sch.Iter(func(tag uint64, col Column) (stop bool, err error) { + // skip if already seen + if _, ok := allTags[tag]; ok { + return false, nil + } allCols = append(allCols, col) + allTags[tag] = true return false, nil }) diff --git a/go/libraries/doltcore/sqle/index/dolt_index.go b/go/libraries/doltcore/sqle/index/dolt_index.go index 45d74bf91c..66122d5167 100644 --- a/go/libraries/doltcore/sqle/index/dolt_index.go +++ b/go/libraries/doltcore/sqle/index/dolt_index.go @@ -31,7 +31,7 @@ import ( ) type DoltIndex interface { - sql.Index + sql.FilteredIndex Schema() schema.Schema IndexSchema() schema.Schema TableData() durable.Index @@ -289,6 +289,10 @@ RangeLoop: }, nil } +func (di doltIndex) HandledFilters(filters []sql.Expression) []sql.Expression { + return filters +} + // Database implement sql.Index func (di doltIndex) Database() string { return di.dbName diff --git a/go/utils/remotesrv/http.go b/go/utils/remotesrv/http.go index fca1c772da..3514940725 100644 --- a/go/utils/remotesrv/http.go +++ b/go/utils/remotesrv/http.go @@ -28,10 +28,14 @@ import ( remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1" - "github.com/dolthub/dolt/go/libraries/utils/iohelp" "github.com/dolthub/dolt/go/store/hash" ) +var ( + ErrReadOutOfBounds = errors.New("cannot read file for given length and " + + "offset since the read would exceed the size of the file") +) + var expectedFiles = make(map[string]*remotesapi.TableFileDetails) func ServeHTTP(respWr http.ResponseWriter, req *http.Request) { @@ -53,13 +57,7 @@ func ServeHTTP(respWr http.ResponseWriter, req *http.Request) { statusCode := http.StatusMethodNotAllowed switch req.Method { case http.MethodGet: - rangeStr := req.Header.Get("Range") - - if rangeStr == "" { - statusCode = readFile(logger, org, repo, hashStr, respWr) - } else { - statusCode = readChunk(logger, org, repo, hashStr, rangeStr, respWr) - } + statusCode = readTableFile(logger, org, repo, hashStr, respWr, req) case http.MethodPost, http.MethodPut: statusCode = writeTableFile(logger, org, repo, hashStr, req) @@ -70,6 +68,63 @@ func ServeHTTP(respWr http.ResponseWriter, req *http.Request) { } } +func readTableFile(logger func(string), org, repo, fileId string, respWr http.ResponseWriter, req *http.Request) int { + rangeStr := req.Header.Get("Range") + path := filepath.Join(org, repo, fileId) + + var r io.ReadCloser + var readSize int64 + var fileErr error + { + if rangeStr == "" { + logger("going to read entire file") + r, readSize, fileErr = getFileReader(path) + } else { + offset, length, err := offsetAndLenFromRange(rangeStr) + if err != nil { + logger(err.Error()) + return http.StatusBadRequest + } + logger(fmt.Sprintf("going to read file at offset %d, length %d", offset, length)) + readSize = length + r, fileErr = getFileReaderAt(path, offset, length) + } + } + if fileErr != nil { + logger(fileErr.Error()) + if errors.Is(fileErr, os.ErrNotExist) { + return http.StatusNotFound + } else if errors.Is(fileErr, ErrReadOutOfBounds) { + return http.StatusBadRequest + } + return http.StatusInternalServerError + } + defer func() { + err := r.Close() + if err != nil { + err = fmt.Errorf("failed to close file at path %s: %w", path, err) + logger(err.Error()) + } + }() + + logger(fmt.Sprintf("opened file at path %s, going to read %d bytes", path, readSize)) + + n, err := io.Copy(respWr, r) + if err != nil { + err = fmt.Errorf("failed to write data to response writer: %w", err) + logger(err.Error()) + return http.StatusInternalServerError + } + if n != readSize { + logger(fmt.Sprintf("wanted to write %d bytes from file (%s) but only wrote %d", readSize, path, n)) + return http.StatusInternalServerError + } + + logger(fmt.Sprintf("wrote %d bytes", n)) + + return http.StatusOK +} + func writeTableFile(logger func(string), org, repo, fileId string, request *http.Request) int { _, ok := hash.MaybeParse(fileId) @@ -157,127 +212,46 @@ func offsetAndLenFromRange(rngStr string) (int64, int64, error) { return int64(start), int64(end-start) + 1, nil } -func readFile(logger func(string), org, repo, fileId string, writer io.Writer) int { - path := filepath.Join(org, repo, fileId) +// getFileReader opens a file at the given path and returns an io.ReadCloser, +// the corresponding file's filesize, and a http status. +func getFileReader(path string) (io.ReadCloser, int64, error) { + return openFile(path) +} +func openFile(path string) (*os.File, int64, error) { info, err := os.Stat(path) - if err != nil { - logger("file not found. path: " + path) - return http.StatusNotFound + return nil, 0, fmt.Errorf("failed to get stats for file at path %s: %w", path, err) } f, err := os.Open(path) - if err != nil { - logger("failed to open file. file: " + path + " err: " + err.Error()) - return http.StatusInternalServerError + return nil, 0, fmt.Errorf("failed to open file at path %s: %w", path, err) } - defer func() { - err := f.Close() - - if err != nil { - logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err)) - } else { - logger("Close Successful") - } - }() - - n, err := io.Copy(writer, f) - - if err != nil { - logger("failed to write data to response. err : " + err.Error()) - return -1 - } - - if n != info.Size() { - logger(fmt.Sprintf("failed to write entire file to response. Copied %d of %d err: %v", n, info.Size(), err)) - return -1 - } - - return -1 + return f, info.Size(), nil } -func readChunk(logger func(string), org, repo, fileId, rngStr string, writer io.Writer) int { - offset, length, err := offsetAndLenFromRange(rngStr) - - if err != nil { - logger(fmt.Sprintln(rngStr, "is not a valid range")) - return http.StatusBadRequest - } - - data, retVal := readLocalRange(logger, org, repo, fileId, int64(offset), int64(length)) - - if retVal != -1 { - return retVal - } - - logger(fmt.Sprintf("writing %d bytes", len(data))) - err = iohelp.WriteAll(writer, data) - - if err != nil { - logger("failed to write data to response " + err.Error()) - return -1 - } - - logger("Successfully wrote data") - return -1 +type closerReaderWrapper struct { + io.Reader + io.Closer } -func readLocalRange(logger func(string), org, repo, fileId string, offset, length int64) ([]byte, int) { - path := filepath.Join(org, repo, fileId) - - logger(fmt.Sprintf("Attempting to read bytes %d to %d from %s", offset, offset+length, path)) - info, err := os.Stat(path) - +func getFileReaderAt(path string, offset int64, length int64) (io.ReadCloser, error) { + f, fSize, err := openFile(path) if err != nil { - logger(fmt.Sprintf("file %s not found", path)) - return nil, http.StatusNotFound + return nil, err } - logger(fmt.Sprintf("Verified file %s exists", path)) - - if info.Size() < int64(offset+length) { - logger(fmt.Sprintf("Attempted to read bytes %d to %d, but the file is only %d bytes in size", offset, offset+length, info.Size())) - return nil, http.StatusBadRequest + if fSize < int64(offset+length) { + return nil, fmt.Errorf("failed to read file %s at offset %d, length %d: %w", path, offset, length, ErrReadOutOfBounds) } - logger(fmt.Sprintf("Verified the file is large enough to contain the range")) - f, err := os.Open(path) - + _, err = f.Seek(int64(offset), 0) if err != nil { - logger(fmt.Sprintf("Failed to open %s: %v", path, err)) - return nil, http.StatusInternalServerError + return nil, fmt.Errorf("failed to seek file at path %s to offset %d: %w", path, offset, err) } - defer func() { - err := f.Close() - - if err != nil { - logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err)) - } else { - logger("Close Successful") - } - }() - - logger(fmt.Sprintf("Successfully opened file")) - pos, err := f.Seek(int64(offset), 0) - - if err != nil { - logger(fmt.Sprintf("Failed to seek to %d: %v", offset, err)) - return nil, http.StatusInternalServerError - } - - logger(fmt.Sprintf("Seek succeeded. Current position is %d", pos)) - diff := offset - pos - data, err := iohelp.ReadNBytes(f, int(diff+int64(length))) - - if err != nil { - logger(fmt.Sprintf("Failed to read %d bytes: %v", diff+length, err)) - return nil, http.StatusInternalServerError - } - - logger(fmt.Sprintf("Successfully read %d bytes", len(data))) - return data[diff:], -1 + r := closerReaderWrapper{io.LimitReader(f, length), f} + return r, nil } diff --git a/integration-tests/bats/import-create-tables.bats b/integration-tests/bats/import-create-tables.bats index 8c4fc344e5..eab34c8f88 100755 --- a/integration-tests/bats/import-create-tables.bats +++ b/integration-tests/bats/import-create-tables.bats @@ -698,7 +698,7 @@ DELIM run dolt table import -s schema.sql -c keyless data.csv [ "$status" -eq 0 ] - [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false run dolt sql -r csv -q "select * from keyless" diff --git a/integration-tests/bats/import-replace-tables.bats b/integration-tests/bats/import-replace-tables.bats index d6644d9856..bc13d70b13 100644 --- a/integration-tests/bats/import-replace-tables.bats +++ b/integration-tests/bats/import-replace-tables.bats @@ -335,7 +335,7 @@ DELIM run dolt table import -r test 1pk5col-ints-updt.csv [ "$status" -eq 0 ] - [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false run dolt sql -r csv -q "select * from test" diff --git a/integration-tests/bats/import-update-tables.bats b/integration-tests/bats/import-update-tables.bats index 92f2f1830e..cae450cf5a 100644 --- a/integration-tests/bats/import-update-tables.bats +++ b/integration-tests/bats/import-update-tables.bats @@ -49,7 +49,7 @@ SQL cat < check-constraint-sch.sql CREATE TABLE persons ( - ID int NOT NULL, + ID int PRIMARY KEY, LastName varchar(255) NOT NULL, FirstName varchar(255), Age int CHECK (Age>=18) @@ -209,7 +209,6 @@ CREATE TABLE employees ( ); SQL run dolt table import -u employees `batshelper employees-tbl-schema-unordered.json` - echo "$output" [ "$status" -eq 0 ] [[ "$output" =~ "Rows Processed: 3, Additions: 3, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false @@ -532,7 +531,7 @@ DELIM run dolt table import -u test 1pk5col-ints-updt.csv [ "$status" -eq 0 ] - [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false run dolt sql -r csv -q "select * from test" @@ -552,7 +551,7 @@ DELIM run dolt table import -u test 1pk5col-ints-updt.csv [ "$status" -eq 0 ] - [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false run dolt sql -r csv -q "select * from test" @@ -653,7 +652,7 @@ DELIM run dolt table import -u keyless data.csv [ "$status" -eq 0 ] - [[ "$output" =~ "Rows Processed: 1, Additions: 0, Modifications: 1, Had No Effect: 0" ]] || false + [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false run dolt sql -r csv -q "select * from keyless order by c0, c1 DESC" @@ -682,4 +681,38 @@ DELIM ! [[ "$output" =~ "[4,little,doe,1]" ]] || false [[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false [[ "$output" =~ "Import completed successfully." ]] || false + + run dolt sql -r csv -q "select * from persons" + [[ "$output" =~ "1,jon,doe,20" ]] || false +} + +@test "import-update-tables: large amounts of no effect rows" { + dolt sql -q "create table t(pk int primary key)" + dolt sql -q "alter table t add constraint cx CHECK (pk < 10)" + dolt sql -q "Insert into t values (1),(2),(3),(4),(5),(6),(7),(8),(9) " + + cat < file.csv +pk +1 +2 +3 +4 +5 +6 +10000 +DELIM + + run dolt table import -u --continue t file.csv + [ "$status" -eq 0 ] + [[ "$output" =~ "Rows Processed: 6, Additions: 0, Modifications: 0, Had No Effect: 6" ]] || false + [[ "$output" =~ "The following rows were skipped:" ]] || false + [[ "$output" =~ "[10000]" ]] || false + + run dolt sql -r csv -q "select * from t" + [[ "$output" =~ "1" ]] || false + [[ "$output" =~ "2" ]] || false + [[ "$output" =~ "3" ]] || false + [[ "$output" =~ "4" ]] || false + [[ "$output" =~ "5" ]] || false + [[ "$output" =~ "6" ]] || false } diff --git a/integration-tests/bats/sql-merge.bats b/integration-tests/bats/sql-merge.bats index b2a7faaa99..f1c59ae6b8 100644 --- a/integration-tests/bats/sql-merge.bats +++ b/integration-tests/bats/sql-merge.bats @@ -762,6 +762,26 @@ SQL [ $status -eq 0 ] } +@test "sql-merge: identical schema changes with data changes merges correctly" { + dolt sql -q "create table t (i int primary key)" + dolt commit -am "initial commit" + dolt branch b1 + dolt branch b2 + dolt checkout b1 + dolt sql -q "alter table t add column j int" + dolt sql -q "insert into t values (1, 1)" + dolt commit -am "changes to b1" + dolt checkout b2 + dolt sql -q "alter table t add column j int" + dolt sql -q "insert into t values (2, 2)" + dolt commit -am "changes to b2" + dolt checkout main + run dolt merge b1 + [ $status -eq 0 ] + run dolt merge b2 + [ $status -eq 0 ] +} + get_head_commit() { dolt log -n 1 | grep -m 1 commit | cut -c 13-44 } diff --git a/integration-tests/transactions/concurrent_tx_test.go b/integration-tests/transactions/concurrent_tx_test.go new file mode 100644 index 0000000000..2f5c6512e0 --- /dev/null +++ b/integration-tests/transactions/concurrent_tx_test.go @@ -0,0 +1,303 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transactions + +import ( + "fmt" + "sync" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/gocraft/dbr/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var defaultConfig = ServerConfig{ + database: "mysql", + host: "127.0.0.1", + port: 3316, + user: "root", + password: "toor", +} + +func TestConcurrentTransactions(t *testing.T) { + sequential := &sync.Mutex{} + for _, test := range txTests { + t.Run(test.name, func(t *testing.T) { + sequential.Lock() + defer sequential.Unlock() + testConcurrentTx(t, test) + }) + } +} + +type ConcurrentTxTest struct { + name string + queries []concurrentQuery +} + +type concurrentQuery struct { + conn string + query string + assertion selector + expected []testRow +} + +type selector func(s *dbr.Session) *dbr.SelectStmt + +type testRow struct { + Pk, C0 int +} + +const ( + one = "one" + two = "two" +) + +var txTests = []ConcurrentTxTest{ + { + name: "smoke test", + queries: []concurrentQuery{ + { + conn: one, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data") + }, + expected: []testRow{ + {1, 1}, + {2, 2}, + {3, 3}, + }, + }, + }, + }, + { + name: "concurrent transactions", + queries: []concurrentQuery{ + { + conn: one, + query: "BEGIN;", + }, + { + conn: two, + query: "BEGIN;", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data") + }, + expected: []testRow{ + {1, 1}, {2, 2}, {3, 3}, + }, + }, + { + conn: one, + query: "INSERT INTO tx.data VALUES (4,4)", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data") + }, + expected: []testRow{ + {1, 1}, {2, 2}, {3, 3}, + }, + }, + { + conn: one, + query: "COMMIT", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data") + }, + expected: []testRow{ + {1, 1}, {2, 2}, {3, 3}, + }, + }, + { + conn: two, + query: "COMMIT", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data") + }, + expected: []testRow{ + {1, 1}, {2, 2}, {3, 3}, {4, 4}, + }, + }, + }, + }, + { + name: "concurrent updates", + queries: []concurrentQuery{ + { + conn: one, + query: "BEGIN;", + }, + { + conn: two, + query: "BEGIN;", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data").Where("pk = 1") + }, + expected: []testRow{ + {1, 1}, + }, + }, + { + conn: one, + query: "UPDATE tx.data SET c0 = c0 + 10 WHERE pk = 1;", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data").Where("pk = 1") + }, + expected: []testRow{ + {1, 1}, + }, + }, + { + conn: one, + query: "COMMIT", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data").Where("pk = 1") + }, + expected: []testRow{ + {1, 1}, + }, + }, + { + conn: two, + query: "UPDATE tx.data SET c0 = c0 + 10 WHERE pk = 1;", + }, + { + conn: two, + assertion: func(s *dbr.Session) *dbr.SelectStmt { + return s.Select("*").From("tx.data").Where("pk = 1") + }, + expected: []testRow{ + {1, 21}, + }, + }, + { + conn: two, + query: "COMMIT", + }, + }, + }, +} + +func setupCommon(sess *dbr.Session) (err error) { + queries := []string{ + "DROP DATABASE IF EXISTS tx;", + "CREATE DATABASE IF NOT EXISTS tx;", + "USE tx;", + "CREATE TABLE data (pk int primary key, c0 int);", + "INSERT INTO data VALUES (1,1),(2,2),(3,3);", + } + + for _, q := range queries { + if _, err = sess.Exec(q); err != nil { + return + } + } + return +} + +func testConcurrentTx(t *testing.T, test ConcurrentTxTest) { + conns, err := createNamedConnections(defaultConfig, one, two) + require.NoError(t, err) + defer func() { require.NoError(t, closeNamedConnections(conns)) }() + + err = setupCommon(conns[one]) + defer func() { require.NoError(t, teardownCommon(conns[one])) }() + + for _, q := range test.queries { + conn := conns[q.conn] + if q.query != "" { + _, err = conn.Exec(q.query) + require.NoError(t, err) + } + + if q.assertion == nil { + continue + } + + var actual []testRow + _, err = q.assertion(conn).Load(&actual) + require.NoError(t, err) + assert.Equal(t, q.expected, actual) + } +} + +func teardownCommon(sess *dbr.Session) (err error) { + _, err = sess.Exec("DROP DATABASE tx;") + return +} + +type ServerConfig struct { + database string + host string + port int + user string + password string +} + +type namedConnections map[string]*dbr.Session + +// ConnectionString returns a Data Source Name (DSN) to be used by go clients for connecting to a running server. +func ConnectionString(config ServerConfig) string { + return fmt.Sprintf("%v:%v@tcp(%v:%v)/%s", + config.user, + config.password, + config.host, + config.port, + config.database, + ) +} + +func createNamedConnections(config ServerConfig, names ...string) (nc namedConnections, err error) { + nc = make(namedConnections, len(names)) + for _, name := range names { + var c *dbr.Connection + if c, err = dbr.Open("mysql", ConnectionString(config), nil); err != nil { + return nil, err + } + nc[name] = c.NewSession(nil) + } + return +} + +func closeNamedConnections(nc namedConnections) (err error) { + for _, conn := range nc { + if err = conn.Close(); err != nil { + return + } + } + return +} diff --git a/integration-tests/transactions/go.mod b/integration-tests/transactions/go.mod new file mode 100644 index 0000000000..7bc4d414d2 --- /dev/null +++ b/integration-tests/transactions/go.mod @@ -0,0 +1,13 @@ +module github.com/dolthub/dolt/integration-tests/transactions + +go 1.17 + +require github.com/go-sql-driver/mysql v1.6.0 + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/gocraft/dbr/v2 v2.7.3 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.7.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect +) diff --git a/integration-tests/transactions/go.sum b/integration-tests/transactions/go.sum new file mode 100644 index 0000000000..5c53f6164d --- /dev/null +++ b/integration-tests/transactions/go.sum @@ -0,0 +1,14 @@ +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/gocraft/dbr/v2 v2.7.3 h1:5/PTRiBkdD2FoHpnrCMoEUw5Wf/Cl3l3PjJ02Wm+pwM= +github.com/gocraft/dbr/v2 v2.7.3/go.mod h1:8IH98S8M8J0JSEiYk0MPH26ZDUKemiQ/GvmXL5jo+Uw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=