SQL Merge Function (#730)

* bh/merge-func

* commit after merge

* formatting

* fix test
This commit is contained in:
Brian Hendriks
2020-06-23 22:41:43 -07:00
committed by GitHub
13 changed files with 455 additions and 18 deletions

View File

@@ -45,13 +45,13 @@ class DoltConnection(object):
def close(self):
self.cnx.close()
def query(self, query_str):
def query(self, query_str, exit_on_err=True):
try:
cursor = self.cnx.cursor()
cursor.execute(query_str)
if cursor.description is None:
return []
return [], cursor.rowcount
raw = cursor.fetchall()
@@ -62,10 +62,12 @@ class DoltConnection(object):
r[k] = str(curr[i])
row_maps.append(r)
return row_maps
return row_maps, cursor.rowcount
except BaseException as e:
_print_err_and_exit(e)
if exit_on_err:
_print_err_and_exit(e)
raise e
class InfiniteRetryConnection(DoltConnection):

View File

@@ -43,7 +43,7 @@ if query_results is not None:
for i in range(len(queries)):
query_str = queries[i].strip()
print('executing:', query_str)
actual_rows = dc.query(query_str)
actual_rows, num_rows = dc.query(query_str)
if expected[i] is not None:
expected_rows = csv_to_row_maps(expected[i])
@@ -89,6 +89,28 @@ start_sql_server() {
wait_for_connection $PORT 5000
}
start_sql_multi_user_server() {
DEFAULT_DB="$1"
let PORT="$$ % (65536-1024) + 1024"
echo "
log_level: debug
user:
name: dolt
listener:
host: 0.0.0.0
port: $PORT
max_connections: 10
behavior:
autocommit: false
" > .cliconfig.yaml
dolt sql-server --config .cliconfig.yaml &
SERVER_PID=$!
wait_for_connection $PORT 5000
}
start_multi_db_server() {
DEFAULT_DB="$1"

View File

@@ -0,0 +1,198 @@
import os
import sys
from queue import Queue
from threading import Thread
from helper.pytest import DoltConnection
# Utility functions
def print_err(e):
print(e, file=sys.stderr)
def query(dc, query_str):
return dc.query(query_str, False)
def query_with_expected_error(dc, non_error_msg , query_str):
try:
dc.query(query_str, False)
raise Exception(non_error_msg)
except:
pass
def row(pk, c1, c2):
return {"pk":str(pk),"c1":str(c1),"c2":str(c2)}
UPDATE_BRANCH_FAIL_MSG = "Failed to update branch"
def commit_and_update_branch(dc, commit_message, expected_hashes, branch_name):
expected_hash = "("
for i, eh in enumerate(expected_hashes):
if i != 0:
expected_hash += " or "
expected_hash += "hash = %s" % eh
expected_hash += ")"
query_str = 'UPDATE dolt_branches SET hash = Commit("%s") WHERE name = "%s" AND %s' % (commit_message, branch_name, expected_hash)
_, row_count = query(dc, query_str)
if row_count != 1:
raise Exception(UPDATE_BRANCH_FAIL_MSG)
query(dc, 'SET @@repo1_head=HASHOF("%s");' % branch_name)
def query_and_test_results(dc, query_str, expected):
results, _ = query(dc, query_str)
if results != expected:
raise Exception("Unexpected results for query:\n\t%s\nExpected:\n\t%s\nActual:\n\t%s" % (query_str, str(), str(results)))
def resolve_theirs(dc):
query_str = "REPLACE INTO test (pk, c1, c2) SELECT their_pk, their_c1, their_c2 FROM dolt_conflicts_test WHERE their_pk IS NOT NULL;"
query(dc, query_str)
query_str = """DELETE FROM test WHERE pk in (
SELECT base_pk FROM dolt_conflicts_test WHERE their_pk IS NULL
);"""
query(dc, query_str)
query(dc, "DELETE FROM dolt_conflicts_test")
def create_branch(dc, branch_name):
query_str = 'INSERT INTO dolt_branches (name, hash) VALUES ("%s", @@repo1_head);' % branch_name
_, row_count = query(dc, query_str)
if row_count != 1:
raise Exception("Failed to create branch")
# work functions
def connect(dc):
dc.connect()
def create_tables(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, """
CREATE TABLE test (
pk INT NOT NULL,
c1 INT,
c2 INT,
PRIMARY KEY(pk));""")
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
query_and_test_results(dc, "SHOW TABLES;", [{"Table": "test"}])
def duplicate_table_create(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query_with_expected_error(dc, "Should have failed creating duplicate table", """
CREATE TABLE test (
pk INT NOT NULL,
c1 INT,
c2 INT,
PRIMARY KEY(pk));""")
def seed_master(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
_, row_count = query(dc, 'INSERT INTO test VALUES (0,0,0),(1,1,1),(2,2,2)')
if row_count != 3:
raise Exception("Failed to update rows")
commit_and_update_branch(dc, "Seeded initial data", ["@@repo1_head"], "master")
expected = [row(0,0,0), row(1,1,1), row(2,2,2)]
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
def modify_pk0_on_master_and_commit(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, "UPDATE test SET c1=1 WHERE pk=0;")
commit_and_update_branch(dc, "set c1 to 1", ["@@repo1_head"], "master")
def modify_pk0_on_master_no_commit(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, "UPDATE test SET c1=2 WHERE pk=0")
def fail_to_commit(dc):
try:
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
raise Exception("Failed to fail commit")
except Exception as e:
if str(e) != UPDATE_BRANCH_FAIL_MSG:
raise e
def commit_to_feature(dc):
create_branch(dc, "feature")
commit_and_update_branch(dc, "set c1 to 2", ["@@repo1_head"], "feature")
def merge_resolve_commit(dc):
query(dc, 'SET @@repo1_head=Merge("master");')
query_and_test_results(dc, "SELECT * from dolt_conflicts;", [{"table": "test", "num_conflicts": "1"}])
resolve_theirs(dc)
expected = [row(0,1,0), row(1,1,1), row(2,2,2)]
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
commit_and_update_branch(dc, "resolved conflicts", ['HASHOF("HEAD^1")', 'HASHOF("HEAD^2")'], "master")
# test script
MAX_SIMULTANEOUS_CONNECTIONS = 2
PORT_STR = sys.argv[1]
CONNECTIONS = [None]*MAX_SIMULTANEOUS_CONNECTIONS
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
CONNECTIONS[i] = DoltConnection(port=int(PORT_STR), database="repo1", user='dolt', auto_commit=False)
WORK_QUEUE = Queue()
# work item run by workers
class WorkItem(object):
def __init__(self, dc, *work_funcs):
self.dc = dc
self.work_funcs = work_funcs
self.exception = None
# worker thread function
def worker():
while True:
try:
item = WORK_QUEUE.get()
for work_func in item.work_funcs:
work_func(item.dc)
WORK_QUEUE.task_done()
except Exception as e:
work_item.exception = e
WORK_QUEUE.task_done()
# start the worker threads
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
t = Thread(target=worker)
t.daemon = True
t.start()
# This defines the actual test script. Each stage in the script has a list of work items. Each work item
# in a stage should have a different connection associated with it. Each connections work is done in parallel
# each of the work functions for a connection is executed in order.
work_item_stages = [
[WorkItem(CONNECTIONS[0], connect, create_tables)],
[WorkItem(CONNECTIONS[0], seed_master), WorkItem(CONNECTIONS[1], connect, duplicate_table_create)],
[WorkItem(CONNECTIONS[0], modify_pk0_on_master_and_commit), WorkItem(CONNECTIONS[1], modify_pk0_on_master_no_commit)],
[WorkItem(CONNECTIONS[1], fail_to_commit, commit_to_feature, merge_resolve_commit)]
]
# Loop through the work item stages executing each stage by sending the work items for the stage to the worker threads
# and then waiting for all of them to finish before moving on to the next one. Checks for an error after every stage.
for stage, work_items in enumerate(work_item_stages):
print("Running stage %d / %d" % (stage,len(work_item_stages)))
for work_item in work_items:
WORK_QUEUE.put(work_item)
WORK_QUEUE.join()
for work_item in work_items:
if work_item.exception is not None:
print_err(work_item.exception)
sys.exit(1)

View File

@@ -20,6 +20,17 @@ teardown() {
teardown_common
}
@test "multi-client" {
skiponwindows "Has dependencies that are missing on the Jenkins Windows installation."
cd repo1
start_sql_multi_user_server repo1
cd $BATS_TEST_DIRNAME
let PORT="$$ % (65536-1024) + 1024"
python3 server_multiclient_test.py $PORT
}
@test "test autocommit" {
skiponwindows "Has dependencies that are missing on the Jenkins Windows installation."

View File

@@ -409,7 +409,7 @@ func (ddb *DoltDB) CommitWithParentSpecs(ctx context.Context, valHash hash.Hash,
return ddb.CommitWithParentCommits(ctx, valHash, dref, parentCommits, cm)
}
func (ddb *DoltDB) WriteCommitDanglingCommit(ctx context.Context, valHash hash.Hash, parentCommits []*Commit, cm *CommitMeta) (*Commit, error) {
func (ddb *DoltDB) WriteDanglingCommit(ctx context.Context, valHash hash.Hash, parentCommits []*Commit, cm *CommitMeta) (*Commit, error) {
var commitSt types.Struct
val, err := ddb.db.ReadValue(ctx, valHash)

View File

@@ -56,7 +56,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
dbName := ctx.GetCurrentDatabase()
dSess := sqle.DSessFromSess(ctx.Session)
parent, err := dSess.GetParentCommit(ctx, dbName)
parent, _, err := dSess.GetParentCommit(ctx, dbName)
if err != nil {
return nil, err
@@ -90,7 +90,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return nil, err
}
cm, err := ddb.WriteCommitDanglingCommit(ctx, h, []*doltdb.Commit{parent}, meta)
cm, err := ddb.WriteDanglingCommit(ctx, h, []*doltdb.Commit{parent}, meta)
if err != nil {
return nil, err

View File

@@ -72,7 +72,7 @@ func (t *HashOf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
if strings.ToUpper(name) == "HEAD" {
sess := sqle.DSessFromSess(ctx.Session)
cm, err = sess.GetParentCommit(ctx, dbName)
cm, _, err = sess.GetParentCommit(ctx, dbName)
} else {
name, err = getBranchInsensitive(ctx, name, ddb)

View File

@@ -23,4 +23,5 @@ func init() {
// TODO: fix function registration
function.Defaults = append(function.Defaults, sql.Function1{Name: HashOfFuncName, Fn: NewHashOf})
function.Defaults = append(function.Defaults, sql.Function1{Name: CommitFuncName, Fn: NewCommitFunc})
function.Defaults = append(function.Defaults, sql.Function1{Name: MergeFuncName, Fn: NewMergeFunc})
}

View File

@@ -0,0 +1,202 @@
// Copyright 2020 Liquidata, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dfunctions
import (
"errors"
"fmt"
"github.com/liquidata-inc/go-mysql-server/sql"
"github.com/liquidata-inc/go-mysql-server/sql/expression"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/merge"
"github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle"
"github.com/liquidata-inc/dolt/go/store/hash"
)
const MergeFuncName = "merge"
type MergeFunc struct {
expression.UnaryExpression
}
// NewMergeFunc creates a new MergeFunc expression.
func NewMergeFunc(e sql.Expression) sql.Expression {
return &MergeFunc{expression.UnaryExpression{Child: e}}
}
// Eval implements the Expression interface.
func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
val, err := cf.Child.Eval(ctx, row)
if err != nil {
return nil, err
} else if val == nil {
return nil, nil
}
sess := sqle.DSessFromSess(ctx.Session)
if sess.Username == "" || sess.Email == "" {
return nil, errors.New("commit function failure: Username and/or email not configured")
}
dbName := sess.GetCurrentDatabase()
ddb, ok := sess.GetDoltDB(dbName)
if !ok {
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
root, ok := sess.GetRoot(dbName)
if !ok {
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
parent, ph, parentRoot, err := getParent(ctx, err, sess, dbName)
if err != nil {
return nil, err
}
err = checkForUncommittedChanges(root, parentRoot)
if err != nil {
return nil, err
}
cm, cmh, err := getBranchCommit(ctx, ok, val, err, ddb)
if err != nil {
return nil, err
}
mergeRoot, _, err := merge.MergeCommits(ctx, ddb, parent, cm)
if err == merge.ErrFastForward {
return cmh.String(), nil
}
h, err := ddb.WriteRootValue(ctx, mergeRoot)
if err != nil {
return nil, err
}
commitMessage := fmt.Sprintf("SQL Generated commit merging %s into %s", ph.String(), cmh.String())
meta, err := doltdb.NewCommitMeta(sess.Username, sess.Email, commitMessage)
if err != nil {
return nil, err
}
mergeCommit, err := ddb.WriteDanglingCommit(ctx, h, []*doltdb.Commit{parent, cm}, meta)
if err != nil {
return nil, err
}
h, err = mergeCommit.HashOf()
if err != nil {
return nil, err
}
return h.String(), nil
}
func checkForUncommittedChanges(root *doltdb.RootValue, parentRoot *doltdb.RootValue) error {
rh, err := root.HashOf()
if err != nil {
return err
}
prh, err := parentRoot.HashOf()
if err != nil {
return err
}
if rh != prh {
return errors.New("cannot merge with uncommitted changes")
}
return nil
}
func getBranchCommit(ctx *sql.Context, ok bool, val interface{}, err error, ddb *doltdb.DoltDB) (*doltdb.Commit, hash.Hash, error) {
paramStr, ok := val.(string)
if !ok {
return nil, hash.Hash{}, errors.New("branch name is not a string")
}
name, err := getBranchInsensitive(ctx, paramStr, ddb)
if err != nil {
return nil, hash.Hash{}, err
}
cs, err := doltdb.NewCommitSpec("HEAD", name)
if err != nil {
return nil, hash.Hash{}, err
}
cm, err := ddb.Resolve(ctx, cs)
if err != nil {
return nil, hash.Hash{}, err
}
cmh, err := cm.HashOf()
if err != nil {
return nil, hash.Hash{}, err
}
return cm, cmh, nil
}
func getParent(ctx *sql.Context, err error, sess *sqle.DoltSession, dbName string) (*doltdb.Commit, hash.Hash, *doltdb.RootValue, error) {
parent, ph, err := sess.GetParentCommit(ctx, dbName)
if err != nil {
return nil, hash.Hash{}, nil, err
}
parentRoot, err := parent.GetRootValue()
if err != nil {
return nil, hash.Hash{}, nil, err
}
return parent, ph, parentRoot, nil
}
// String implements the Stringer interface.
func (cf *MergeFunc) String() string {
return fmt.Sprintf("Merge(%s)", cf.Child.String())
}
// IsNullable implements the Expression interface.
func (cf *MergeFunc) IsNullable() bool {
return cf.Child.IsNullable()
}
// WithChildren implements the Expression interface.
func (cf *MergeFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 1 {
return nil, sql.ErrInvalidChildrenNumber.New(cf, len(children), 1)
}
return NewMergeFunc(children[0]), nil
}
// Type implements the Expression interface.
func (cf *MergeFunc) Type() sql.Type {
return sql.Text
}

View File

@@ -145,7 +145,7 @@ func (dt *DiffTable) Schema() sql.Schema {
func (dt *DiffTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
sess := DSessFromSess(ctx.Session)
rootCmt, err := sess.GetParentCommit(ctx, dt.dbName)
rootCmt, _, err := sess.GetParentCommit(ctx, dt.dbName)
if err != nil {
return nil, err

View File

@@ -122,33 +122,34 @@ func (sess *DoltSession) GetRoot(dbName string) (*doltdb.RootValue, bool) {
}
// GetParentCommit returns the parent commit of the current session.
func (sess *DoltSession) GetParentCommit(ctx context.Context, dbName string) (*doltdb.Commit, error) {
func (sess *DoltSession) GetParentCommit(ctx context.Context, dbName string) (*doltdb.Commit, hash.Hash, error) {
dbd, dbFound := sess.dbDatas[dbName]
if !dbFound {
return nil, sql.ErrDatabaseNotFound.New(dbName)
return nil, hash.Hash{}, sql.ErrDatabaseNotFound.New(dbName)
}
_, value := sess.Session.Get(dbName + HeadKeySuffix)
valStr, isStr := value.(string)
if !isStr || !hash.IsValid(valStr) {
return nil, doltdb.ErrInvalidHash
return nil, hash.Hash{}, doltdb.ErrInvalidHash
}
h := hash.Parse(valStr)
cs, err := doltdb.NewCommitSpec(valStr, "")
if err != nil {
return nil, err
return nil, hash.Hash{}, err
}
cm, err := dbd.ddb.Resolve(ctx, cs)
if err != nil {
return nil, err
return nil, hash.Hash{}, err
}
return cm, nil
return cm, h, nil
}
func (sess *DoltSession) Set(ctx context.Context, key string, typ sql.Type, value interface{}) error {

View File

@@ -70,7 +70,7 @@ func NewHistoryTable(ctx *sql.Context, db Database, tblName string) (sql.Table,
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
head, err := sess.GetParentCommit(ctx, dbName)
head, _, err := sess.GetParentCommit(ctx, dbName)
if err != nil {
return nil, err

View File

@@ -84,7 +84,7 @@ type LogItr struct {
// NewLogItr creates a LogItr from the current environment.
func NewLogItr(sqlCtx *sql.Context, dbName string, ddb *doltdb.DoltDB) (*LogItr, error) {
sess := DSessFromSess(sqlCtx.Session)
commit, err := sess.GetParentCommit(sqlCtx, dbName)
commit, _, err := sess.GetParentCommit(sqlCtx, dbName)
if err != nil {
return nil, err