mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-13 19:28:50 -06:00
SQL Merge Function (#730)
* bh/merge-func * commit after merge * formatting * fix test
This commit is contained in:
@@ -45,13 +45,13 @@ class DoltConnection(object):
|
||||
def close(self):
|
||||
self.cnx.close()
|
||||
|
||||
def query(self, query_str):
|
||||
def query(self, query_str, exit_on_err=True):
|
||||
try:
|
||||
cursor = self.cnx.cursor()
|
||||
cursor.execute(query_str)
|
||||
|
||||
if cursor.description is None:
|
||||
return []
|
||||
return [], cursor.rowcount
|
||||
|
||||
raw = cursor.fetchall()
|
||||
|
||||
@@ -62,10 +62,12 @@ class DoltConnection(object):
|
||||
r[k] = str(curr[i])
|
||||
row_maps.append(r)
|
||||
|
||||
return row_maps
|
||||
return row_maps, cursor.rowcount
|
||||
|
||||
except BaseException as e:
|
||||
_print_err_and_exit(e)
|
||||
if exit_on_err:
|
||||
_print_err_and_exit(e)
|
||||
raise e
|
||||
|
||||
|
||||
class InfiniteRetryConnection(DoltConnection):
|
||||
|
||||
@@ -43,7 +43,7 @@ if query_results is not None:
|
||||
for i in range(len(queries)):
|
||||
query_str = queries[i].strip()
|
||||
print('executing:', query_str)
|
||||
actual_rows = dc.query(query_str)
|
||||
actual_rows, num_rows = dc.query(query_str)
|
||||
|
||||
if expected[i] is not None:
|
||||
expected_rows = csv_to_row_maps(expected[i])
|
||||
@@ -89,6 +89,28 @@ start_sql_server() {
|
||||
wait_for_connection $PORT 5000
|
||||
}
|
||||
|
||||
start_sql_multi_user_server() {
|
||||
DEFAULT_DB="$1"
|
||||
let PORT="$$ % (65536-1024) + 1024"
|
||||
echo "
|
||||
log_level: debug
|
||||
|
||||
user:
|
||||
name: dolt
|
||||
|
||||
listener:
|
||||
host: 0.0.0.0
|
||||
port: $PORT
|
||||
max_connections: 10
|
||||
|
||||
behavior:
|
||||
autocommit: false
|
||||
" > .cliconfig.yaml
|
||||
dolt sql-server --config .cliconfig.yaml &
|
||||
SERVER_PID=$!
|
||||
wait_for_connection $PORT 5000
|
||||
}
|
||||
|
||||
|
||||
start_multi_db_server() {
|
||||
DEFAULT_DB="$1"
|
||||
|
||||
198
bats/server_multiclient_test.py
Normal file
198
bats/server_multiclient_test.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
|
||||
from helper.pytest import DoltConnection
|
||||
|
||||
|
||||
# Utility functions
|
||||
|
||||
def print_err(e):
|
||||
print(e, file=sys.stderr)
|
||||
|
||||
def query(dc, query_str):
|
||||
return dc.query(query_str, False)
|
||||
|
||||
def query_with_expected_error(dc, non_error_msg , query_str):
|
||||
try:
|
||||
dc.query(query_str, False)
|
||||
raise Exception(non_error_msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
def row(pk, c1, c2):
|
||||
return {"pk":str(pk),"c1":str(c1),"c2":str(c2)}
|
||||
|
||||
UPDATE_BRANCH_FAIL_MSG = "Failed to update branch"
|
||||
|
||||
def commit_and_update_branch(dc, commit_message, expected_hashes, branch_name):
|
||||
expected_hash = "("
|
||||
for i, eh in enumerate(expected_hashes):
|
||||
if i != 0:
|
||||
expected_hash += " or "
|
||||
expected_hash += "hash = %s" % eh
|
||||
expected_hash += ")"
|
||||
|
||||
query_str = 'UPDATE dolt_branches SET hash = Commit("%s") WHERE name = "%s" AND %s' % (commit_message, branch_name, expected_hash)
|
||||
_, row_count = query(dc, query_str)
|
||||
|
||||
if row_count != 1:
|
||||
raise Exception(UPDATE_BRANCH_FAIL_MSG)
|
||||
|
||||
query(dc, 'SET @@repo1_head=HASHOF("%s");' % branch_name)
|
||||
|
||||
def query_and_test_results(dc, query_str, expected):
|
||||
results, _ = query(dc, query_str)
|
||||
|
||||
if results != expected:
|
||||
raise Exception("Unexpected results for query:\n\t%s\nExpected:\n\t%s\nActual:\n\t%s" % (query_str, str(), str(results)))
|
||||
|
||||
def resolve_theirs(dc):
|
||||
query_str = "REPLACE INTO test (pk, c1, c2) SELECT their_pk, their_c1, their_c2 FROM dolt_conflicts_test WHERE their_pk IS NOT NULL;"
|
||||
query(dc, query_str)
|
||||
|
||||
query_str = """DELETE FROM test WHERE pk in (
|
||||
SELECT base_pk FROM dolt_conflicts_test WHERE their_pk IS NULL
|
||||
);"""
|
||||
query(dc, query_str)
|
||||
|
||||
query(dc, "DELETE FROM dolt_conflicts_test")
|
||||
|
||||
def create_branch(dc, branch_name):
|
||||
query_str = 'INSERT INTO dolt_branches (name, hash) VALUES ("%s", @@repo1_head);' % branch_name
|
||||
_, row_count = query(dc, query_str)
|
||||
|
||||
if row_count != 1:
|
||||
raise Exception("Failed to create branch")
|
||||
|
||||
|
||||
# work functions
|
||||
|
||||
def connect(dc):
|
||||
dc.connect()
|
||||
|
||||
def create_tables(dc):
|
||||
query(dc, 'SET @@repo1_head=HASHOF("master");')
|
||||
query(dc, """
|
||||
CREATE TABLE test (
|
||||
pk INT NOT NULL,
|
||||
c1 INT,
|
||||
c2 INT,
|
||||
PRIMARY KEY(pk));""")
|
||||
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
|
||||
query_and_test_results(dc, "SHOW TABLES;", [{"Table": "test"}])
|
||||
|
||||
def duplicate_table_create(dc):
|
||||
query(dc, 'SET @@repo1_head=HASHOF("master");')
|
||||
query_with_expected_error(dc, "Should have failed creating duplicate table", """
|
||||
CREATE TABLE test (
|
||||
pk INT NOT NULL,
|
||||
c1 INT,
|
||||
c2 INT,
|
||||
PRIMARY KEY(pk));""")
|
||||
|
||||
|
||||
def seed_master(dc):
|
||||
query(dc, 'SET @@repo1_head=HASHOF("master");')
|
||||
_, row_count = query(dc, 'INSERT INTO test VALUES (0,0,0),(1,1,1),(2,2,2)')
|
||||
|
||||
if row_count != 3:
|
||||
raise Exception("Failed to update rows")
|
||||
|
||||
commit_and_update_branch(dc, "Seeded initial data", ["@@repo1_head"], "master")
|
||||
expected = [row(0,0,0), row(1,1,1), row(2,2,2)]
|
||||
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
|
||||
|
||||
def modify_pk0_on_master_and_commit(dc):
|
||||
query(dc, 'SET @@repo1_head=HASHOF("master");')
|
||||
query(dc, "UPDATE test SET c1=1 WHERE pk=0;")
|
||||
commit_and_update_branch(dc, "set c1 to 1", ["@@repo1_head"], "master")
|
||||
|
||||
def modify_pk0_on_master_no_commit(dc):
|
||||
query(dc, 'SET @@repo1_head=HASHOF("master");')
|
||||
query(dc, "UPDATE test SET c1=2 WHERE pk=0")
|
||||
|
||||
def fail_to_commit(dc):
|
||||
try:
|
||||
commit_and_update_branch(dc, "Created tables", ["@@repo1_head"], "master")
|
||||
raise Exception("Failed to fail commit")
|
||||
except Exception as e:
|
||||
if str(e) != UPDATE_BRANCH_FAIL_MSG:
|
||||
raise e
|
||||
|
||||
def commit_to_feature(dc):
|
||||
create_branch(dc, "feature")
|
||||
commit_and_update_branch(dc, "set c1 to 2", ["@@repo1_head"], "feature")
|
||||
|
||||
def merge_resolve_commit(dc):
|
||||
query(dc, 'SET @@repo1_head=Merge("master");')
|
||||
query_and_test_results(dc, "SELECT * from dolt_conflicts;", [{"table": "test", "num_conflicts": "1"}])
|
||||
resolve_theirs(dc)
|
||||
expected = [row(0,1,0), row(1,1,1), row(2,2,2)]
|
||||
query_and_test_results(dc, "SELECT pk, c1, c2 FROM test ORDER BY pk", expected)
|
||||
commit_and_update_branch(dc, "resolved conflicts", ['HASHOF("HEAD^1")', 'HASHOF("HEAD^2")'], "master")
|
||||
|
||||
|
||||
# test script
|
||||
MAX_SIMULTANEOUS_CONNECTIONS = 2
|
||||
PORT_STR = sys.argv[1]
|
||||
|
||||
CONNECTIONS = [None]*MAX_SIMULTANEOUS_CONNECTIONS
|
||||
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
|
||||
CONNECTIONS[i] = DoltConnection(port=int(PORT_STR), database="repo1", user='dolt', auto_commit=False)
|
||||
|
||||
WORK_QUEUE = Queue()
|
||||
|
||||
# work item run by workers
|
||||
class WorkItem(object):
|
||||
def __init__(self, dc, *work_funcs):
|
||||
self.dc = dc
|
||||
self.work_funcs = work_funcs
|
||||
self.exception = None
|
||||
|
||||
|
||||
# worker thread function
|
||||
def worker():
|
||||
while True:
|
||||
try:
|
||||
item = WORK_QUEUE.get()
|
||||
|
||||
for work_func in item.work_funcs:
|
||||
work_func(item.dc)
|
||||
|
||||
WORK_QUEUE.task_done()
|
||||
except Exception as e:
|
||||
work_item.exception = e
|
||||
WORK_QUEUE.task_done()
|
||||
|
||||
# start the worker threads
|
||||
for i in range(MAX_SIMULTANEOUS_CONNECTIONS):
|
||||
t = Thread(target=worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
# This defines the actual test script. Each stage in the script has a list of work items. Each work item
|
||||
# in a stage should have a different connection associated with it. Each connections work is done in parallel
|
||||
# each of the work functions for a connection is executed in order.
|
||||
work_item_stages = [
|
||||
[WorkItem(CONNECTIONS[0], connect, create_tables)],
|
||||
[WorkItem(CONNECTIONS[0], seed_master), WorkItem(CONNECTIONS[1], connect, duplicate_table_create)],
|
||||
[WorkItem(CONNECTIONS[0], modify_pk0_on_master_and_commit), WorkItem(CONNECTIONS[1], modify_pk0_on_master_no_commit)],
|
||||
[WorkItem(CONNECTIONS[1], fail_to_commit, commit_to_feature, merge_resolve_commit)]
|
||||
]
|
||||
|
||||
# Loop through the work item stages executing each stage by sending the work items for the stage to the worker threads
|
||||
# and then waiting for all of them to finish before moving on to the next one. Checks for an error after every stage.
|
||||
for stage, work_items in enumerate(work_item_stages):
|
||||
print("Running stage %d / %d" % (stage,len(work_item_stages)))
|
||||
for work_item in work_items:
|
||||
WORK_QUEUE.put(work_item)
|
||||
|
||||
WORK_QUEUE.join()
|
||||
|
||||
for work_item in work_items:
|
||||
if work_item.exception is not None:
|
||||
print_err(work_item.exception)
|
||||
sys.exit(1)
|
||||
@@ -20,6 +20,17 @@ teardown() {
|
||||
teardown_common
|
||||
}
|
||||
|
||||
@test "multi-client" {
|
||||
skiponwindows "Has dependencies that are missing on the Jenkins Windows installation."
|
||||
|
||||
cd repo1
|
||||
start_sql_multi_user_server repo1
|
||||
|
||||
cd $BATS_TEST_DIRNAME
|
||||
let PORT="$$ % (65536-1024) + 1024"
|
||||
python3 server_multiclient_test.py $PORT
|
||||
}
|
||||
|
||||
@test "test autocommit" {
|
||||
skiponwindows "Has dependencies that are missing on the Jenkins Windows installation."
|
||||
|
||||
|
||||
@@ -409,7 +409,7 @@ func (ddb *DoltDB) CommitWithParentSpecs(ctx context.Context, valHash hash.Hash,
|
||||
return ddb.CommitWithParentCommits(ctx, valHash, dref, parentCommits, cm)
|
||||
}
|
||||
|
||||
func (ddb *DoltDB) WriteCommitDanglingCommit(ctx context.Context, valHash hash.Hash, parentCommits []*Commit, cm *CommitMeta) (*Commit, error) {
|
||||
func (ddb *DoltDB) WriteDanglingCommit(ctx context.Context, valHash hash.Hash, parentCommits []*Commit, cm *CommitMeta) (*Commit, error) {
|
||||
var commitSt types.Struct
|
||||
val, err := ddb.db.ReadValue(ctx, valHash)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
|
||||
dbName := ctx.GetCurrentDatabase()
|
||||
dSess := sqle.DSessFromSess(ctx.Session)
|
||||
parent, err := dSess.GetParentCommit(ctx, dbName)
|
||||
parent, _, err := dSess.GetParentCommit(ctx, dbName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -90,7 +90,7 @@ func (cf *CommitFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm, err := ddb.WriteCommitDanglingCommit(ctx, h, []*doltdb.Commit{parent}, meta)
|
||||
cm, err := ddb.WriteDanglingCommit(ctx, h, []*doltdb.Commit{parent}, meta)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -72,7 +72,7 @@ func (t *HashOf) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
if strings.ToUpper(name) == "HEAD" {
|
||||
sess := sqle.DSessFromSess(ctx.Session)
|
||||
|
||||
cm, err = sess.GetParentCommit(ctx, dbName)
|
||||
cm, _, err = sess.GetParentCommit(ctx, dbName)
|
||||
} else {
|
||||
name, err = getBranchInsensitive(ctx, name, ddb)
|
||||
|
||||
|
||||
@@ -23,4 +23,5 @@ func init() {
|
||||
// TODO: fix function registration
|
||||
function.Defaults = append(function.Defaults, sql.Function1{Name: HashOfFuncName, Fn: NewHashOf})
|
||||
function.Defaults = append(function.Defaults, sql.Function1{Name: CommitFuncName, Fn: NewCommitFunc})
|
||||
function.Defaults = append(function.Defaults, sql.Function1{Name: MergeFuncName, Fn: NewMergeFunc})
|
||||
}
|
||||
|
||||
202
go/libraries/doltcore/sqle/dfunctions/merge.go
Normal file
202
go/libraries/doltcore/sqle/dfunctions/merge.go
Normal file
@@ -0,0 +1,202 @@
|
||||
// Copyright 2020 Liquidata, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dfunctions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/liquidata-inc/go-mysql-server/sql"
|
||||
"github.com/liquidata-inc/go-mysql-server/sql/expression"
|
||||
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/merge"
|
||||
"github.com/liquidata-inc/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/liquidata-inc/dolt/go/store/hash"
|
||||
)
|
||||
|
||||
const MergeFuncName = "merge"
|
||||
|
||||
type MergeFunc struct {
|
||||
expression.UnaryExpression
|
||||
}
|
||||
|
||||
// NewMergeFunc creates a new MergeFunc expression.
|
||||
func NewMergeFunc(e sql.Expression) sql.Expression {
|
||||
return &MergeFunc{expression.UnaryExpression{Child: e}}
|
||||
}
|
||||
|
||||
// Eval implements the Expression interface.
|
||||
func (cf *MergeFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
|
||||
val, err := cf.Child.Eval(ctx, row)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if val == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sess := sqle.DSessFromSess(ctx.Session)
|
||||
if sess.Username == "" || sess.Email == "" {
|
||||
return nil, errors.New("commit function failure: Username and/or email not configured")
|
||||
}
|
||||
|
||||
dbName := sess.GetCurrentDatabase()
|
||||
ddb, ok := sess.GetDoltDB(dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
root, ok := sess.GetRoot(dbName)
|
||||
if !ok {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
parent, ph, parentRoot, err := getParent(ctx, err, sess, dbName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = checkForUncommittedChanges(root, parentRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm, cmh, err := getBranchCommit(ctx, ok, val, err, ddb)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergeRoot, _, err := merge.MergeCommits(ctx, ddb, parent, cm)
|
||||
if err == merge.ErrFastForward {
|
||||
return cmh.String(), nil
|
||||
}
|
||||
|
||||
h, err := ddb.WriteRootValue(ctx, mergeRoot)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
commitMessage := fmt.Sprintf("SQL Generated commit merging %s into %s", ph.String(), cmh.String())
|
||||
meta, err := doltdb.NewCommitMeta(sess.Username, sess.Email, commitMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mergeCommit, err := ddb.WriteDanglingCommit(ctx, h, []*doltdb.Commit{parent, cm}, meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h, err = mergeCommit.HashOf()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return h.String(), nil
|
||||
}
|
||||
|
||||
func checkForUncommittedChanges(root *doltdb.RootValue, parentRoot *doltdb.RootValue) error {
|
||||
rh, err := root.HashOf()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prh, err := parentRoot.HashOf()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if rh != prh {
|
||||
return errors.New("cannot merge with uncommitted changes")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getBranchCommit(ctx *sql.Context, ok bool, val interface{}, err error, ddb *doltdb.DoltDB) (*doltdb.Commit, hash.Hash, error) {
|
||||
paramStr, ok := val.(string)
|
||||
|
||||
if !ok {
|
||||
return nil, hash.Hash{}, errors.New("branch name is not a string")
|
||||
}
|
||||
|
||||
name, err := getBranchInsensitive(ctx, paramStr, ddb)
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
cs, err := doltdb.NewCommitSpec("HEAD", name)
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
cm, err := ddb.Resolve(ctx, cs)
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
cmh, err := cm.HashOf()
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
return cm, cmh, nil
|
||||
}
|
||||
|
||||
func getParent(ctx *sql.Context, err error, sess *sqle.DoltSession, dbName string) (*doltdb.Commit, hash.Hash, *doltdb.RootValue, error) {
|
||||
parent, ph, err := sess.GetParentCommit(ctx, dbName)
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, nil, err
|
||||
}
|
||||
|
||||
parentRoot, err := parent.GetRootValue()
|
||||
|
||||
if err != nil {
|
||||
return nil, hash.Hash{}, nil, err
|
||||
}
|
||||
|
||||
return parent, ph, parentRoot, nil
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (cf *MergeFunc) String() string {
|
||||
return fmt.Sprintf("Merge(%s)", cf.Child.String())
|
||||
}
|
||||
|
||||
// IsNullable implements the Expression interface.
|
||||
func (cf *MergeFunc) IsNullable() bool {
|
||||
return cf.Child.IsNullable()
|
||||
}
|
||||
|
||||
// WithChildren implements the Expression interface.
|
||||
func (cf *MergeFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) {
|
||||
if len(children) != 1 {
|
||||
return nil, sql.ErrInvalidChildrenNumber.New(cf, len(children), 1)
|
||||
}
|
||||
|
||||
return NewMergeFunc(children[0]), nil
|
||||
}
|
||||
|
||||
// Type implements the Expression interface.
|
||||
func (cf *MergeFunc) Type() sql.Type {
|
||||
return sql.Text
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (dt *DiffTable) Schema() sql.Schema {
|
||||
|
||||
func (dt *DiffTable) Partitions(ctx *sql.Context) (sql.PartitionIter, error) {
|
||||
sess := DSessFromSess(ctx.Session)
|
||||
rootCmt, err := sess.GetParentCommit(ctx, dt.dbName)
|
||||
rootCmt, _, err := sess.GetParentCommit(ctx, dt.dbName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -122,33 +122,34 @@ func (sess *DoltSession) GetRoot(dbName string) (*doltdb.RootValue, bool) {
|
||||
}
|
||||
|
||||
// GetParentCommit returns the parent commit of the current session.
|
||||
func (sess *DoltSession) GetParentCommit(ctx context.Context, dbName string) (*doltdb.Commit, error) {
|
||||
func (sess *DoltSession) GetParentCommit(ctx context.Context, dbName string) (*doltdb.Commit, hash.Hash, error) {
|
||||
dbd, dbFound := sess.dbDatas[dbName]
|
||||
|
||||
if !dbFound {
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
return nil, hash.Hash{}, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
_, value := sess.Session.Get(dbName + HeadKeySuffix)
|
||||
valStr, isStr := value.(string)
|
||||
|
||||
if !isStr || !hash.IsValid(valStr) {
|
||||
return nil, doltdb.ErrInvalidHash
|
||||
return nil, hash.Hash{}, doltdb.ErrInvalidHash
|
||||
}
|
||||
|
||||
h := hash.Parse(valStr)
|
||||
cs, err := doltdb.NewCommitSpec(valStr, "")
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
cm, err := dbd.ddb.Resolve(ctx, cs)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, hash.Hash{}, err
|
||||
}
|
||||
|
||||
return cm, nil
|
||||
return cm, h, nil
|
||||
}
|
||||
|
||||
func (sess *DoltSession) Set(ctx context.Context, key string, typ sql.Type, value interface{}) error {
|
||||
|
||||
@@ -70,7 +70,7 @@ func NewHistoryTable(ctx *sql.Context, db Database, tblName string) (sql.Table,
|
||||
return nil, sql.ErrDatabaseNotFound.New(dbName)
|
||||
}
|
||||
|
||||
head, err := sess.GetParentCommit(ctx, dbName)
|
||||
head, _, err := sess.GetParentCommit(ctx, dbName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -84,7 +84,7 @@ type LogItr struct {
|
||||
// NewLogItr creates a LogItr from the current environment.
|
||||
func NewLogItr(sqlCtx *sql.Context, dbName string, ddb *doltdb.DoltDB) (*LogItr, error) {
|
||||
sess := DSessFromSess(sqlCtx.Session)
|
||||
commit, err := sess.GetParentCommit(sqlCtx, dbName)
|
||||
commit, _, err := sess.GetParentCommit(sqlCtx, dbName)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
Reference in New Issue
Block a user