mirror of
https://github.com/dolthub/dolt.git
synced 2026-02-23 18:19:50 -06:00
Merge remote-tracking branch 'origin/main' into elian/10462
This commit is contained in:
@@ -61,6 +61,7 @@ func CreateCommitArgParser(supportsBranchFlag bool) *argparser.ArgParser {
|
||||
ap.SupportsFlag(UpperCaseAllFlag, "A", "Adds all tables and databases (including new tables) in the working set to the staged set.")
|
||||
ap.SupportsFlag(AmendFlag, "", "Amend previous commit")
|
||||
ap.SupportsOptionalString(SignFlag, "S", "key-id", "Sign the commit using GPG. If no key-id is provided the key-id is taken from 'user.signingkey' the in the configuration")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification")
|
||||
if supportsBranchFlag {
|
||||
ap.SupportsString(BranchParam, "", "branch", "Commit to the specified branch instead of the current branch.")
|
||||
}
|
||||
@@ -96,6 +97,7 @@ func CreateMergeArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.")
|
||||
ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.")
|
||||
ap.SupportsString(AuthorParam, "", "author", "Specify an explicit author using the standard A U Thor {{.LessThan}}author@example.com{{.GreaterThan}} format.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
|
||||
|
||||
return ap
|
||||
}
|
||||
@@ -116,6 +118,7 @@ func CreateRebaseArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state")
|
||||
ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan")
|
||||
ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before rebase")
|
||||
return ap
|
||||
}
|
||||
|
||||
@@ -193,6 +196,7 @@ func CreateCherryPickArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+
|
||||
"Note that use of this option only keeps commits that were initially empty. "+
|
||||
"Commits which become empty, due to a previous commit, will cause cherry-pick to fail.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before cherry-pick")
|
||||
ap.TooManyArgsErrorFunc = func(receivedArgs []string) error {
|
||||
return errors.New("cherry-picking multiple commits is not supported yet.")
|
||||
}
|
||||
@@ -230,6 +234,7 @@ func CreatePullArgParser() *argparser.ArgParser {
|
||||
ap.SupportsString(UserFlag, "", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
|
||||
ap.SupportsFlag(PruneFlag, "p", "After fetching, remove any remote-tracking references that don't exist on the remote.")
|
||||
ap.SupportsFlag(SilentFlag, "", "Suppress progress information.")
|
||||
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
|
||||
return ap
|
||||
}
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ const (
|
||||
SilentFlag = "silent"
|
||||
SingleBranchFlag = "single-branch"
|
||||
SkipEmptyFlag = "skip-empty"
|
||||
SkipVerificationFlag = "skip-verification"
|
||||
SoftResetParam = "soft"
|
||||
SquashParam = "squash"
|
||||
StagedFlag = "staged"
|
||||
|
||||
@@ -266,6 +266,10 @@ func constructParametrizedDoltCommitQuery(msg string, apr *argparser.ArgParseRes
|
||||
writeToBuffer("--skip-empty")
|
||||
}
|
||||
|
||||
if apr.Contains(cli.SkipVerificationFlag) {
|
||||
writeToBuffer("--skip-verification")
|
||||
}
|
||||
|
||||
cfgSign := cliCtx.Config().GetStringOrDefault("sqlserver.global.gpgsign", "")
|
||||
if apr.Contains(cli.SignFlag) || strings.ToLower(cfgSign) == "true" {
|
||||
writeToBuffer("--gpg-sign")
|
||||
|
||||
@@ -318,6 +318,10 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx
|
||||
params = append(params, msg)
|
||||
}
|
||||
|
||||
if apr.Contains(cli.SkipVerificationFlag) {
|
||||
writeToBuffer("--skip-verification", false)
|
||||
}
|
||||
|
||||
if !apr.Contains(cli.AbortParam) && !apr.Contains(cli.SquashParam) {
|
||||
writeToBuffer("?", true)
|
||||
params = append(params, apr.Arg(0))
|
||||
|
||||
@@ -579,7 +579,19 @@ func (rcv *RebaseState) MutateRebasingStarted(n bool) bool {
|
||||
return rcv._tab.MutateBoolSlot(16, n)
|
||||
}
|
||||
|
||||
const RebaseStateNumFields = 7
|
||||
func (rcv *RebaseState) SkipVerification() bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(18))
|
||||
if o != 0 {
|
||||
return rcv._tab.GetBool(o + rcv._tab.Pos)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *RebaseState) MutateSkipVerification(n bool) bool {
|
||||
return rcv._tab.MutateBoolSlot(18, n)
|
||||
}
|
||||
|
||||
const RebaseStateNumFields = 8
|
||||
|
||||
func RebaseStateStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(RebaseStateNumFields)
|
||||
@@ -614,6 +626,9 @@ func RebaseStateAddLastAttemptedStep(builder *flatbuffers.Builder, lastAttempted
|
||||
func RebaseStateAddRebasingStarted(builder *flatbuffers.Builder, rebasingStarted bool) {
|
||||
builder.PrependBoolSlot(6, rebasingStarted, false)
|
||||
}
|
||||
func RebaseStateAddSkipVerification(builder *flatbuffers.Builder, skipVerification bool) {
|
||||
builder.PrependBoolSlot(7, skipVerification, false)
|
||||
}
|
||||
func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
}
|
||||
|
||||
@@ -52,6 +52,9 @@ type CherryPickOptions struct {
|
||||
// and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git
|
||||
// and Dolt rebase implementations, the default action is to keep commits that start off as empty.
|
||||
EmptyCommitHandling doltdb.EmptyCommitHandling
|
||||
|
||||
// SkipVerification controls whether test validation should be skipped before creating commits.
|
||||
SkipVerification bool
|
||||
}
|
||||
|
||||
// NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick.
|
||||
@@ -61,6 +64,7 @@ func NewCherryPickOptions() CherryPickOptions {
|
||||
CommitMessage: "",
|
||||
CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit,
|
||||
EmptyCommitHandling: doltdb.ErrorOnEmptyCommit,
|
||||
SkipVerification: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,9 +163,10 @@ func CreateCommitStagedPropsFromCherryPickOptions(ctx *sql.Context, options Cher
|
||||
}
|
||||
|
||||
commitProps := actions.CommitStagedProps{
|
||||
Date: originalMeta.Time(),
|
||||
Name: originalMeta.Name,
|
||||
Email: originalMeta.Email,
|
||||
Date: originalMeta.Time(),
|
||||
Name: originalMeta.Name,
|
||||
Email: originalMeta.Email,
|
||||
SkipVerification: options.SkipVerification,
|
||||
}
|
||||
|
||||
if options.CommitMessage != "" {
|
||||
|
||||
@@ -472,6 +472,10 @@ func encodeTableNameForSerialization(name TableName) string {
|
||||
// decodeTableNameFromSerialization decodes a table name from a serialized string. See notes on serialization in
|
||||
// |encodeTableNameForSerialization|
|
||||
func decodeTableNameFromSerialization(encodedName string) (TableName, bool) {
|
||||
if len(encodedName) == 0 {
|
||||
return TableName{}, false
|
||||
}
|
||||
|
||||
if encodedName[0] != 0 {
|
||||
return TableName{Name: encodedName}, true
|
||||
} else if len(encodedName) >= 4 { // 2 null bytes plus at least one char for schema and table name
|
||||
|
||||
@@ -75,6 +75,8 @@ type RebaseState struct {
|
||||
// rebasingStarted is true once the rebase plan has been started to execute. Once rebasingStarted is true, the
|
||||
// value in lastAttemptedStep has been initialized and is valid to read.
|
||||
rebasingStarted bool
|
||||
// skipVerification indicates whether test validation should be skipped during rebase operations.
|
||||
skipVerification bool
|
||||
}
|
||||
|
||||
// Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point
|
||||
@@ -120,6 +122,10 @@ func (rs RebaseState) WithRebasingStarted(rebasingStarted bool) *RebaseState {
|
||||
return &rs
|
||||
}
|
||||
|
||||
func (rs RebaseState) SkipVerification() bool {
|
||||
return rs.skipVerification
|
||||
}
|
||||
|
||||
type MergeState struct {
|
||||
// the source commit
|
||||
commit *Commit
|
||||
@@ -322,13 +328,14 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe
|
||||
// the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED
|
||||
// root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY
|
||||
// if it contains only ignored tables.
|
||||
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) {
|
||||
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling, skipVerification bool) (*WorkingSet, error) {
|
||||
ws.rebaseState = &RebaseState{
|
||||
ontoCommit: ontoCommit,
|
||||
preRebaseWorking: previousRoot,
|
||||
branch: branch,
|
||||
commitBecomesEmptyHandling: commitBecomesEmptyHandling,
|
||||
emptyCommitHandling: emptyCommitHandling,
|
||||
skipVerification: skipVerification,
|
||||
}
|
||||
|
||||
ontoRoot, err := ontoCommit.GetRootValue(ctx)
|
||||
@@ -549,6 +556,7 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter,
|
||||
emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)),
|
||||
lastAttemptedStep: dsws.RebaseState.LastAttemptedStep(ctx),
|
||||
rebasingStarted: dsws.RebaseState.RebasingStarted(ctx),
|
||||
skipVerification: dsws.RebaseState.SkipVerification(ctx),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -646,7 +654,7 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W
|
||||
|
||||
rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch,
|
||||
uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling),
|
||||
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted)
|
||||
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted, ws.rebaseState.skipVerification)
|
||||
}
|
||||
|
||||
return &datas.WorkingSetSpec{
|
||||
|
||||
116
go/libraries/doltcore/env/actions/commit.go
vendored
116
go/libraries/doltcore/env/actions/commit.go
vendored
@@ -15,8 +15,12 @@
|
||||
package actions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gms "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
|
||||
@@ -25,14 +29,42 @@ import (
|
||||
)
|
||||
|
||||
type CommitStagedProps struct {
|
||||
Message string
|
||||
Date time.Time
|
||||
AllowEmpty bool
|
||||
SkipEmpty bool
|
||||
Amend bool
|
||||
Force bool
|
||||
Name string
|
||||
Email string
|
||||
Message string
|
||||
Date time.Time
|
||||
AllowEmpty bool
|
||||
SkipEmpty bool
|
||||
Amend bool
|
||||
Force bool
|
||||
Name string
|
||||
Email string
|
||||
SkipVerification bool
|
||||
}
|
||||
|
||||
const (
|
||||
// System variable name, defined here to avoid circular imports
|
||||
DoltCommitVerificationGroups = "dolt_commit_verification_groups"
|
||||
)
|
||||
|
||||
// GetCommitRunTestGroups returns the test groups to run for commit operations
|
||||
// Returns empty slice if no tests should be run, ["*"] if all tests should be run,
|
||||
// or specific group names if only those groups should be run
|
||||
func GetCommitRunTestGroups() []string {
|
||||
_, val, ok := sql.SystemVariables.GetGlobal(DoltCommitVerificationGroups)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if stringVal, ok := val.(string); ok && stringVal != "" {
|
||||
if stringVal == "*" {
|
||||
return []string{"*"}
|
||||
}
|
||||
// Split by comma and trim whitespace
|
||||
groups := strings.Split(stringVal, ",")
|
||||
for i, group := range groups {
|
||||
groups[i] = strings.TrimSpace(group)
|
||||
}
|
||||
return groups
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCommitStaged returns a new pending commit with the roots and commit properties given.
|
||||
@@ -114,6 +146,16 @@ func GetCommitStaged(
|
||||
}
|
||||
}
|
||||
|
||||
if !props.SkipVerification {
|
||||
testGroups := GetCommitRunTestGroups()
|
||||
if len(testGroups) > 0 {
|
||||
err := runCommitVerification(ctx, testGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
meta, err := datas.NewCommitMetaWithUserTS(props.Name, props.Email, props.Message, props.Date)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -121,3 +163,61 @@ func GetCommitStaged(
|
||||
|
||||
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
|
||||
}
|
||||
|
||||
func runCommitVerification(ctx *sql.Context, testGroups []string) error {
|
||||
type sessionInterface interface {
|
||||
sql.Session
|
||||
GenericProvider() sql.MutableDatabaseProvider
|
||||
}
|
||||
|
||||
session, ok := ctx.Session.(sessionInterface)
|
||||
if !ok {
|
||||
return fmt.Errorf("session does not provide database provider interface")
|
||||
}
|
||||
|
||||
provider := session.GenericProvider()
|
||||
engine := gms.NewDefault(provider)
|
||||
|
||||
return runTestsUsingDtablefunctions(ctx, engine, testGroups)
|
||||
}
|
||||
|
||||
// runTestsUsingDtablefunctions runs tests using the dtablefunctions package against the staged root
|
||||
func runTestsUsingDtablefunctions(ctx *sql.Context, engine *gms.Engine, testGroups []string) error {
|
||||
if len(testGroups) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var allFailures []string
|
||||
|
||||
for _, group := range testGroups {
|
||||
query := fmt.Sprintf("SELECT * FROM dolt_test_run('%s')", group)
|
||||
_, iter, _, err := engine.Query(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run dolt_test_run for group %s: %w", group, err)
|
||||
}
|
||||
|
||||
for {
|
||||
row, rErr := iter.Next(ctx)
|
||||
if rErr == io.EOF {
|
||||
break
|
||||
}
|
||||
if rErr != nil {
|
||||
return fmt.Errorf("error reading test results: %w", rErr)
|
||||
}
|
||||
|
||||
// Extract status (column 3)
|
||||
status := fmt.Sprintf("%v", row[3])
|
||||
if status != "PASS" {
|
||||
testName := fmt.Sprintf("%v", row[0])
|
||||
message := fmt.Sprintf("%v", row[4])
|
||||
allFailures = append(allFailures, fmt.Sprintf("%s (%s)", testName, message))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(allFailures) > 0 {
|
||||
return fmt.Errorf("commit verification failed: %s", strings.Join(allFailures, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,447 +0,0 @@
|
||||
// Copyright 2025 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package actions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/exp/constraints"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/val"
|
||||
)
|
||||
|
||||
const (
|
||||
AssertionExpectedRows = "expected_rows"
|
||||
AssertionExpectedColumns = "expected_columns"
|
||||
AssertionExpectedSingleValue = "expected_single_value"
|
||||
)
|
||||
|
||||
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
|
||||
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
|
||||
// testPassed indicates whether the test was successful or not.
|
||||
// message is a string used to indicate test failures, and will not halt the overall process.
|
||||
// message will be empty if the test passed.
|
||||
// err indicates runtime failures and will stop dolt_test_run from proceeding.
|
||||
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
|
||||
switch assertion {
|
||||
case AssertionExpectedRows:
|
||||
message, err = expectRows(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedColumns:
|
||||
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedSingleValue:
|
||||
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
|
||||
default:
|
||||
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
} else if message != "" {
|
||||
return false, message, nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(row) != 1 {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
|
||||
}
|
||||
_, err = queryResult.Next(sqlCtx)
|
||||
if err == nil { //If multiple rows were given, we should error out
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
|
||||
} else if err != io.EOF { // "True" error, so we should quit out
|
||||
return "", err
|
||||
}
|
||||
|
||||
if value == nil { // If we're expecting a null value, we don't need to type switch
|
||||
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
|
||||
}
|
||||
|
||||
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
|
||||
// of "0" and "1", which are valid integers and are covered below.
|
||||
if *value != "0" && *value != "1" {
|
||||
if expectedBool, err := strconv.ParseBool(*value); err == nil {
|
||||
actualBool, boolErr := getInterfaceAsBool(row[0])
|
||||
if boolErr != nil {
|
||||
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
|
||||
}
|
||||
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch actualValue := row[0].(type) {
|
||||
case int8:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int16:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int32:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int64:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
|
||||
case int:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint8:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint16:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint32:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint64:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case float64:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
|
||||
case float32:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
|
||||
case decimal.Decimal:
|
||||
expectedDecimal, err := decimal.NewFromString(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
|
||||
}
|
||||
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
|
||||
case time.Time:
|
||||
expectedTime, format, err := parseTestsDate(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
|
||||
}
|
||||
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
|
||||
case *val.TextStorage, string:
|
||||
actualString, err := GetStringColAsString(sqlCtx, actualValue)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
|
||||
default:
|
||||
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedRows, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numRows int
|
||||
for {
|
||||
_, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
numRows++
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
|
||||
}
|
||||
|
||||
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedColumns, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numColumns int
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
numColumns = len(row)
|
||||
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
|
||||
}
|
||||
|
||||
// compareTestAssertion is a generic function used for comparing string, ints, floats.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<":
|
||||
if actualValue >= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<=":
|
||||
if actualValue > expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">":
|
||||
if actualValue <= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">=":
|
||||
if actualValue < expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
|
||||
// It returns the parsed time, the format that succeeded, and an error if applicable.
|
||||
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
|
||||
// List of valid formats
|
||||
formats := []string{
|
||||
time.DateOnly,
|
||||
time.DateTime,
|
||||
time.TimeOnly,
|
||||
time.RFC3339,
|
||||
time.RFC1123Z,
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
|
||||
return parsedTime, format, nil
|
||||
} else {
|
||||
err = parseErr
|
||||
}
|
||||
}
|
||||
return time.Time{}, "", err
|
||||
}
|
||||
|
||||
// compareDates is a function used for comparing time values.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
|
||||
expectedStr := expectedValue.Format(format)
|
||||
realStr := realValue.Format(format)
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<":
|
||||
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">":
|
||||
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.Before(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareDecimals is a function used for comparing decimals.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<":
|
||||
if realValue.GreaterThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.GreaterThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">":
|
||||
if realValue.LessThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.LessThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getTinyIntColAsBool returns the value interface{} as a bool
|
||||
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
|
||||
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
|
||||
func getInterfaceAsBool(col interface{}) (bool, error) {
|
||||
switch v := col.(type) {
|
||||
case bool:
|
||||
return v, nil
|
||||
case int:
|
||||
return v == 1, nil
|
||||
case int8:
|
||||
return v == 1, nil
|
||||
case int16:
|
||||
return v == 1, nil
|
||||
case int32:
|
||||
return v == 1, nil
|
||||
case int64:
|
||||
return v == 1, nil
|
||||
case uint:
|
||||
return v == 1, nil
|
||||
case uint8:
|
||||
return v == 1, nil
|
||||
case uint16:
|
||||
return v == 1, nil
|
||||
case uint32:
|
||||
return v == 1, nil
|
||||
case uint64:
|
||||
return v == 1, nil
|
||||
case string:
|
||||
return v == "1", nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
|
||||
}
|
||||
}
|
||||
|
||||
// compareBooleans is a function used for comparing boolean values.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if expectedValue != realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue == realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareNullValue is a function used for comparing a null value.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != nil {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == nil {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringColAsString is a function that returns a text column as a string.
|
||||
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
|
||||
// so we use a special parser to get the correct string values
|
||||
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
return &str, err
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else if tableValue == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
@@ -1955,7 +1955,10 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
|
||||
return err
|
||||
}
|
||||
|
||||
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && !doltdb.IsFullTextTable(tableName) && !doltdb.HasDoltCIPrefix(tableName) {
|
||||
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) &&
|
||||
!doltdb.IsFullTextTable(tableName) &&
|
||||
!doltdb.HasDoltCIPrefix(tableName) &&
|
||||
tableName != doltdb.TestsTableName { // NM4 - determine why this is required now.
|
||||
return ErrReservedTableName.New(tableName)
|
||||
}
|
||||
|
||||
|
||||
@@ -103,6 +103,8 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e
|
||||
cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit
|
||||
}
|
||||
|
||||
cherryPickOptions.SkipVerification = apr.Contains(cli.SkipVerificationFlag)
|
||||
|
||||
commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions)
|
||||
if err != nil {
|
||||
return "", 0, 0, 0, err
|
||||
|
||||
@@ -163,14 +163,15 @@ func doDoltCommit(ctx *sql.Context, args []string) (string, bool, error) {
|
||||
}
|
||||
|
||||
csp := actions.CommitStagedProps{
|
||||
Message: msg,
|
||||
Date: t,
|
||||
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
|
||||
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
|
||||
Amend: amend,
|
||||
Force: apr.Contains(cli.ForceFlag),
|
||||
Name: name,
|
||||
Email: email,
|
||||
Message: msg,
|
||||
Date: t,
|
||||
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
|
||||
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
|
||||
Amend: amend,
|
||||
Force: apr.Contains(cli.ForceFlag),
|
||||
Name: name,
|
||||
Email: email,
|
||||
SkipVerification: apr.Contains(cli.SkipVerificationFlag),
|
||||
}
|
||||
|
||||
shouldSign, err := dsess.GetBooleanSystemVar(ctx, "gpgsign")
|
||||
|
||||
@@ -180,7 +180,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err
|
||||
msg = userMsg
|
||||
}
|
||||
|
||||
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
|
||||
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
|
||||
if err != nil {
|
||||
return commit, conflicts, fastForward, "", err
|
||||
}
|
||||
@@ -205,6 +205,7 @@ func performMerge(
|
||||
spec *merge.MergeSpec,
|
||||
noCommit bool,
|
||||
msg string,
|
||||
skipVerification bool,
|
||||
) (*doltdb.WorkingSet, string, int, int, string, error) {
|
||||
// todo: allow merges even when an existing merge is uncommitted
|
||||
if ws.MergeActive() {
|
||||
@@ -234,7 +235,7 @@ func performMerge(
|
||||
if canFF {
|
||||
if spec.FFMode == merge.NoFastForward {
|
||||
var commit *doltdb.Commit
|
||||
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
|
||||
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit, skipVerification)
|
||||
if err == doltdb.ErrUnresolvedConflictsOrViolations {
|
||||
// if there are unresolved conflicts, write the resulting working set back to the session and return an
|
||||
// error message
|
||||
@@ -306,7 +307,10 @@ func performMerge(
|
||||
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
|
||||
args := []string{"-m", msg, "--author", author}
|
||||
if spec.Force {
|
||||
args = append(args, "--force")
|
||||
args = append(args, "--"+cli.ForceFlag)
|
||||
}
|
||||
if skipVerification {
|
||||
args = append(args, "--"+cli.SkipVerificationFlag)
|
||||
}
|
||||
commit, _, err = doDoltCommit(ctx, args)
|
||||
if err != nil {
|
||||
@@ -405,6 +409,7 @@ func executeNoFFMerge(
|
||||
dbName string,
|
||||
ws *doltdb.WorkingSet,
|
||||
noCommit bool,
|
||||
skipVerification bool,
|
||||
) (*doltdb.WorkingSet, *doltdb.Commit, error) {
|
||||
mergeRoot, err := spec.MergeC.GetRootValue(ctx)
|
||||
if err != nil {
|
||||
@@ -444,11 +449,12 @@ func executeNoFFMerge(
|
||||
}
|
||||
|
||||
pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{
|
||||
Message: msg,
|
||||
Date: spec.Date,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
Message: msg,
|
||||
Date: spec.Date,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
SkipVerification: skipVerification,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -237,7 +237,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
|
||||
return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New()
|
||||
}
|
||||
|
||||
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
|
||||
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
|
||||
if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
|
||||
return conflicts, fastForward, "", err
|
||||
}
|
||||
|
||||
@@ -216,7 +216,9 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) {
|
||||
} else if apr.NArg() > 1 {
|
||||
return 1, "", fmt.Errorf("too many args")
|
||||
}
|
||||
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
|
||||
skipVerification := apr.Contains(cli.SkipVerificationFlag)
|
||||
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return 1, "", err
|
||||
}
|
||||
@@ -263,7 +265,7 @@ func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.Emp
|
||||
// startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased
|
||||
// commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but
|
||||
// do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits.
|
||||
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error {
|
||||
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, skipVerification bool) error {
|
||||
if upstreamPoint == "" {
|
||||
return fmt.Errorf("no upstream branch specified")
|
||||
}
|
||||
@@ -351,7 +353,7 @@ func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandl
|
||||
}
|
||||
|
||||
newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working,
|
||||
commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -716,7 +718,8 @@ func continueRebase(ctx *sql.Context) rebaseResult {
|
||||
|
||||
result := processRebasePlanStep(ctx, &step,
|
||||
workingSet.RebaseState().CommitBecomesEmptyHandling(),
|
||||
workingSet.RebaseState().EmptyCommitHandling())
|
||||
workingSet.RebaseState().EmptyCommitHandling(),
|
||||
workingSet.RebaseState().SkipVerification())
|
||||
if result.err != nil || result.status != 0 || result.halt {
|
||||
return result
|
||||
}
|
||||
@@ -803,7 +806,7 @@ func commitManuallyStagedChangesForStep(ctx *sql.Context, step rebase.RebasePlan
|
||||
}
|
||||
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(),
|
||||
workingSet.RebaseState().EmptyCommitHandling())
|
||||
workingSet.RebaseState().EmptyCommitHandling(), workingSet.RebaseState().SkipVerification())
|
||||
|
||||
doltDB, ok := doltSession.GetDoltDB(ctx, ctx.GetCurrentDatabase())
|
||||
if !ok {
|
||||
@@ -861,6 +864,7 @@ func processRebasePlanStep(
|
||||
planStep *rebase.RebasePlanStep,
|
||||
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
|
||||
emptyCommitHandling doltdb.EmptyCommitHandling,
|
||||
skipVerification bool,
|
||||
) rebaseResult {
|
||||
// Make sure we have a transaction opened for the session
|
||||
// NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started
|
||||
@@ -878,7 +882,7 @@ func processRebasePlanStep(
|
||||
return newRebaseSuccess("")
|
||||
}
|
||||
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling)
|
||||
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
|
||||
if err != nil {
|
||||
return newRebaseError(err)
|
||||
}
|
||||
@@ -886,12 +890,19 @@ func processRebasePlanStep(
|
||||
return handleRebaseCherryPick(ctx, planStep, *options)
|
||||
}
|
||||
|
||||
func createCherryPickOptionsForRebaseStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) (*cherry_pick.CherryPickOptions, error) {
|
||||
func createCherryPickOptionsForRebaseStep(
|
||||
ctx *sql.Context,
|
||||
planStep *rebase.RebasePlanStep,
|
||||
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
|
||||
emptyCommitHandling doltdb.EmptyCommitHandling,
|
||||
skipVerification bool,
|
||||
) (*cherry_pick.CherryPickOptions, error) {
|
||||
// Override the default empty commit handling options for cherry-pick, since
|
||||
// rebase has slightly different defaults
|
||||
options := cherry_pick.NewCherryPickOptions()
|
||||
options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling
|
||||
options.EmptyCommitHandling = emptyCommitHandling
|
||||
options.SkipVerification = skipVerification
|
||||
|
||||
switch planStep.Action {
|
||||
case rebase.RebaseActionDrop, rebase.RebaseActionPick, rebase.RebaseActionEdit:
|
||||
|
||||
@@ -17,7 +17,9 @@ package dtablefunctions
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
gms "github.com/dolthub/go-mysql-server"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
@@ -26,10 +28,13 @@ import (
|
||||
"github.com/dolthub/vitess/go/vt/sqlparser"
|
||||
"github.com/gocraft/dbr/v2"
|
||||
"github.com/gocraft/dbr/v2/dialect"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/exp/constraints"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
|
||||
"github.com/dolthub/dolt/go/store/val"
|
||||
)
|
||||
|
||||
const testsRunDefaultRowCount = 10
|
||||
@@ -39,12 +44,13 @@ var _ sql.CatalogTableFunction = (*TestsRunTableFunction)(nil)
|
||||
var _ sql.ExecSourceRel = (*TestsRunTableFunction)(nil)
|
||||
var _ sql.AuthorizationCheckerNode = (*TestsRunTableFunction)(nil)
|
||||
|
||||
type testResult struct {
|
||||
testName string
|
||||
groupName string
|
||||
query string
|
||||
status string
|
||||
message string
|
||||
// TestResult represents the result of running a single test
|
||||
type TestResult struct {
|
||||
TestName string
|
||||
GroupName string
|
||||
Query string
|
||||
Status string
|
||||
Message string
|
||||
}
|
||||
|
||||
type TestsRunTableFunction struct {
|
||||
@@ -199,7 +205,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resultRow := sql.NewRow(result.testName, result.groupName, result.query, result.status, result.message)
|
||||
resultRow := sql.NewRow(result.TestName, result.GroupName, result.Query, result.Status, result.Message)
|
||||
resultRows = append(resultRows, resultRow)
|
||||
}
|
||||
}
|
||||
@@ -220,7 +226,7 @@ func (trtf *TestsRunTableFunction) RowCount(_ *sql.Context) (uint64, bool, error
|
||||
return testsRunDefaultRowCount, false, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResult, err error) {
|
||||
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result TestResult, err error) {
|
||||
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -237,9 +243,9 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
|
||||
if err != nil {
|
||||
message = fmt.Sprintf("Query error: %s", err.Error())
|
||||
} else {
|
||||
testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
testPassed, message, err = AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
if err != nil {
|
||||
return testResult{}, err
|
||||
return TestResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -253,11 +259,49 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
|
||||
if groupName != nil {
|
||||
groupString = *groupName
|
||||
}
|
||||
result = testResult{*testName, groupString, *query, status, message}
|
||||
result = TestResult{*testName, groupString, *query, status, message}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) queryAndAssertWithFunc(row sql.Row, assertDataFunc AssertDataFunc) (result TestResult, err error) {
|
||||
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
message, err := validateQuery(trtf.ctx, trtf.catalog, *query)
|
||||
if err != nil && message == "" {
|
||||
message = fmt.Sprintf("query error: %s", err.Error())
|
||||
}
|
||||
|
||||
var testPassed bool
|
||||
if message == "" {
|
||||
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query)
|
||||
if err != nil {
|
||||
message = fmt.Sprintf("Query error: %s", err.Error())
|
||||
} else {
|
||||
testPassed, message, err = assertDataFunc(trtf.ctx, *assertion, *comparison, value, queryResult)
|
||||
if err != nil {
|
||||
return TestResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status := "PASS"
|
||||
if !testPassed {
|
||||
status = "FAIL"
|
||||
}
|
||||
|
||||
var groupString string
|
||||
if groupName != nil {
|
||||
groupString = *groupName
|
||||
}
|
||||
result = TestResult{*testName, groupString, *query, status, message}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
|
||||
// Original behavior when root is nil - use SQL queries against current session
|
||||
var queries []string
|
||||
|
||||
if arg == "*" {
|
||||
@@ -320,28 +364,31 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er
|
||||
}
|
||||
|
||||
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) {
|
||||
if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil {
|
||||
if testName, err = getStringColAsString(ctx, row[0]); err != nil {
|
||||
return
|
||||
}
|
||||
if groupName, err = actions.GetStringColAsString(ctx, row[1]); err != nil {
|
||||
if groupName, err = getStringColAsString(ctx, row[1]); err != nil {
|
||||
return
|
||||
}
|
||||
if query, err = actions.GetStringColAsString(ctx, row[2]); err != nil {
|
||||
if query, err = getStringColAsString(ctx, row[2]); err != nil {
|
||||
return
|
||||
}
|
||||
if assertion, err = actions.GetStringColAsString(ctx, row[3]); err != nil {
|
||||
if assertion, err = getStringColAsString(ctx, row[3]); err != nil {
|
||||
return
|
||||
}
|
||||
if comparison, err = actions.GetStringColAsString(ctx, row[4]); err != nil {
|
||||
if comparison, err = getStringColAsString(ctx, row[4]); err != nil {
|
||||
return
|
||||
}
|
||||
if value, err = actions.GetStringColAsString(ctx, row[5]); err != nil {
|
||||
if value, err = getStringColAsString(ctx, row[5]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return testName, groupName, query, assertion, comparison, value, nil
|
||||
}
|
||||
|
||||
// AssertDataFunc defines the function signature for asserting test data
|
||||
type AssertDataFunc func(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error)
|
||||
|
||||
func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, error) {
|
||||
// We first check if the query contains multiple sql statements
|
||||
if statements, err := sqlparser.SplitStatementToPieces(query); err != nil {
|
||||
@@ -361,3 +408,455 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string,
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Simple inline assertion constants to avoid circular imports
|
||||
const (
|
||||
AssertionExpectedRows = "expected_rows"
|
||||
AssertionExpectedColumns = "expected_columns"
|
||||
AssertionExpectedSingleValue = "expected_single_value"
|
||||
)
|
||||
|
||||
// getStringColAsString safely converts a sql value to string
|
||||
func getStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if tableValue == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &str, nil
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
|
||||
// readTableDataFromDoltTable reads test data directly from a dolt table
|
||||
func (trtf *TestsRunTableFunction) readTableDataFromDoltTable(table *doltdb.Table, arg string) ([]sql.Row, error) {
|
||||
// This is a complex implementation that requires reading table data directly from dolt storage
|
||||
// For now, return an error that clearly indicates this needs to be implemented
|
||||
// The table scan would involve:
|
||||
// 1. Getting the table schema
|
||||
// 2. Creating a table iterator
|
||||
// 3. Reading and filtering rows based on the arg (test_name or test_group)
|
||||
// 4. Converting dolt storage format to SQL rows
|
||||
//
|
||||
// This is a significant implementation that requires understanding dolt's storage internals
|
||||
return nil, fmt.Errorf("direct table reading from dolt storage not yet implemented for table scan of dolt_tests - this requires implementing table iteration and row conversion from dolt's internal storage format")
|
||||
}
|
||||
|
||||
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
|
||||
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
|
||||
// testPassed indicates whether the test was successful or not.
|
||||
// message is a string used to indicate test failures, and will not halt the overall process.
|
||||
// message will be empty if the test passed.
|
||||
// err indicates runtime failures and will stop dolt_test_run from proceeding.
|
||||
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
|
||||
switch assertion {
|
||||
case AssertionExpectedRows:
|
||||
message, err = expectRows(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedColumns:
|
||||
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
|
||||
case AssertionExpectedSingleValue:
|
||||
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
|
||||
default:
|
||||
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
} else if message != "" {
|
||||
return false, message, nil
|
||||
}
|
||||
return true, "", nil
|
||||
}
|
||||
|
||||
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(row) != 1 {
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
|
||||
}
|
||||
_, err = queryResult.Next(sqlCtx)
|
||||
if err == nil { //If multiple rows were given, we should error out
|
||||
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
|
||||
} else if err != io.EOF { // "True" error, so we should quit out
|
||||
return "", err
|
||||
}
|
||||
|
||||
if value == nil { // If we're expecting a null value, we don't need to type switch
|
||||
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
|
||||
}
|
||||
|
||||
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
|
||||
// of "0" and "1", which are valid integers and are covered below.
|
||||
if *value != "0" && *value != "1" {
|
||||
if expectedBool, err := strconv.ParseBool(*value); err == nil {
|
||||
actualBool, boolErr := getInterfaceAsBool(row[0])
|
||||
if boolErr != nil {
|
||||
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
|
||||
}
|
||||
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
switch actualValue := row[0].(type) {
|
||||
case int8:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int16:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int32:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case int64:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
|
||||
case int:
|
||||
expectedInt, err := strconv.ParseInt(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint8:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint16:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint32:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint64:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
|
||||
case uint:
|
||||
expectedUint, err := strconv.ParseUint(*value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
|
||||
case float64:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 64)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
|
||||
case float32:
|
||||
expectedFloat, err := strconv.ParseFloat(*value, 32)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
|
||||
}
|
||||
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
|
||||
case decimal.Decimal:
|
||||
expectedDecimal, err := decimal.NewFromString(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
|
||||
}
|
||||
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
|
||||
case time.Time:
|
||||
expectedTime, format, err := parseTestsDate(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
|
||||
}
|
||||
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
|
||||
case *val.TextStorage, string:
|
||||
actualString, err := GetStringColAsString(sqlCtx, actualValue)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
|
||||
default:
|
||||
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
|
||||
}
|
||||
}
|
||||
|
||||
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedRows, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numRows int
|
||||
for {
|
||||
_, err := queryResult.Next(sqlCtx)
|
||||
if err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
return "", err
|
||||
}
|
||||
numRows++
|
||||
}
|
||||
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
|
||||
}
|
||||
|
||||
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
|
||||
if value == nil {
|
||||
return "null is not a valid assertion for expected_rows", nil
|
||||
}
|
||||
expectedColumns, err := strconv.Atoi(*value)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
|
||||
}
|
||||
|
||||
var numColumns int
|
||||
row, err := queryResult.Next(sqlCtx)
|
||||
if err != nil && err != io.EOF {
|
||||
return "", err
|
||||
}
|
||||
numColumns = len(row)
|
||||
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
|
||||
}
|
||||
|
||||
// compareTestAssertion is a generic function used for comparing string, ints, floats.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<":
|
||||
if actualValue >= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case "<=":
|
||||
if actualValue > expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">":
|
||||
if actualValue <= expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
case ">=":
|
||||
if actualValue < expectedValue {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
|
||||
// It returns the parsed time, the format that succeeded, and an error if applicable.
|
||||
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
|
||||
// List of valid formats
|
||||
formats := []string{
|
||||
time.DateOnly,
|
||||
time.DateTime,
|
||||
time.TimeOnly,
|
||||
time.RFC3339,
|
||||
time.RFC1123Z,
|
||||
}
|
||||
|
||||
for _, format := range formats {
|
||||
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
|
||||
return parsedTime, format, nil
|
||||
} else {
|
||||
err = parseErr
|
||||
}
|
||||
}
|
||||
return time.Time{}, "", err
|
||||
}
|
||||
|
||||
// compareDates is a function used for comparing time values.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
|
||||
expectedStr := expectedValue.Format(format)
|
||||
realStr := realValue.Format(format)
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<":
|
||||
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.After(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">":
|
||||
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.Before(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareDecimals is a function used for comparing decimals.
|
||||
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if !expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue.Equal(realValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<":
|
||||
if realValue.GreaterThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "<=":
|
||||
if realValue.GreaterThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">":
|
||||
if realValue.LessThanOrEqual(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case ">=":
|
||||
if realValue.LessThan(expectedValue) {
|
||||
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison type", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getTinyIntColAsBool returns the value interface{} as a bool
|
||||
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
|
||||
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
|
||||
func getInterfaceAsBool(col interface{}) (bool, error) {
|
||||
switch v := col.(type) {
|
||||
case bool:
|
||||
return v, nil
|
||||
case int:
|
||||
return v == 1, nil
|
||||
case int8:
|
||||
return v == 1, nil
|
||||
case int16:
|
||||
return v == 1, nil
|
||||
case int32:
|
||||
return v == 1, nil
|
||||
case int64:
|
||||
return v == 1, nil
|
||||
case uint:
|
||||
return v == 1, nil
|
||||
case uint8:
|
||||
return v == 1, nil
|
||||
case uint16:
|
||||
return v == 1, nil
|
||||
case uint32:
|
||||
return v == 1, nil
|
||||
case uint64:
|
||||
return v == 1, nil
|
||||
case string:
|
||||
return v == "1", nil
|
||||
default:
|
||||
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
|
||||
}
|
||||
}
|
||||
|
||||
// compareBooleans is a function used for comparing boolean values.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if expectedValue != realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
case "!=":
|
||||
if expectedValue == realValue {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// compareNullValue is a function used for comparing a null value.
|
||||
// It takes in a comparison string from one of: "==", "!="
|
||||
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
|
||||
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
|
||||
switch comparison {
|
||||
case "==":
|
||||
if actualValue != nil {
|
||||
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
|
||||
}
|
||||
case "!=":
|
||||
if actualValue == nil {
|
||||
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
|
||||
}
|
||||
default:
|
||||
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetStringColAsString is a function that returns a text column as a string.
|
||||
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
|
||||
// so we use a special parser to get the correct string values
|
||||
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
|
||||
if ts, ok := tableValue.(*val.TextStorage); ok {
|
||||
str, err := ts.Unwrap(sqlCtx)
|
||||
return &str, err
|
||||
} else if str, ok := tableValue.(string); ok {
|
||||
return &str, nil
|
||||
} else if tableValue == nil {
|
||||
return nil, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1210,6 +1210,11 @@ func TestDoltDdlScripts(t *testing.T) {
|
||||
RunDoltDdlScripts(t, harness)
|
||||
}
|
||||
|
||||
func TestDoltCommitVerificationScripts(t *testing.T) {
|
||||
harness := newDoltEnginetestHarness(t)
|
||||
RunDoltCommitVerificationScripts(t, harness)
|
||||
}
|
||||
|
||||
func TestBrokenDdlScripts(t *testing.T) {
|
||||
for _, script := range BrokenDDLScripts {
|
||||
t.Skip(script.Name)
|
||||
|
||||
@@ -2147,3 +2147,12 @@ func RunTransactionTestsWithEngineSetup(t *testing.T, setupEngine func(*gms.Engi
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func RunDoltCommitVerificationScripts(t *testing.T, harness DoltEnginetestHarness) {
|
||||
for _, script := range DoltCommitVerificationScripts {
|
||||
harness := harness.NewHarness(t)
|
||||
|
||||
enginetest.TestScript(t, harness, script)
|
||||
harness.Close()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +190,6 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
|
||||
for i := range dbs {
|
||||
db := dbs[i]
|
||||
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)})
|
||||
|
||||
// Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment
|
||||
// sequence trackers
|
||||
_, aiTables := enginetest.MustQuery(ctx, d.engine,
|
||||
@@ -218,6 +217,7 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
|
||||
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop database if exists %s", db)})
|
||||
}
|
||||
}
|
||||
|
||||
resetCmds = append(resetCmds, setup.SetupScript{"use mydb"})
|
||||
return resetCmds
|
||||
}
|
||||
@@ -229,7 +229,7 @@ func commitScripts(dbs []string) []setup.SetupScript {
|
||||
db := dbs[i]
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("use %s", db))
|
||||
commitCmds = append(commitCmds, "call dolt_add('.')")
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00')", db))
|
||||
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00', '--skip-verification')", db))
|
||||
}
|
||||
commitCmds = append(commitCmds, "use mydb")
|
||||
return []setup.SetupScript{commitCmds}
|
||||
|
||||
@@ -0,0 +1,538 @@
|
||||
// Copyright 2025 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package enginetest
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/enginetest"
|
||||
"github.com/dolthub/go-mysql-server/enginetest/queries"
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
)
|
||||
|
||||
// commitHashValidator validates commit hash format (32 character hex)
|
||||
type commitHashValidator struct{}
|
||||
|
||||
var _ enginetest.CustomValueValidator = &commitHashValidator{}
|
||||
|
||||
func (chv *commitHashValidator) Validate(val interface{}) (bool, error) {
|
||||
h, ok := val.(string)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
_, ok = hash.MaybeParse(h)
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
// successfulRebaseMessageValidator validates successful rebase message format
|
||||
type successfulRebaseMessageValidator struct{}
|
||||
|
||||
var _ enginetest.CustomValueValidator = &successfulRebaseMessageValidator{}
|
||||
var successfulRebaseRegex = regexp.MustCompile(`^Successfully rebased.*`)
|
||||
|
||||
func (srmv *successfulRebaseMessageValidator) Validate(val interface{}) (bool, error) {
|
||||
message, ok := val.(string)
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
return successfulRebaseRegex.MatchString(message), nil
|
||||
}
|
||||
|
||||
var commitHash = &commitHashValidator{}
|
||||
var successfulRebaseMessage = &successfulRebaseMessageValidator{}
|
||||
|
||||
var DoltCommitVerificationScripts = []queries.ScriptTest{
|
||||
{
|
||||
Name: "test verification system variables exist and have correct defaults",
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", ""},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification system variables can be set",
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
Expected: []sql.Row{{types.OkResult{}}},
|
||||
},
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", "*"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = 'unit,integration'",
|
||||
Expected: []sql.Row{{types.OkResult{}}},
|
||||
},
|
||||
{
|
||||
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
|
||||
Expected: []sql.Row{
|
||||
{"dolt_commit_verification_groups", "unit,integration"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit verification enabled - all tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with passing tests')",
|
||||
ExpectedColumns: sql.Schema{
|
||||
{Name: "hash", Type: types.LongText, Nullable: false},
|
||||
},
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit verification enabled - tests fail, commit aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit that should fail verification')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--skip-verification','-m', 'skip verification')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "commit with test verification - specific test groups",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with unit tests only')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = 'integration'",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--allow-empty', '--amend', '-m', 'fail please')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_commit('--allow-empty', '--amend', '--skip-verification', '-m', 'skip the tests')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cherry-pick with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'add test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_1_hash,'--skip-verification', '-m', 'Add Bob and update test')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'chuck@exampl.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_2_hash,'--skip-verification', '-m', 'Add Charlie')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_1_hash)",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_2_hash)",
|
||||
ExpectedErrStr: "commit verification failed: test_user_count_update (Assertion failed: expected_single_value equal to 2, got 3)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "cherry-pick with test verification enabled - tests fail, aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"call dolt_commit_hash_out(@commit_hash,'--skip-verification', '-m', 'Add Bob but dont update test')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick(@commit_hash)",
|
||||
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 1, got 2)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_cherry_pick('--skip-verification', @commit_hash)",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 1, got 2"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "rebase with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"DELETE FROM users where id = 1",
|
||||
"INSERT INTO users VALUES (1, 'Zed', 'zed@example.com')",
|
||||
"CALL dolt_commit('-am', 'drop Alice, add Zed')", // tests still pass here.
|
||||
"CALL dolt_checkout('-b', 'feature', 'HEAD~1')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Bob and update test')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Charlie, update test')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('main')",
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "rebase with test verification enabled - tests fail, aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Add Bob but dont update test')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", // this will trip the existing test.
|
||||
"CALL dolt_checkout('feature')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('main')",
|
||||
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 2, got 3)",
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--abort')",
|
||||
Expected: []sql.Row{{0, "Interactive rebase aborted"}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--skip-verification', 'main')",
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 2, got 3"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "interactive rebase with --skip-verification flag should persist across continue operations",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob but dont update test')", // This will cause test to fail
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (4, 'David', 'david@example.com')", // Add a commit to main to create divergence
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add David on main')",
|
||||
"CALL dolt_checkout('feature')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_rebase('--interactive', '--skip-verification', 'main')",
|
||||
Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_feature; adjust the rebase plan in the dolt_rebase table, then continue rebasing by calling dolt_rebase('--continue')"}},
|
||||
},
|
||||
{
|
||||
Query: "CALL dolt_rebase('--continue')", // This should NOT require --skip-verification flag but should still skip tests
|
||||
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification with no dolt_tests errors",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit without dolt_tests table')",
|
||||
ExpectedErrStr: "failed to run dolt_test_run for group *: could not find tests for argument: *",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification with mixed test groups - only specified groups run",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
|
||||
"('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with unit tests only - should pass')",
|
||||
Expected: []sql.Row{{commitHash}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "test verification error message includes test details",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_specific_failure', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_commit('-m', 'Commit with specific test failure')",
|
||||
ExpectedErrStr: "commit verification failed: test_specific_failure (Assertion failed: expected_single_value equal to 999, got 2)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with test verification enabled - tests pass",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('-m', 'Initial commit')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Bob\"', 'expected_single_value', '==', '1')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('feature')",
|
||||
Expected: []sql.Row{{commitHash, int64(1), int64(0), "merge successful"}},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with test verification enabled - tests fail, merge aborted",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('feature')",
|
||||
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 3)",
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "merge with --skip-verification flag bypasses verification",
|
||||
SetUpScript: []string{
|
||||
"SET GLOBAL dolt_commit_verification_groups = '*'",
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
|
||||
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
|
||||
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
|
||||
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
|
||||
"CALL dolt_checkout('-b', 'feature')",
|
||||
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
|
||||
"CALL dolt_checkout('main')",
|
||||
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
|
||||
"CALL dolt_add('.')",
|
||||
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "CALL dolt_merge('--skip-verification', 'feature')",
|
||||
Expected: []sql.Row{{commitHash, int64(0), int64(0), "merge successful"}},
|
||||
},
|
||||
{
|
||||
Query: "select * from dolt_test_run('*')",
|
||||
Expected: []sql.Row{
|
||||
{"test_will_fail", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 999, got 3"},
|
||||
},
|
||||
},
|
||||
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
|
||||
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
|
||||
SkipResultsCheck: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql/types"
|
||||
_ "github.com/dolthub/go-mysql-server/sql/variables"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
)
|
||||
|
||||
@@ -292,6 +293,13 @@ var DoltSystemVariables = []sql.SystemVariable{
|
||||
Type: types.NewSystemBoolType(dsess.AllowCICreation),
|
||||
Default: int8(0),
|
||||
},
|
||||
&sql.MysqlSystemVariable{
|
||||
Name: actions.DoltCommitVerificationGroups,
|
||||
Dynamic: true,
|
||||
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Global),
|
||||
Type: types.NewSystemStringType(actions.DoltCommitVerificationGroups),
|
||||
Default: "",
|
||||
},
|
||||
}
|
||||
|
||||
func AddDoltSystemVariables() {
|
||||
|
||||
@@ -67,6 +67,10 @@ table RebaseState {
|
||||
// The rebasing_started field indicates if execution of the rebase plan has been started or not. Once execution of the
|
||||
// plan has been started, the last_attempted_step field holds a reference to the most recent plan step attempted.
|
||||
rebasing_started:bool;
|
||||
|
||||
// When set to true, the rebase process will skip performing commit
|
||||
// verification if it would otherwise run.
|
||||
skip_verification:bool;
|
||||
}
|
||||
|
||||
// KEEP THIS IN SYNC WITH fileidentifiers.go
|
||||
|
||||
@@ -169,6 +169,7 @@ type RebaseState struct {
|
||||
commitBecomesEmptyHandling uint8
|
||||
emptyCommitHandling uint8
|
||||
rebasingStarted bool
|
||||
skipVerification bool
|
||||
}
|
||||
|
||||
func (rs *RebaseState) PreRebaseWorkingAddr() hash.Hash {
|
||||
@@ -206,6 +207,10 @@ func (rs *RebaseState) EmptyCommitHandling(_ context.Context) uint8 {
|
||||
return rs.emptyCommitHandling
|
||||
}
|
||||
|
||||
func (rs *RebaseState) SkipVerification(_ context.Context) bool {
|
||||
return rs.skipVerification
|
||||
}
|
||||
|
||||
type MergeState struct {
|
||||
preMergeWorkingAddr *hash.Hash
|
||||
fromCommitAddr *hash.Hash
|
||||
@@ -457,6 +462,7 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) {
|
||||
rebaseState.EmptyCommitHandling(),
|
||||
rebaseState.LastAttemptedStep(),
|
||||
rebaseState.RebasingStarted(),
|
||||
rebaseState.SkipVerification(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -196,6 +196,7 @@ func workingset_flatbuffer(working hash.Hash, staged *hash.Hash, mergeState *Mer
|
||||
serial.RebaseStateAddEmptyCommitHandling(builder, rebaseState.emptyCommitHandling)
|
||||
serial.RebaseStateAddLastAttemptedStep(builder, rebaseState.lastAttemptedStep)
|
||||
serial.RebaseStateAddRebasingStarted(builder, rebaseState.rebasingStarted)
|
||||
serial.RebaseStateAddSkipVerification(builder, rebaseState.skipVerification)
|
||||
rebaseStateOffset = serial.RebaseStateEnd(builder)
|
||||
}
|
||||
|
||||
@@ -264,7 +265,7 @@ func NewMergeState(
|
||||
}
|
||||
}
|
||||
|
||||
func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool) *RebaseState {
|
||||
func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool, skipVerification bool) *RebaseState {
|
||||
return &RebaseState{
|
||||
preRebaseWorkingAddr: &preRebaseWorkingRoot,
|
||||
ontoCommitAddr: &commitAddr,
|
||||
@@ -273,6 +274,7 @@ func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch
|
||||
emptyCommitHandling: emptyCommitHandling,
|
||||
lastAttemptedStep: lastAttemptedStep,
|
||||
rebasingStarted: rebasingStarted,
|
||||
skipVerification: skipVerification,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
253
integration-tests/bats/commit_verification.bats
Normal file
253
integration-tests/bats/commit_verification.bats
Normal file
@@ -0,0 +1,253 @@
|
||||
#!/usr/bin/env bats
|
||||
load $BATS_TEST_DIRNAME/helper/common.bash
|
||||
|
||||
setup() {
|
||||
setup_common
|
||||
|
||||
dolt sql <<SQL
|
||||
CREATE TABLE users (
|
||||
id INT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
email VARCHAR(100)
|
||||
);
|
||||
INSERT INTO users VALUES (1, 'Alice', 'alice@example.com');
|
||||
CALL DOLT_ADD('.');
|
||||
CALL DOLT_COMMIT('-m', 'Initial commit');
|
||||
SQL
|
||||
}
|
||||
|
||||
getHeadHash() {
|
||||
run dolt sql -r csv -q "select commit_hash from dolt_log limit 1 offset 0;"
|
||||
[ "$status" -eq 0 ] || return 1
|
||||
echo "${lines[1]}"
|
||||
}
|
||||
|
||||
@test "commit_verification: system variables can be set" {
|
||||
run dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
[ "$status" -eq 0 ]
|
||||
|
||||
run dolt sql -q "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "*" ]] || false
|
||||
}
|
||||
|
||||
@test "commit_verification: commit with tests enabled - all tests pass" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1'),
|
||||
('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Alice"', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
|
||||
dolt add .
|
||||
|
||||
run dolt commit -m "Commit with passing tests"
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: abort commit, then skip verification to bypass" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
|
||||
SQL
|
||||
|
||||
dolt add .
|
||||
|
||||
run dolt commit -m "Commit that should fail verification"
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "commit verification failed" ]] || false
|
||||
[[ "$output" =~ "test_will_fail" ]] || false
|
||||
[[ "$output" =~ "expected_single_value equal to 999, got 1" ]] || false
|
||||
|
||||
run dolt commit --skip-verification -m "Skip verification commit"
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: specific test groups - only specified groups run" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = 'unit'"
|
||||
|
||||
# Add tests in different groups
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1'),
|
||||
('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
|
||||
SQL
|
||||
|
||||
dolt add .
|
||||
|
||||
# Commit should succeed because only unit tests run (integration test that would fail is ignored)
|
||||
run dolt commit -m "Commit with unit tests only"
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: merge with tests enabled - tests pass" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Alice"', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit -m "Initial commit"
|
||||
|
||||
dolt checkout -b feature
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Bob"', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit -m "Add Bob"
|
||||
|
||||
dolt checkout main
|
||||
run dolt merge feature
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: merge with tests enabled - tests fail, merge aborted" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Initial commit with failing test"
|
||||
|
||||
dolt checkout -b feature
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add Bob"
|
||||
|
||||
# Add Charlie to main to force non-fast-forward merge
|
||||
dolt checkout main
|
||||
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add Charlie"
|
||||
|
||||
run dolt merge feature
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "commit verification failed" ]] || false
|
||||
[[ "$output" =~ "test_will_fail" ]] || false
|
||||
[[ "$output" =~ "expected_single_value equal to 999, got 3" ]] || false
|
||||
|
||||
run dolt merge --skip-verification feature
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: cherry-pick with tests enabled - tests pass" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add test"
|
||||
|
||||
dolt checkout -b feature
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'"
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add Bob and update test"
|
||||
commit_hash=$(getHeadHash)
|
||||
|
||||
dolt checkout main
|
||||
run dolt cherry-pick $commit_hash
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: cherry-pick with tests enabled - tests fail, aborted" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit -m "Initial commit"
|
||||
|
||||
dolt checkout -b feature
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add Bob but don't update test"
|
||||
commit_hash=$(getHeadHash)
|
||||
|
||||
dolt checkout main
|
||||
run dolt cherry-pick $commit_hash
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "commit verification failed" ]] || false
|
||||
[[ "$output" =~ "test_users_count" ]] || false
|
||||
[[ "$output" =~ "expected_single_value equal to 1, got 2" ]] || false
|
||||
|
||||
run dolt cherry-pick --skip-verification $commit_hash
|
||||
[ "$status" -eq 0 ]
|
||||
}
|
||||
|
||||
@test "commit_verification: rebase with tests enabled - tests pass" {
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit -m "Initial commit"
|
||||
|
||||
dolt sql -q "UPDATE users SET name = 'Zed' WHERE id = 1"
|
||||
dolt commit -am "Update Alice to Zed" # Tests still pass
|
||||
|
||||
dolt checkout -b feature HEAD~1
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'"
|
||||
dolt add .
|
||||
dolt commit -m "Add Bob and update test"
|
||||
|
||||
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
|
||||
dolt sql -q "UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'"
|
||||
dolt add .
|
||||
dolt commit -m "Add Charlie, update test"
|
||||
|
||||
run dolt rebase main
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "Successfully rebased" ]] || false
|
||||
}
|
||||
|
||||
@test "commit_verification: rebase with tests enabled - tests fail, aborted" {
|
||||
skip "Rebase restart of workflow on failed verification is currently busted."
|
||||
|
||||
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
|
||||
|
||||
dolt sql <<SQL
|
||||
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
|
||||
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
|
||||
SQL
|
||||
dolt add .
|
||||
dolt commit -m "Initial commit"
|
||||
|
||||
dolt checkout -b feature
|
||||
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
|
||||
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'"
|
||||
dolt add .
|
||||
dolt commit -m "Add Bob and update test"
|
||||
|
||||
dolt checkout main
|
||||
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
|
||||
dolt add .
|
||||
dolt commit --skip-verification -m "Add Charlie"
|
||||
|
||||
dolt checkout feature
|
||||
|
||||
run dolt rebase main
|
||||
[ "$status" -ne 0 ]
|
||||
[[ "$output" =~ "commit verification failed" ]] || false
|
||||
[[ "$output" =~ "test_users_count" ]] || false
|
||||
[[ "$output" =~ "Expected '2' but got '3'" ]] || false
|
||||
|
||||
run dolt rebase --skip-verification main
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "Successfully rebased" ]] || false
|
||||
}
|
||||
@@ -144,6 +144,7 @@ SKIP_SERVER_TESTS=$(cat <<-EOM
|
||||
~branch-activity.bats~
|
||||
~mutual-tls-auth.bats~
|
||||
~requires-repo.bats~
|
||||
~commit_verification.bats~
|
||||
EOM
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user