Files
dolt/bats/server_multiclient_test.py
VinaiRachakonda 28a2f08b19 Initial commit. Compiles but errors
Fix up tests

fix commit tests....

Add author param to merge

formatting...

Fix array prob
2020-12-13 13:13:16 -05:00

199 lines
6.4 KiB
Python

import os
import sys
from queue import Queue
from threading import Thread
from helper.pytest import DoltConnection
# Utility functions
def print_err(e):
print(e, file=sys.stderr)
def query(dc, query_str):
return dc.query(query_str, False)
def query_with_expected_error(dc, non_error_msg , query_str):
try:
dc.query(query_str, False)
raise Exception(non_error_msg)
except:
pass
def row(pk, c1, c2):
return {"pk":str(pk),"c1":str(c1),"c2":str(c2)}
UPDATE_BRANCH_FAIL_MSG = "Failed to update branch"
def commit_and_update_branch(dc, commit_message, expected_hashes, branch_name):
expected_hash = "("
for i, eh in enumerate(expected_hashes):
if i != 0:
expected_hash += " or "
expected_hash += "hash = %s" % eh
expected_hash += ")"
query_str = 'UPDATE dolt_branches SET hash = Commit("-m", "%s") WHERE name = "%s" AND %s' % (commit_message, branch_name, expected_hash)
_, row_count = query(dc, query_str)
if row_count != 1:
raise Exception(UPDATE_BRANCH_FAIL_MSG)
query(dc, 'SET @@repo1_head=HASHOF("%s");' % branch_name)
def query_and_test_results(dc, query_str, expected):
results, _ = query(dc, query_str)
if results != expected:
raise Exception("Unexpected results for query:\n\t%s\nExpected:\n\t%s\nActual:\n\t%s" % (query_str, str(), str(results)))
def resolve_theirs(dc):
query_str = "REPLACE INTO test (pk, c1, c2) SELECT their_pk, their_c1, their_c2 FROM dolt_conflicts_test WHERE their_pk IS NOT NULL;"
query(dc, query_str)
query_str = """DELETE FROM test WHERE pk in (
SELECT base_pk FROM dolt_conflicts_test WHERE their_pk IS NULL
);"""
query(dc, query_str)
query(dc, "DELETE FROM dolt_conflicts_test")
def create_branch(dc, branch_name):
query_str = 'INSERT INTO dolt_branches (name, hash) VALUES ("%s", @@repo1_head);' % branch_name
_, row_count = query(dc, query_str)
if row_count != 1:
raise Exception("Failed to create branch")
# work functions
def connect(dc):
dc.connect()
def create_tables(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, """
CREATE TABLE test (
pk INT NOT NULL,
c1 INT,
c2 INT,
PRIMARY KEY(pk));""")
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
query_and_test_results(dc, "SHOW TABLES;", [{"Table": "test"}])
def duplicate_table_create(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query_with_expected_error(dc, "Should have failed creating duplicate table", """
CREATE TABLE test (
pk INT NOT NULL,
c1 INT,
c2 INT,
PRIMARY KEY(pk));""")
def seed_master(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
_, row_count = query(dc, 'INSERT INTO test VALUES (0,0,0),(1,1,1),(2,2,2)')
if row_count != 3:
raise Exception("Failed to update rows")
commit_and_update_branch(dc, "Seeded initial data", ["@@repo1_head"], "master")
expected = [row(0,0,0), row(1,1,1), row(2,2,2)]
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
def modify_pk0_on_master_and_commit(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, "UPDATE test SET c1=1 WHERE pk=0;")
commit_and_update_branch(dc, "set c1 to 1", ["@@repo1_head"], "master")
def modify_pk0_on_master_no_commit(dc):
query(dc, 'SET @@repo1_head=HASHOF("master");')
query(dc, "UPDATE test SET c1=2 WHERE pk=0")
def fail_to_commit(dc):
try:
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
raise Exception("Failed to fail commit")
except Exception as e:
if str(e) != UPDATE_BRANCH_FAIL_MSG:
raise e
def commit_to_feature(dc):
create_branch(dc, "feature")
commit_and_update_branch(dc, "set c1 to 2", ["@@repo1_head"], "feature")
def merge_resolve_commit(dc):
query(dc, 'SET @@repo1_head=Merge("master");')
query_and_test_results(dc, "SELECT * from dolt_conflicts;", [{"table": "test", "num_conflicts": "1"}])
resolve_theirs(dc)
expected = [row(0,1,0), row(1,1,1), row(2,2,2)]
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
commit_and_update_branch(dc, "resolved conflicts", ['HASHOF("HEAD^1")', 'HASHOF("HEAD^2")'], "master")
# test script
MAX_SIMULTANEOUS_CONNECTIONS = 2
PORT_STR = sys.argv[1]
CONNECTIONS = [None]*MAX_SIMULTANEOUS_CONNECTIONS
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
CONNECTIONS[i] = DoltConnection(port=int(PORT_STR), database="repo1", user='dolt', auto_commit=False)
WORK_QUEUE = Queue()
# work item run by workers
class WorkItem(object):
def __init__(self, dc, *work_funcs):
self.dc = dc
self.work_funcs = work_funcs
self.exception = None
# worker thread function
def worker():
while True:
try:
item = WORK_QUEUE.get()
for work_func in item.work_funcs:
work_func(item.dc)
WORK_QUEUE.task_done()
except Exception as e:
work_item.exception = e
WORK_QUEUE.task_done()
# start the worker threads
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
t = Thread(target=worker)
t.daemon = True
t.start()
# This defines the actual test script. Each stage in the script has a list of work items. Each work item
# in a stage should have a different connection associated with it. Each connections work is done in parallel
# each of the work functions for a connection is executed in order.
work_item_stages = [
[WorkItem(CONNECTIONS[0], connect, create_tables)],
[WorkItem(CONNECTIONS[0], seed_master), WorkItem(CONNECTIONS[1], connect, duplicate_table_create)],
[WorkItem(CONNECTIONS[0], modify_pk0_on_master_and_commit), WorkItem(CONNECTIONS[1], modify_pk0_on_master_no_commit)],
[WorkItem(CONNECTIONS[1], fail_to_commit, commit_to_feature, merge_resolve_commit)]
]
# Loop through the work item stages executing each stage by sending the work items for the stage to the worker threads
# and then waiting for all of them to finish before moving on to the next one. Checks for an error after every stage.
for stage, work_items in enumerate(work_item_stages):
print("Running stage %d / %d" % (stage,len(work_item_stages)))
for work_item in work_items:
WORK_QUEUE.put(work_item)
WORK_QUEUE.join()
for work_item in work_items:
if work_item.exception is not None:
print_err(work_item.exception)
sys.exit(1)