Merge remote-tracking branch 'origin/main' into elian/10462

This commit is contained in:
elianddb
2026-02-12 11:24:20 -08:00
27 changed files with 1554 additions and 507 deletions

View File

@@ -61,6 +61,7 @@ func CreateCommitArgParser(supportsBranchFlag bool) *argparser.ArgParser {
ap.SupportsFlag(UpperCaseAllFlag, "A", "Adds all tables and databases (including new tables) in the working set to the staged set.")
ap.SupportsFlag(AmendFlag, "", "Amend previous commit")
ap.SupportsOptionalString(SignFlag, "S", "key-id", "Sign the commit using GPG. If no key-id is provided the key-id is taken from 'user.signingkey' the in the configuration")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification")
if supportsBranchFlag {
ap.SupportsString(BranchParam, "", "branch", "Commit to the specified branch instead of the current branch.")
}
@@ -96,6 +97,7 @@ func CreateMergeArgParser() *argparser.ArgParser {
ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.")
ap.SupportsFlag(NoEditFlag, "", "Use an auto-generated commit message when creating a merge commit. The default for interactive CLI sessions is to open an editor.")
ap.SupportsString(AuthorParam, "", "author", "Specify an explicit author using the standard A U Thor {{.LessThan}}author@example.com{{.GreaterThan}} format.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
return ap
}
@@ -116,6 +118,7 @@ func CreateRebaseArgParser() *argparser.ArgParser {
ap.SupportsFlag(AbortParam, "", "Abort an interactive rebase and return the working set to the pre-rebase state")
ap.SupportsFlag(ContinueFlag, "", "Continue an interactive rebase after adjusting the rebase plan")
ap.SupportsFlag(InteractiveFlag, "i", "Start an interactive rebase")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before rebase")
return ap
}
@@ -193,6 +196,7 @@ func CreateCherryPickArgParser() *argparser.ArgParser {
ap.SupportsFlag(AllowEmptyFlag, "", "Allow empty commits to be cherry-picked. "+
"Note that use of this option only keeps commits that were initially empty. "+
"Commits which become empty, due to a previous commit, will cause cherry-pick to fail.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before cherry-pick")
ap.TooManyArgsErrorFunc = func(receivedArgs []string) error {
return errors.New("cherry-picking multiple commits is not supported yet.")
}
@@ -230,6 +234,7 @@ func CreatePullArgParser() *argparser.ArgParser {
ap.SupportsString(UserFlag, "", "user", "User name to use when authenticating with the remote. Gets password from the environment variable {{.EmphasisLeft}}DOLT_REMOTE_PASSWORD{{.EmphasisRight}}.")
ap.SupportsFlag(PruneFlag, "p", "After fetching, remove any remote-tracking references that don't exist on the remote.")
ap.SupportsFlag(SilentFlag, "", "Suppress progress information.")
ap.SupportsFlag(SkipVerificationFlag, "", "Skip commit verification before merge")
return ap
}

View File

@@ -79,6 +79,7 @@ const (
SilentFlag = "silent"
SingleBranchFlag = "single-branch"
SkipEmptyFlag = "skip-empty"
SkipVerificationFlag = "skip-verification"
SoftResetParam = "soft"
SquashParam = "squash"
StagedFlag = "staged"

View File

@@ -266,6 +266,10 @@ func constructParametrizedDoltCommitQuery(msg string, apr *argparser.ArgParseRes
writeToBuffer("--skip-empty")
}
if apr.Contains(cli.SkipVerificationFlag) {
writeToBuffer("--skip-verification")
}
cfgSign := cliCtx.Config().GetStringOrDefault("sqlserver.global.gpgsign", "")
if apr.Contains(cli.SignFlag) || strings.ToLower(cfgSign) == "true" {
writeToBuffer("--gpg-sign")

View File

@@ -318,6 +318,10 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx
params = append(params, msg)
}
if apr.Contains(cli.SkipVerificationFlag) {
writeToBuffer("--skip-verification", false)
}
if !apr.Contains(cli.AbortParam) && !apr.Contains(cli.SquashParam) {
writeToBuffer("?", true)
params = append(params, apr.Arg(0))

View File

@@ -579,7 +579,19 @@ func (rcv *RebaseState) MutateRebasingStarted(n bool) bool {
return rcv._tab.MutateBoolSlot(16, n)
}
const RebaseStateNumFields = 7
func (rcv *RebaseState) SkipVerification() bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(18))
if o != 0 {
return rcv._tab.GetBool(o + rcv._tab.Pos)
}
return false
}
func (rcv *RebaseState) MutateSkipVerification(n bool) bool {
return rcv._tab.MutateBoolSlot(18, n)
}
const RebaseStateNumFields = 8
func RebaseStateStart(builder *flatbuffers.Builder) {
builder.StartObject(RebaseStateNumFields)
@@ -614,6 +626,9 @@ func RebaseStateAddLastAttemptedStep(builder *flatbuffers.Builder, lastAttempted
func RebaseStateAddRebasingStarted(builder *flatbuffers.Builder, rebasingStarted bool) {
builder.PrependBoolSlot(6, rebasingStarted, false)
}
func RebaseStateAddSkipVerification(builder *flatbuffers.Builder, skipVerification bool) {
builder.PrependBoolSlot(7, skipVerification, false)
}
func RebaseStateEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
}

View File

@@ -52,6 +52,9 @@ type CherryPickOptions struct {
// and Dolt cherry-pick implementations, the default action is to fail when an empty commit is specified. In Git
// and Dolt rebase implementations, the default action is to keep commits that start off as empty.
EmptyCommitHandling doltdb.EmptyCommitHandling
// SkipVerification controls whether test validation should be skipped before creating commits.
SkipVerification bool
}
// NewCherryPickOptions creates a new CherryPickOptions instance, filled out with default values for cherry-pick.
@@ -61,6 +64,7 @@ func NewCherryPickOptions() CherryPickOptions {
CommitMessage: "",
CommitBecomesEmptyHandling: doltdb.ErrorOnEmptyCommit,
EmptyCommitHandling: doltdb.ErrorOnEmptyCommit,
SkipVerification: false,
}
}
@@ -159,9 +163,10 @@ func CreateCommitStagedPropsFromCherryPickOptions(ctx *sql.Context, options Cher
}
commitProps := actions.CommitStagedProps{
Date: originalMeta.Time(),
Name: originalMeta.Name,
Email: originalMeta.Email,
Date: originalMeta.Time(),
Name: originalMeta.Name,
Email: originalMeta.Email,
SkipVerification: options.SkipVerification,
}
if options.CommitMessage != "" {

View File

@@ -472,6 +472,10 @@ func encodeTableNameForSerialization(name TableName) string {
// decodeTableNameFromSerialization decodes a table name from a serialized string. See notes on serialization in
// |encodeTableNameForSerialization|
func decodeTableNameFromSerialization(encodedName string) (TableName, bool) {
if len(encodedName) == 0 {
return TableName{}, false
}
if encodedName[0] != 0 {
return TableName{Name: encodedName}, true
} else if len(encodedName) >= 4 { // 2 null bytes plus at least one char for schema and table name

View File

@@ -75,6 +75,8 @@ type RebaseState struct {
// rebasingStarted is true once the rebase plan has been started to execute. Once rebasingStarted is true, the
// value in lastAttemptedStep has been initialized and is valid to read.
rebasingStarted bool
// skipVerification indicates whether test validation should be skipped during rebase operations.
skipVerification bool
}
// Branch returns the name of the branch being actively rebased. This is the branch that will be updated to point
@@ -120,6 +122,10 @@ func (rs RebaseState) WithRebasingStarted(rebasingStarted bool) *RebaseState {
return &rs
}
func (rs RebaseState) SkipVerification() bool {
return rs.skipVerification
}
type MergeState struct {
// the source commit
commit *Commit
@@ -322,13 +328,14 @@ func (ws WorkingSet) StartMerge(commit *Commit, commitSpecStr string) *WorkingSe
// the branch that is being rebased, and |previousRoot| is root value of the branch being rebased. The HEAD and STAGED
// root values of the branch being rebased must match |previousRoot|; WORKING may be a different root value, but ONLY
// if it contains only ignored tables.
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling) (*WorkingSet, error) {
func (ws WorkingSet) StartRebase(ctx *sql.Context, ontoCommit *Commit, branch string, previousRoot RootValue, commitBecomesEmptyHandling EmptyCommitHandling, emptyCommitHandling EmptyCommitHandling, skipVerification bool) (*WorkingSet, error) {
ws.rebaseState = &RebaseState{
ontoCommit: ontoCommit,
preRebaseWorking: previousRoot,
branch: branch,
commitBecomesEmptyHandling: commitBecomesEmptyHandling,
emptyCommitHandling: emptyCommitHandling,
skipVerification: skipVerification,
}
ontoRoot, err := ontoCommit.GetRootValue(ctx)
@@ -549,6 +556,7 @@ func newWorkingSet(ctx context.Context, name string, vrw types.ValueReadWriter,
emptyCommitHandling: EmptyCommitHandling(dsws.RebaseState.EmptyCommitHandling(ctx)),
lastAttemptedStep: dsws.RebaseState.LastAttemptedStep(ctx),
rebasingStarted: dsws.RebaseState.RebasingStarted(ctx),
skipVerification: dsws.RebaseState.SkipVerification(ctx),
}
}
@@ -646,7 +654,7 @@ func (ws *WorkingSet) writeValues(ctx context.Context, db *DoltDB, meta *datas.W
rebaseState = datas.NewRebaseState(preRebaseWorking.TargetHash(), dCommit.Addr(), ws.rebaseState.branch,
uint8(ws.rebaseState.commitBecomesEmptyHandling), uint8(ws.rebaseState.emptyCommitHandling),
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted)
ws.rebaseState.lastAttemptedStep, ws.rebaseState.rebasingStarted, ws.rebaseState.skipVerification)
}
return &datas.WorkingSetSpec{

View File

@@ -15,8 +15,12 @@
package actions
import (
"fmt"
"io"
"strings"
"time"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
@@ -25,14 +29,42 @@ import (
)
type CommitStagedProps struct {
Message string
Date time.Time
AllowEmpty bool
SkipEmpty bool
Amend bool
Force bool
Name string
Email string
Message string
Date time.Time
AllowEmpty bool
SkipEmpty bool
Amend bool
Force bool
Name string
Email string
SkipVerification bool
}
const (
// System variable name, defined here to avoid circular imports
DoltCommitVerificationGroups = "dolt_commit_verification_groups"
)
// GetCommitRunTestGroups returns the test groups to run for commit operations
// Returns empty slice if no tests should be run, ["*"] if all tests should be run,
// or specific group names if only those groups should be run
func GetCommitRunTestGroups() []string {
_, val, ok := sql.SystemVariables.GetGlobal(DoltCommitVerificationGroups)
if !ok {
return nil
}
if stringVal, ok := val.(string); ok && stringVal != "" {
if stringVal == "*" {
return []string{"*"}
}
// Split by comma and trim whitespace
groups := strings.Split(stringVal, ",")
for i, group := range groups {
groups[i] = strings.TrimSpace(group)
}
return groups
}
return nil
}
// GetCommitStaged returns a new pending commit with the roots and commit properties given.
@@ -114,6 +146,16 @@ func GetCommitStaged(
}
}
if !props.SkipVerification {
testGroups := GetCommitRunTestGroups()
if len(testGroups) > 0 {
err := runCommitVerification(ctx, testGroups)
if err != nil {
return nil, err
}
}
}
meta, err := datas.NewCommitMetaWithUserTS(props.Name, props.Email, props.Message, props.Date)
if err != nil {
return nil, err
@@ -121,3 +163,61 @@ func GetCommitStaged(
return db.NewPendingCommit(ctx, roots, mergeParents, props.Amend, meta)
}
func runCommitVerification(ctx *sql.Context, testGroups []string) error {
type sessionInterface interface {
sql.Session
GenericProvider() sql.MutableDatabaseProvider
}
session, ok := ctx.Session.(sessionInterface)
if !ok {
return fmt.Errorf("session does not provide database provider interface")
}
provider := session.GenericProvider()
engine := gms.NewDefault(provider)
return runTestsUsingDtablefunctions(ctx, engine, testGroups)
}
// runTestsUsingDtablefunctions runs tests using the dtablefunctions package against the staged root
func runTestsUsingDtablefunctions(ctx *sql.Context, engine *gms.Engine, testGroups []string) error {
if len(testGroups) == 0 {
return nil
}
var allFailures []string
for _, group := range testGroups {
query := fmt.Sprintf("SELECT * FROM dolt_test_run('%s')", group)
_, iter, _, err := engine.Query(ctx, query)
if err != nil {
return fmt.Errorf("failed to run dolt_test_run for group %s: %w", group, err)
}
for {
row, rErr := iter.Next(ctx)
if rErr == io.EOF {
break
}
if rErr != nil {
return fmt.Errorf("error reading test results: %w", rErr)
}
// Extract status (column 3)
status := fmt.Sprintf("%v", row[3])
if status != "PASS" {
testName := fmt.Sprintf("%v", row[0])
message := fmt.Sprintf("%v", row[4])
allFailures = append(allFailures, fmt.Sprintf("%s (%s)", testName, message))
}
}
}
if len(allFailures) > 0 {
return fmt.Errorf("commit verification failed: %s", strings.Join(allFailures, ", "))
}
return nil
}

View File

@@ -1,447 +0,0 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package actions
import (
"fmt"
"io"
"strconv"
"time"
"github.com/dolthub/go-mysql-server/sql"
"github.com/shopspring/decimal"
"golang.org/x/exp/constraints"
"github.com/dolthub/dolt/go/store/val"
)
const (
AssertionExpectedRows = "expected_rows"
AssertionExpectedColumns = "expected_columns"
AssertionExpectedSingleValue = "expected_single_value"
)
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
// testPassed indicates whether the test was successful or not.
// message is a string used to indicate test failures, and will not halt the overall process.
// message will be empty if the test passed.
// err indicates runtime failures and will stop dolt_test_run from proceeding.
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
switch assertion {
case AssertionExpectedRows:
message, err = expectRows(sqlCtx, comparison, value, queryResult)
case AssertionExpectedColumns:
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
case AssertionExpectedSingleValue:
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
default:
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
}
if err != nil {
return false, "", err
} else if message != "" {
return false, message, nil
}
return true, "", nil
}
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
row, err := queryResult.Next(sqlCtx)
if err == io.EOF {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
} else if err != nil {
return "", err
}
if len(row) != 1 {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
}
_, err = queryResult.Next(sqlCtx)
if err == nil { //If multiple rows were given, we should error out
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
} else if err != io.EOF { // "True" error, so we should quit out
return "", err
}
if value == nil { // If we're expecting a null value, we don't need to type switch
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
}
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
// of "0" and "1", which are valid integers and are covered below.
if *value != "0" && *value != "1" {
if expectedBool, err := strconv.ParseBool(*value); err == nil {
actualBool, boolErr := getInterfaceAsBool(row[0])
if boolErr != nil {
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
}
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
}
}
switch actualValue := row[0].(type) {
case int8:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int16:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int32:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int64:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
case int:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case uint8:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint16:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint64:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
case uint:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case float64:
expectedFloat, err := strconv.ParseFloat(*value, 64)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
case float32:
expectedFloat, err := strconv.ParseFloat(*value, 32)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
case decimal.Decimal:
expectedDecimal, err := decimal.NewFromString(*value)
if err != nil {
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
}
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
case time.Time:
expectedTime, format, err := parseTestsDate(*value)
if err != nil {
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
}
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
case *val.TextStorage, string:
actualString, err := GetStringColAsString(sqlCtx, actualValue)
if err != nil {
return "", err
}
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
default:
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
}
}
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedRows, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numRows int
for {
_, err := queryResult.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
return "", err
}
numRows++
}
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
}
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedColumns, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numColumns int
row, err := queryResult.Next(sqlCtx)
if err != nil && err != io.EOF {
return "", err
}
numColumns = len(row)
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
}
// compareTestAssertion is a generic function used for comparing string, ints, floats.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
switch comparison {
case "==":
if actualValue != expectedValue {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "!=":
if actualValue == expectedValue {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "<":
if actualValue >= expectedValue {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
}
case "<=":
if actualValue > expectedValue {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case ">":
if actualValue <= expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
}
case ">=":
if actualValue < expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
// It returns the parsed time, the format that succeeded, and an error if applicable.
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
// List of valid formats
formats := []string{
time.DateOnly,
time.DateTime,
time.TimeOnly,
time.RFC3339,
time.RFC1123Z,
}
for _, format := range formats {
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
return parsedTime, format, nil
} else {
err = parseErr
}
}
return time.Time{}, "", err
}
// compareDates is a function used for comparing time values.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
expectedStr := expectedValue.Format(format)
realStr := realValue.Format(format)
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "<":
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
}
case "<=":
if realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
case ">":
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
}
case ">=":
if realValue.Before(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// compareDecimals is a function used for comparing decimals.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "<":
if realValue.GreaterThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
}
case "<=":
if realValue.GreaterThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
case ">":
if realValue.LessThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
}
case ">=":
if realValue.LessThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// getTinyIntColAsBool returns the value interface{} as a bool
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
func getInterfaceAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return v, nil
case int:
return v == 1, nil
case int8:
return v == 1, nil
case int16:
return v == 1, nil
case int32:
return v == 1, nil
case int64:
return v == 1, nil
case uint:
return v == 1, nil
case uint8:
return v == 1, nil
case uint16:
return v == 1, nil
case uint32:
return v == 1, nil
case uint64:
return v == 1, nil
case string:
return v == "1", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
}
}
// compareBooleans is a function used for comparing boolean values.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
switch comparison {
case "==":
if expectedValue != realValue {
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue == realValue {
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
}
return ""
}
// compareNullValue is a function used for comparing a null value.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
switch comparison {
case "==":
if actualValue != nil {
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
}
case "!=":
if actualValue == nil {
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
}
default:
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
}
return ""
}
// GetStringColAsString is a function that returns a text column as a string.
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
// so we use a special parser to get the correct string values
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
return &str, err
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else if tableValue == nil {
return nil, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}

View File

@@ -1955,7 +1955,10 @@ func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.Prima
return err
}
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) && !doltdb.IsFullTextTable(tableName) && !doltdb.HasDoltCIPrefix(tableName) {
if doltdb.IsSystemTable(doltdb.TableName{Name: tableName, Schema: db.schemaName}) &&
!doltdb.IsFullTextTable(tableName) &&
!doltdb.HasDoltCIPrefix(tableName) &&
tableName != doltdb.TestsTableName { // NM4 - determine why this is required now.
return ErrReservedTableName.New(tableName)
}

View File

@@ -103,6 +103,8 @@ func doDoltCherryPick(ctx *sql.Context, args []string) (string, int, int, int, e
cherryPickOptions.EmptyCommitHandling = doltdb.KeepEmptyCommit
}
cherryPickOptions.SkipVerification = apr.Contains(cli.SkipVerificationFlag)
commit, mergeResult, err := cherry_pick.CherryPick(ctx, cherryStr, cherryPickOptions)
if err != nil {
return "", 0, 0, 0, err

View File

@@ -163,14 +163,15 @@ func doDoltCommit(ctx *sql.Context, args []string) (string, bool, error) {
}
csp := actions.CommitStagedProps{
Message: msg,
Date: t,
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
Amend: amend,
Force: apr.Contains(cli.ForceFlag),
Name: name,
Email: email,
Message: msg,
Date: t,
AllowEmpty: apr.Contains(cli.AllowEmptyFlag),
SkipEmpty: apr.Contains(cli.SkipEmptyFlag),
Amend: amend,
Force: apr.Contains(cli.ForceFlag),
Name: name,
Email: email,
SkipVerification: apr.Contains(cli.SkipVerificationFlag),
}
shouldSign, err := dsess.GetBooleanSystemVar(ctx, "gpgsign")

View File

@@ -180,7 +180,7 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err
msg = userMsg
}
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
ws, commit, conflicts, fastForward, message, err := performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
if err != nil {
return commit, conflicts, fastForward, "", err
}
@@ -205,6 +205,7 @@ func performMerge(
spec *merge.MergeSpec,
noCommit bool,
msg string,
skipVerification bool,
) (*doltdb.WorkingSet, string, int, int, string, error) {
// todo: allow merges even when an existing merge is uncommitted
if ws.MergeActive() {
@@ -234,7 +235,7 @@ func performMerge(
if canFF {
if spec.FFMode == merge.NoFastForward {
var commit *doltdb.Commit
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit, skipVerification)
if err == doltdb.ErrUnresolvedConflictsOrViolations {
// if there are unresolved conflicts, write the resulting working set back to the session and return an
// error message
@@ -306,7 +307,10 @@ func performMerge(
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
args := []string{"-m", msg, "--author", author}
if spec.Force {
args = append(args, "--force")
args = append(args, "--"+cli.ForceFlag)
}
if skipVerification {
args = append(args, "--"+cli.SkipVerificationFlag)
}
commit, _, err = doDoltCommit(ctx, args)
if err != nil {
@@ -405,6 +409,7 @@ func executeNoFFMerge(
dbName string,
ws *doltdb.WorkingSet,
noCommit bool,
skipVerification bool,
) (*doltdb.WorkingSet, *doltdb.Commit, error) {
mergeRoot, err := spec.MergeC.GetRootValue(ctx)
if err != nil {
@@ -444,11 +449,12 @@ func executeNoFFMerge(
}
pendingCommit, err := dSess.NewPendingCommit(ctx, dbName, roots, actions.CommitStagedProps{
Message: msg,
Date: spec.Date,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
Message: msg,
Date: spec.Date,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
SkipVerification: skipVerification,
})
if err != nil {
return nil, nil, err

View File

@@ -237,7 +237,7 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
return noConflictsOrViolations, threeWayMerge, "", ErrUncommittedChanges.New()
}
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg)
ws, _, conflicts, fastForward, message, err = performMerge(ctx, sess, ws, dbName, mergeSpec, apr.Contains(cli.NoCommitFlag), msg, apr.Contains(cli.SkipVerificationFlag))
if err != nil && !errors.Is(doltdb.ErrUpToDate, err) {
return conflicts, fastForward, "", err
}

View File

@@ -216,7 +216,9 @@ func doDoltRebase(ctx *sql.Context, args []string) (int, string, error) {
} else if apr.NArg() > 1 {
return 1, "", fmt.Errorf("too many args")
}
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling)
skipVerification := apr.Contains(cli.SkipVerificationFlag)
err = startRebase(ctx, apr.Arg(0), commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return 1, "", err
}
@@ -263,7 +265,7 @@ func processCommitBecomesEmptyParams(apr *argparser.ArgParseResults) (doltdb.Emp
// startRebase starts a new interactive rebase operation. |upstreamPoint| specifies the commit where the new rebased
// commits will be based off of, |commitBecomesEmptyHandling| specifies how to handle commits that are not empty, but
// do not produce any changes when applied, and |emptyCommitHandling| specifies how to handle empty commits.
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) error {
func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling, skipVerification bool) error {
if upstreamPoint == "" {
return fmt.Errorf("no upstream branch specified")
}
@@ -351,7 +353,7 @@ func startRebase(ctx *sql.Context, upstreamPoint string, commitBecomesEmptyHandl
}
newWorkingSet, err := workingSet.StartRebase(ctx, upstreamCommit, rebaseBranch, branchRoots.Working,
commitBecomesEmptyHandling, emptyCommitHandling)
commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return err
}
@@ -716,7 +718,8 @@ func continueRebase(ctx *sql.Context) rebaseResult {
result := processRebasePlanStep(ctx, &step,
workingSet.RebaseState().CommitBecomesEmptyHandling(),
workingSet.RebaseState().EmptyCommitHandling())
workingSet.RebaseState().EmptyCommitHandling(),
workingSet.RebaseState().SkipVerification())
if result.err != nil || result.status != 0 || result.halt {
return result
}
@@ -803,7 +806,7 @@ func commitManuallyStagedChangesForStep(ctx *sql.Context, step rebase.RebasePlan
}
options, err := createCherryPickOptionsForRebaseStep(ctx, &step, workingSet.RebaseState().CommitBecomesEmptyHandling(),
workingSet.RebaseState().EmptyCommitHandling())
workingSet.RebaseState().EmptyCommitHandling(), workingSet.RebaseState().SkipVerification())
doltDB, ok := doltSession.GetDoltDB(ctx, ctx.GetCurrentDatabase())
if !ok {
@@ -861,6 +864,7 @@ func processRebasePlanStep(
planStep *rebase.RebasePlanStep,
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
emptyCommitHandling doltdb.EmptyCommitHandling,
skipVerification bool,
) rebaseResult {
// Make sure we have a transaction opened for the session
// NOTE: After our first call to cherry-pick, the tx is committed, so a new tx needs to be started
@@ -878,7 +882,7 @@ func processRebasePlanStep(
return newRebaseSuccess("")
}
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling)
options, err := createCherryPickOptionsForRebaseStep(ctx, planStep, commitBecomesEmptyHandling, emptyCommitHandling, skipVerification)
if err != nil {
return newRebaseError(err)
}
@@ -886,12 +890,19 @@ func processRebasePlanStep(
return handleRebaseCherryPick(ctx, planStep, *options)
}
func createCherryPickOptionsForRebaseStep(ctx *sql.Context, planStep *rebase.RebasePlanStep, commitBecomesEmptyHandling doltdb.EmptyCommitHandling, emptyCommitHandling doltdb.EmptyCommitHandling) (*cherry_pick.CherryPickOptions, error) {
func createCherryPickOptionsForRebaseStep(
ctx *sql.Context,
planStep *rebase.RebasePlanStep,
commitBecomesEmptyHandling doltdb.EmptyCommitHandling,
emptyCommitHandling doltdb.EmptyCommitHandling,
skipVerification bool,
) (*cherry_pick.CherryPickOptions, error) {
// Override the default empty commit handling options for cherry-pick, since
// rebase has slightly different defaults
options := cherry_pick.NewCherryPickOptions()
options.CommitBecomesEmptyHandling = commitBecomesEmptyHandling
options.EmptyCommitHandling = emptyCommitHandling
options.SkipVerification = skipVerification
switch planStep.Action {
case rebase.RebaseActionDrop, rebase.RebaseActionPick, rebase.RebaseActionEdit:

View File

@@ -17,7 +17,9 @@ package dtablefunctions
import (
"fmt"
"io"
"strconv"
"strings"
"time"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/sql"
@@ -26,10 +28,13 @@ import (
"github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/gocraft/dbr/v2"
"github.com/gocraft/dbr/v2/dialect"
"github.com/shopspring/decimal"
"golang.org/x/exp/constraints"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/store/val"
)
const testsRunDefaultRowCount = 10
@@ -39,12 +44,13 @@ var _ sql.CatalogTableFunction = (*TestsRunTableFunction)(nil)
var _ sql.ExecSourceRel = (*TestsRunTableFunction)(nil)
var _ sql.AuthorizationCheckerNode = (*TestsRunTableFunction)(nil)
type testResult struct {
testName string
groupName string
query string
status string
message string
// TestResult represents the result of running a single test
type TestResult struct {
TestName string
GroupName string
Query string
Status string
Message string
}
type TestsRunTableFunction struct {
@@ -199,7 +205,7 @@ func (trtf *TestsRunTableFunction) RowIter(_ *sql.Context, _ sql.Row) (sql.RowIt
return nil, err
}
resultRow := sql.NewRow(result.testName, result.groupName, result.query, result.status, result.message)
resultRow := sql.NewRow(result.TestName, result.GroupName, result.Query, result.Status, result.Message)
resultRows = append(resultRows, resultRow)
}
}
@@ -220,7 +226,7 @@ func (trtf *TestsRunTableFunction) RowCount(_ *sql.Context) (uint64, bool, error
return testsRunDefaultRowCount, false, nil
}
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResult, err error) {
func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result TestResult, err error) {
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
if err != nil {
return
@@ -237,9 +243,9 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if err != nil {
message = fmt.Sprintf("Query error: %s", err.Error())
} else {
testPassed, message, err = actions.AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
testPassed, message, err = AssertData(trtf.ctx, *assertion, *comparison, value, queryResult)
if err != nil {
return testResult{}, err
return TestResult{}, err
}
}
}
@@ -253,11 +259,49 @@ func (trtf *TestsRunTableFunction) queryAndAssert(row sql.Row) (result testResul
if groupName != nil {
groupString = *groupName
}
result = testResult{*testName, groupString, *query, status, message}
result = TestResult{*testName, groupString, *query, status, message}
return result, nil
}
func (trtf *TestsRunTableFunction) queryAndAssertWithFunc(row sql.Row, assertDataFunc AssertDataFunc) (result TestResult, err error) {
testName, groupName, query, assertion, comparison, value, err := parseDoltTestsRow(trtf.ctx, row)
if err != nil {
return
}
message, err := validateQuery(trtf.ctx, trtf.catalog, *query)
if err != nil && message == "" {
message = fmt.Sprintf("query error: %s", err.Error())
}
var testPassed bool
if message == "" {
_, queryResult, _, err := trtf.engine.Query(trtf.ctx, *query)
if err != nil {
message = fmt.Sprintf("Query error: %s", err.Error())
} else {
testPassed, message, err = assertDataFunc(trtf.ctx, *assertion, *comparison, value, queryResult)
if err != nil {
return TestResult{}, err
}
}
}
status := "PASS"
if !testPassed {
status = "FAIL"
}
var groupString string
if groupName != nil {
groupString = *groupName
}
result = TestResult{*testName, groupString, *query, status, message}
return result, nil
}
func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, error) {
// Original behavior when root is nil - use SQL queries against current session
var queries []string
if arg == "*" {
@@ -320,28 +364,31 @@ func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, er
}
func parseDoltTestsRow(ctx *sql.Context, row sql.Row) (testName, groupName, query, assertion, comparison, value *string, err error) {
if testName, err = actions.GetStringColAsString(ctx, row[0]); err != nil {
if testName, err = getStringColAsString(ctx, row[0]); err != nil {
return
}
if groupName, err = actions.GetStringColAsString(ctx, row[1]); err != nil {
if groupName, err = getStringColAsString(ctx, row[1]); err != nil {
return
}
if query, err = actions.GetStringColAsString(ctx, row[2]); err != nil {
if query, err = getStringColAsString(ctx, row[2]); err != nil {
return
}
if assertion, err = actions.GetStringColAsString(ctx, row[3]); err != nil {
if assertion, err = getStringColAsString(ctx, row[3]); err != nil {
return
}
if comparison, err = actions.GetStringColAsString(ctx, row[4]); err != nil {
if comparison, err = getStringColAsString(ctx, row[4]); err != nil {
return
}
if value, err = actions.GetStringColAsString(ctx, row[5]); err != nil {
if value, err = getStringColAsString(ctx, row[5]); err != nil {
return
}
return testName, groupName, query, assertion, comparison, value, nil
}
// AssertDataFunc defines the function signature for asserting test data
type AssertDataFunc func(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error)
func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string, error) {
// We first check if the query contains multiple sql statements
if statements, err := sqlparser.SplitStatementToPieces(query); err != nil {
@@ -361,3 +408,455 @@ func validateQuery(ctx *sql.Context, catalog sql.Catalog, query string) (string,
}
return "", nil
}
// Simple inline assertion constants to avoid circular imports
const (
AssertionExpectedRows = "expected_rows"
AssertionExpectedColumns = "expected_columns"
AssertionExpectedSingleValue = "expected_single_value"
)
// getStringColAsString safely converts a sql value to string
func getStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if tableValue == nil {
return nil, nil
}
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
if err != nil {
return nil, err
}
return &str, nil
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}
// readTableDataFromDoltTable reads test data directly from a dolt table
func (trtf *TestsRunTableFunction) readTableDataFromDoltTable(table *doltdb.Table, arg string) ([]sql.Row, error) {
// This is a complex implementation that requires reading table data directly from dolt storage
// For now, return an error that clearly indicates this needs to be implemented
// The table scan would involve:
// 1. Getting the table schema
// 2. Creating a table iterator
// 3. Reading and filtering rows based on the arg (test_name or test_group)
// 4. Converting dolt storage format to SQL rows
//
// This is a significant implementation that requires understanding dolt's storage internals
return nil, fmt.Errorf("direct table reading from dolt storage not yet implemented for table scan of dolt_tests - this requires implementing table iteration and row conversion from dolt's internal storage format")
}
// AssertData parses an assertion, comparison, and value, then returns the status of the test.
// Valid comparison are: "==", "!=", "<", ">", "<=", and ">=".
// testPassed indicates whether the test was successful or not.
// message is a string used to indicate test failures, and will not halt the overall process.
// message will be empty if the test passed.
// err indicates runtime failures and will stop dolt_test_run from proceeding.
func AssertData(sqlCtx *sql.Context, assertion string, comparison string, value *string, queryResult sql.RowIter) (testPassed bool, message string, err error) {
switch assertion {
case AssertionExpectedRows:
message, err = expectRows(sqlCtx, comparison, value, queryResult)
case AssertionExpectedColumns:
message, err = expectColumns(sqlCtx, comparison, value, queryResult)
case AssertionExpectedSingleValue:
message, err = expectSingleValue(sqlCtx, comparison, value, queryResult)
default:
return false, fmt.Sprintf("%s is not a valid assertion type", assertion), nil
}
if err != nil {
return false, "", err
} else if message != "" {
return false, message, nil
}
return true, "", nil
}
func expectSingleValue(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
row, err := queryResult.Next(sqlCtx)
if err == io.EOF {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received 0 rows"), nil
} else if err != nil {
return "", err
}
if len(row) != 1 {
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple columns"), nil
}
_, err = queryResult.Next(sqlCtx)
if err == nil { //If multiple rows were given, we should error out
return fmt.Sprintf("expected_single_value expects exactly one cell. Received multiple rows"), nil
} else if err != io.EOF { // "True" error, so we should quit out
return "", err
}
if value == nil { // If we're expecting a null value, we don't need to type switch
return compareNullValue(comparison, row[0], AssertionExpectedSingleValue), nil
}
// Check if the expected value is a boolean string, and if so, coerce the actual value to boolean, with the exception
// of "0" and "1", which are valid integers and are covered below.
if *value != "0" && *value != "1" {
if expectedBool, err := strconv.ParseBool(*value); err == nil {
actualBool, boolErr := getInterfaceAsBool(row[0])
if boolErr != nil {
return fmt.Sprintf("Could not convert value to boolean: %v", boolErr), nil
}
return compareBooleans(comparison, expectedBool, actualBool, AssertionExpectedSingleValue), nil
}
}
switch actualValue := row[0].(type) {
case int8:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int8(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int16:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int16(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int32:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int32(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case int64:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedInt, actualValue, AssertionExpectedSingleValue), nil
case int:
expectedInt, err := strconv.ParseInt(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, int(expectedInt), actualValue, AssertionExpectedSingleValue), nil
case uint8:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint8(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint16:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint16(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint32:
expectedUint, err := strconv.ParseUint(*value, 10, 32)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint32(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case uint64:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedUint, actualValue, AssertionExpectedSingleValue), nil
case uint:
expectedUint, err := strconv.ParseUint(*value, 10, 64)
if err != nil {
return fmt.Sprintf("Could not compare non integer value '%s', with %d", *value, actualValue), nil
}
return compareTestAssertion(comparison, uint(expectedUint), actualValue, AssertionExpectedSingleValue), nil
case float64:
expectedFloat, err := strconv.ParseFloat(*value, 64)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, expectedFloat, actualValue, AssertionExpectedSingleValue), nil
case float32:
expectedFloat, err := strconv.ParseFloat(*value, 32)
if err != nil {
return fmt.Sprintf("Could not compare non float value '%s', with %f", *value, actualValue), nil
}
return compareTestAssertion(comparison, float32(expectedFloat), actualValue, AssertionExpectedSingleValue), nil
case decimal.Decimal:
expectedDecimal, err := decimal.NewFromString(*value)
if err != nil {
return fmt.Sprintf("Could not compare non decimal value '%s', with %s", *value, actualValue), nil
}
return compareDecimals(comparison, expectedDecimal, actualValue, AssertionExpectedSingleValue), nil
case time.Time:
expectedTime, format, err := parseTestsDate(*value)
if err != nil {
return fmt.Sprintf("%s does not appear to be a valid date", *value), nil
}
return compareDates(comparison, expectedTime, actualValue, format, AssertionExpectedSingleValue), nil
case *val.TextStorage, string:
actualString, err := GetStringColAsString(sqlCtx, actualValue)
if err != nil {
return "", err
}
return compareTestAssertion(comparison, *value, *actualString, AssertionExpectedSingleValue), nil
default:
return fmt.Sprintf("Type %T is not supported. Open an issue at https://github.com/dolthub/dolt/issues to see it added", actualValue), nil
}
}
func expectRows(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedRows, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numRows int
for {
_, err := queryResult.Next(sqlCtx)
if err == io.EOF {
break
} else if err != nil {
return "", err
}
numRows++
}
return compareTestAssertion(comparison, expectedRows, numRows, AssertionExpectedRows), nil
}
func expectColumns(sqlCtx *sql.Context, comparison string, value *string, queryResult sql.RowIter) (message string, err error) {
if value == nil {
return "null is not a valid assertion for expected_rows", nil
}
expectedColumns, err := strconv.Atoi(*value)
if err != nil {
return fmt.Sprintf("cannot run assertion on non integer value: %s", *value), nil
}
var numColumns int
row, err := queryResult.Next(sqlCtx)
if err != nil && err != io.EOF {
return "", err
}
numColumns = len(row)
return compareTestAssertion(comparison, expectedColumns, numColumns, AssertionExpectedColumns), nil
}
// compareTestAssertion is a generic function used for comparing string, ints, floats.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareTestAssertion[T constraints.Ordered](comparison string, expectedValue, actualValue T, assertionType string) string {
switch comparison {
case "==":
if actualValue != expectedValue {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "!=":
if actualValue == expectedValue {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case "<":
if actualValue >= expectedValue {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, actualValue)
}
case "<=":
if actualValue > expectedValue {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
case ">":
if actualValue <= expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, actualValue)
}
case ">=":
if actualValue < expectedValue {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, actualValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// parseTestsDate is an internal function that parses the queried string according to allowed time formats for dolt_tests.
// It returns the parsed time, the format that succeeded, and an error if applicable.
func parseTestsDate(value string) (parsedTime time.Time, format string, err error) {
// List of valid formats
formats := []string{
time.DateOnly,
time.DateTime,
time.TimeOnly,
time.RFC3339,
time.RFC1123Z,
}
for _, format := range formats {
if parsedTime, parseErr := time.Parse(format, value); parseErr == nil {
return parsedTime, format, nil
} else {
err = parseErr
}
}
return time.Time{}, "", err
}
// compareDates is a function used for comparing time values.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDates(comparison string, expectedValue, realValue time.Time, format string, assertionType string) string {
expectedStr := expectedValue.Format(format)
realStr := realValue.Format(format)
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %s, got %s", assertionType, expectedStr, realStr)
}
case "<":
if realValue.Equal(expectedValue) || realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %s, got %s", assertionType, expectedStr, realStr)
}
case "<=":
if realValue.After(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
case ">":
if realValue.Before(expectedValue) || realValue.Equal(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %s, got %s", assertionType, expectedStr, realStr)
}
case ">=":
if realValue.Before(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %s, got %s", assertionType, expectedStr, realStr)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// compareDecimals is a function used for comparing decimals.
// It takes in a comparison string from one of: "==", "!=", "<", ">", "<=", ">="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareDecimals(comparison string, expectedValue, realValue decimal.Decimal, assertionType string) string {
switch comparison {
case "==":
if !expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue.Equal(realValue) {
return fmt.Sprintf("Assertion failed: %s not equal to %v, got %v", assertionType, expectedValue, realValue)
}
case "<":
if realValue.GreaterThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than %v, got %v", assertionType, expectedValue, realValue)
}
case "<=":
if realValue.GreaterThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s less than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
case ">":
if realValue.LessThanOrEqual(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than %v, got %v", assertionType, expectedValue, realValue)
}
case ">=":
if realValue.LessThan(expectedValue) {
return fmt.Sprintf("Assertion failed: %s greater than or equal to %v, got %v", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison type", comparison)
}
return ""
}
// getTinyIntColAsBool returns the value interface{} as a bool
// This is necessary because the query engine may return a tinyint column as a bool, int, or other types.
// Based on GetTinyIntColAsBool from commands/utils.go, which we can't depend on here due to package cycles.
func getInterfaceAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return v, nil
case int:
return v == 1, nil
case int8:
return v == 1, nil
case int16:
return v == 1, nil
case int32:
return v == 1, nil
case int64:
return v == 1, nil
case uint:
return v == 1, nil
case uint8:
return v == 1, nil
case uint16:
return v == 1, nil
case uint32:
return v == 1, nil
case uint64:
return v == 1, nil
case string:
return v == "1", nil
default:
return false, fmt.Errorf("unexpected type %T, was expecting bool, int, or string", v)
}
}
// compareBooleans is a function used for comparing boolean values.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareBooleans(comparison string, expectedValue, realValue bool, assertionType string) string {
switch comparison {
case "==":
if expectedValue != realValue {
return fmt.Sprintf("Assertion failed: %s equal to %t, got %t", assertionType, expectedValue, realValue)
}
case "!=":
if expectedValue == realValue {
return fmt.Sprintf("Assertion failed: %s not equal to %t, got %t", assertionType, expectedValue, realValue)
}
default:
return fmt.Sprintf("%s is not a valid comparison for boolean values. Only '==' and '!=' are supported", comparison)
}
return ""
}
// compareNullValue is a function used for comparing a null value.
// It takes in a comparison string from one of: "==", "!="
// It returns a string. The string is empty if the assertion passed, or has a message explaining the failure otherwise
func compareNullValue(comparison string, actualValue interface{}, assertionType string) string {
switch comparison {
case "==":
if actualValue != nil {
return fmt.Sprintf("Assertion failed: %s equal to NULL, got %v", assertionType, actualValue)
}
case "!=":
if actualValue == nil {
return fmt.Sprintf("Assertion failed: %s not equal to NULL, got NULL", assertionType)
}
default:
return fmt.Sprintf("%s is not a valid comparison for NULL values", comparison)
}
return ""
}
// GetStringColAsString is a function that returns a text column as a string.
// This is necessary as the dolt_tests system table returns *val.TextStorage types under certain situations,
// so we use a special parser to get the correct string values
func GetStringColAsString(sqlCtx *sql.Context, tableValue interface{}) (*string, error) {
if ts, ok := tableValue.(*val.TextStorage); ok {
str, err := ts.Unwrap(sqlCtx)
return &str, err
} else if str, ok := tableValue.(string); ok {
return &str, nil
} else if tableValue == nil {
return nil, nil
} else {
return nil, fmt.Errorf("unexpected type %T, was expecting string", tableValue)
}
}

View File

@@ -1210,6 +1210,11 @@ func TestDoltDdlScripts(t *testing.T) {
RunDoltDdlScripts(t, harness)
}
func TestDoltCommitVerificationScripts(t *testing.T) {
harness := newDoltEnginetestHarness(t)
RunDoltCommitVerificationScripts(t, harness)
}
func TestBrokenDdlScripts(t *testing.T) {
for _, script := range BrokenDDLScripts {
t.Skip(script.Name)

View File

@@ -2147,3 +2147,12 @@ func RunTransactionTestsWithEngineSetup(t *testing.T, setupEngine func(*gms.Engi
})
}
}
func RunDoltCommitVerificationScripts(t *testing.T, harness DoltEnginetestHarness) {
for _, script := range DoltCommitVerificationScripts {
harness := harness.NewHarness(t)
enginetest.TestScript(t, harness, script)
harness.Close()
}
}

View File

@@ -190,7 +190,6 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
for i := range dbs {
db := dbs[i]
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("use %s", db)})
// Any auto increment tables must be dropped and recreated to get a fresh state for the global auto increment
// sequence trackers
_, aiTables := enginetest.MustQuery(ctx, d.engine,
@@ -218,6 +217,7 @@ func (d *DoltHarness) resetScripts() []setup.SetupScript {
resetCmds = append(resetCmds, setup.SetupScript{fmt.Sprintf("drop database if exists %s", db)})
}
}
resetCmds = append(resetCmds, setup.SetupScript{"use mydb"})
return resetCmds
}
@@ -229,7 +229,7 @@ func commitScripts(dbs []string) []setup.SetupScript {
db := dbs[i]
commitCmds = append(commitCmds, fmt.Sprintf("use %s", db))
commitCmds = append(commitCmds, "call dolt_add('.')")
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00')", db))
commitCmds = append(commitCmds, fmt.Sprintf("call dolt_commit('--allow-empty', '-am', 'checkpoint enginetest database %s', '--date', '1970-01-01T12:00:00', '--skip-verification')", db))
}
commitCmds = append(commitCmds, "use mydb")
return []setup.SetupScript{commitCmds}

View File

@@ -0,0 +1,538 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package enginetest
import (
"regexp"
"github.com/dolthub/go-mysql-server/enginetest"
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/store/hash"
)
// commitHashValidator validates commit hash format (32 character hex)
type commitHashValidator struct{}
var _ enginetest.CustomValueValidator = &commitHashValidator{}
func (chv *commitHashValidator) Validate(val interface{}) (bool, error) {
h, ok := val.(string)
if !ok {
return false, nil
}
_, ok = hash.MaybeParse(h)
return ok, nil
}
// successfulRebaseMessageValidator validates successful rebase message format
type successfulRebaseMessageValidator struct{}
var _ enginetest.CustomValueValidator = &successfulRebaseMessageValidator{}
var successfulRebaseRegex = regexp.MustCompile(`^Successfully rebased.*`)
func (srmv *successfulRebaseMessageValidator) Validate(val interface{}) (bool, error) {
message, ok := val.(string)
if !ok {
return false, nil
}
return successfulRebaseRegex.MatchString(message), nil
}
var commitHash = &commitHashValidator{}
var successfulRebaseMessage = &successfulRebaseMessageValidator{}
var DoltCommitVerificationScripts = []queries.ScriptTest{
{
Name: "test verification system variables exist and have correct defaults",
Assertions: []queries.ScriptTestAssertion{
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", ""},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification system variables can be set",
Assertions: []queries.ScriptTestAssertion{
{
Query: "SET GLOBAL dolt_commit_verification_groups = '*'",
Expected: []sql.Row{{types.OkResult{}}},
},
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", "*"},
},
},
{
Query: "SET GLOBAL dolt_commit_verification_groups = 'unit,integration'",
Expected: []sql.Row{{types.OkResult{}}},
},
{
Query: "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'",
Expected: []sql.Row{
{"dolt_commit_verification_groups", "unit,integration"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit verification enabled - all tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with passing tests')",
ExpectedColumns: sql.Schema{
{Name: "hash", Type: types.LongText, Nullable: false},
},
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit verification enabled - tests fail, commit aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit that should fail verification')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
},
{
Query: "CALL dolt_commit('--skip-verification','-m', 'skip verification')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "commit with test verification - specific test groups",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_will_fail', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with unit tests only')",
Expected: []sql.Row{{commitHash}},
},
{
Query: "SET GLOBAL dolt_commit_verification_groups = 'integration'",
SkipResultsCheck: true,
},
{
Query: "CALL dolt_commit('--allow-empty', '--amend', '-m', 'fail please')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 2)",
},
{
Query: "CALL dolt_commit('--allow-empty', '--amend', '--skip-verification', '-m', 'skip the tests')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "cherry-pick with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'add test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_1_hash,'--skip-verification', '-m', 'Add Bob and update test')",
"INSERT INTO users VALUES (3, 'Charlie', 'chuck@exampl.com')",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_2_hash,'--skip-verification', '-m', 'Add Charlie')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_cherry_pick(@commit_1_hash)",
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
},
{
Query: "CALL dolt_cherry_pick(@commit_2_hash)",
ExpectedErrStr: "commit verification failed: test_user_count_update (Assertion failed: expected_single_value equal to 2, got 3)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "cherry-pick with test verification enabled - tests fail, aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"call dolt_commit_hash_out(@commit_hash,'--skip-verification', '-m', 'Add Bob but dont update test')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_cherry_pick(@commit_hash)",
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 1, got 2)",
},
{
Query: "CALL dolt_cherry_pick('--skip-verification', @commit_hash)",
Expected: []sql.Row{{commitHash, int64(0), int64(0), int64(0)}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 1, got 2"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "rebase with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"DELETE FROM users where id = 1",
"INSERT INTO users VALUES (1, 'Zed', 'zed@example.com')",
"CALL dolt_commit('-am', 'drop Alice, add Zed')", // tests still pass here.
"CALL dolt_checkout('-b', 'feature', 'HEAD~1')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Bob and update test')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Charlie, update test')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('main')",
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "rebase with test verification enabled - tests fail, aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Add Bob but dont update test')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')", // this will trip the existing test.
"CALL dolt_checkout('feature')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('main')",
ExpectedErrStr: "commit verification failed: test_users_count (Assertion failed: expected_single_value equal to 2, got 3)",
},
{
Query: "CALL dolt_rebase('--abort')",
Expected: []sql.Row{{0, "Interactive rebase aborted"}},
},
{
Query: "CALL dolt_rebase('--skip-verification', 'main')",
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_users_count", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 2, got 3"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "interactive rebase with --skip-verification flag should persist across continue operations",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob but dont update test')", // This will cause test to fail
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (4, 'David', 'david@example.com')", // Add a commit to main to create divergence
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add David on main')",
"CALL dolt_checkout('feature')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_rebase('--interactive', '--skip-verification', 'main')",
Expected: []sql.Row{{0, "interactive rebase started on branch dolt_rebase_feature; adjust the rebase plan in the dolt_rebase table, then continue rebasing by calling dolt_rebase('--continue')"}},
},
{
Query: "CALL dolt_rebase('--continue')", // This should NOT require --skip-verification flag but should still skip tests
Expected: []sql.Row{{int64(0), successfulRebaseMessage}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification with no dolt_tests errors",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit without dolt_tests table')",
ExpectedErrStr: "failed to run dolt_test_run for group *: could not find tests for argument: *",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification with mixed test groups - only specified groups run",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = 'unit'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '2'), " +
"('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with unit tests only - should pass')",
Expected: []sql.Row{{commitHash}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "test verification error message includes test details",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com'), (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_specific_failure', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_commit('-m', 'Commit with specific test failure')",
ExpectedErrStr: "commit verification failed: test_specific_failure (Assertion failed: expected_single_value equal to 999, got 2)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with test verification enabled - tests pass",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Alice\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('-m', 'Initial commit')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = \"Bob\"', 'expected_single_value', '==', '1')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('feature')",
Expected: []sql.Row{{commitHash, int64(1), int64(0), "merge successful"}},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with test verification enabled - tests fail, merge aborted",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('feature')",
ExpectedErrStr: "commit verification failed: test_will_fail (Assertion failed: expected_single_value equal to 999, got 3)",
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
{
Name: "merge with --skip-verification flag bypasses verification",
SetUpScript: []string{
"SET GLOBAL dolt_commit_verification_groups = '*'",
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100))",
"INSERT INTO users VALUES (1, 'Alice', 'alice@example.com')",
"INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES " +
"('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Initial commit with failing test')",
"CALL dolt_checkout('-b', 'feature')",
"INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Bob')",
"CALL dolt_checkout('main')",
"INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')",
"CALL dolt_add('.')",
"CALL dolt_commit('--skip-verification', '-m', 'Add Charlie to force non-FF merge')",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--skip-verification', 'feature')",
Expected: []sql.Row{{commitHash, int64(0), int64(0), "merge successful"}},
},
{
Query: "select * from dolt_test_run('*')",
Expected: []sql.Row{
{"test_will_fail", "unit", "SELECT COUNT(*) FROM users", "FAIL", "Assertion failed: expected_single_value equal to 999, got 3"},
},
},
{ // Test harness bleeds GLOBAL variable changes across tests, so reset after each test.
Query: "SET GLOBAL dolt_commit_verification_groups = ''",
SkipResultsCheck: true,
},
},
},
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
_ "github.com/dolthub/go-mysql-server/sql/variables"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
)
@@ -292,6 +293,13 @@ var DoltSystemVariables = []sql.SystemVariable{
Type: types.NewSystemBoolType(dsess.AllowCICreation),
Default: int8(0),
},
&sql.MysqlSystemVariable{
Name: actions.DoltCommitVerificationGroups,
Dynamic: true,
Scope: sql.GetMysqlScope(sql.SystemVariableScope_Global),
Type: types.NewSystemStringType(actions.DoltCommitVerificationGroups),
Default: "",
},
}
func AddDoltSystemVariables() {

View File

@@ -67,6 +67,10 @@ table RebaseState {
// The rebasing_started field indicates if execution of the rebase plan has been started or not. Once execution of the
// plan has been started, the last_attempted_step field holds a reference to the most recent plan step attempted.
rebasing_started:bool;
// When set to true, the rebase process will skip performing commit
// verification if it would otherwise run.
skip_verification:bool;
}
// KEEP THIS IN SYNC WITH fileidentifiers.go

View File

@@ -169,6 +169,7 @@ type RebaseState struct {
commitBecomesEmptyHandling uint8
emptyCommitHandling uint8
rebasingStarted bool
skipVerification bool
}
func (rs *RebaseState) PreRebaseWorkingAddr() hash.Hash {
@@ -206,6 +207,10 @@ func (rs *RebaseState) EmptyCommitHandling(_ context.Context) uint8 {
return rs.emptyCommitHandling
}
func (rs *RebaseState) SkipVerification(_ context.Context) bool {
return rs.skipVerification
}
type MergeState struct {
preMergeWorkingAddr *hash.Hash
fromCommitAddr *hash.Hash
@@ -457,6 +462,7 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) {
rebaseState.EmptyCommitHandling(),
rebaseState.LastAttemptedStep(),
rebaseState.RebasingStarted(),
rebaseState.SkipVerification(),
)
}

View File

@@ -196,6 +196,7 @@ func workingset_flatbuffer(working hash.Hash, staged *hash.Hash, mergeState *Mer
serial.RebaseStateAddEmptyCommitHandling(builder, rebaseState.emptyCommitHandling)
serial.RebaseStateAddLastAttemptedStep(builder, rebaseState.lastAttemptedStep)
serial.RebaseStateAddRebasingStarted(builder, rebaseState.rebasingStarted)
serial.RebaseStateAddSkipVerification(builder, rebaseState.skipVerification)
rebaseStateOffset = serial.RebaseStateEnd(builder)
}
@@ -264,7 +265,7 @@ func NewMergeState(
}
}
func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool) *RebaseState {
func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch string, commitBecomesEmptyHandling uint8, emptyCommitHandling uint8, lastAttemptedStep float32, rebasingStarted bool, skipVerification bool) *RebaseState {
return &RebaseState{
preRebaseWorkingAddr: &preRebaseWorkingRoot,
ontoCommitAddr: &commitAddr,
@@ -273,6 +274,7 @@ func NewRebaseState(preRebaseWorkingRoot hash.Hash, commitAddr hash.Hash, branch
emptyCommitHandling: emptyCommitHandling,
lastAttemptedStep: lastAttemptedStep,
rebasingStarted: rebasingStarted,
skipVerification: skipVerification,
}
}

View File

@@ -0,0 +1,253 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
setup() {
setup_common
dolt sql <<SQL
CREATE TABLE users (
id INT PRIMARY KEY,
name VARCHAR(100) NOT NULL,
email VARCHAR(100)
);
INSERT INTO users VALUES (1, 'Alice', 'alice@example.com');
CALL DOLT_ADD('.');
CALL DOLT_COMMIT('-m', 'Initial commit');
SQL
}
getHeadHash() {
run dolt sql -r csv -q "select commit_hash from dolt_log limit 1 offset 0;"
[ "$status" -eq 0 ] || return 1
echo "${lines[1]}"
}
@test "commit_verification: system variables can be set" {
run dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
[ "$status" -eq 0 ]
run dolt sql -q "SHOW GLOBAL VARIABLES LIKE 'dolt_commit_verification_groups'"
[ "$status" -eq 0 ]
[[ "$output" =~ "*" ]] || false
}
@test "commit_verification: commit with tests enabled - all tests pass" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1'),
('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Alice"', 'expected_single_value', '==', '1');
SQL
dolt add .
run dolt commit -m "Commit with passing tests"
[ "$status" -eq 0 ]
}
@test "commit_verification: abort commit, then skip verification to bypass" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
SQL
dolt add .
run dolt commit -m "Commit that should fail verification"
[ "$status" -ne 0 ]
[[ "$output" =~ "commit verification failed" ]] || false
[[ "$output" =~ "test_will_fail" ]] || false
[[ "$output" =~ "expected_single_value equal to 999, got 1" ]] || false
run dolt commit --skip-verification -m "Skip verification commit"
[ "$status" -eq 0 ]
}
@test "commit_verification: specific test groups - only specified groups run" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = 'unit'"
# Add tests in different groups
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_users_unit', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1'),
('test_users_integration', 'integration', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
SQL
dolt add .
# Commit should succeed because only unit tests run (integration test that would fail is ignored)
run dolt commit -m "Commit with unit tests only"
[ "$status" -eq 0 ]
}
@test "commit_verification: merge with tests enabled - tests pass" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_alice_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Alice"', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit -m "Initial commit"
dolt checkout -b feature
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_bob_exists', 'unit', 'SELECT COUNT(*) FROM users WHERE name = "Bob"', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit -m "Add Bob"
dolt checkout main
run dolt merge feature
[ "$status" -eq 0 ]
}
@test "commit_verification: merge with tests enabled - tests fail, merge aborted" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_will_fail', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '999');
SQL
dolt add .
dolt commit --skip-verification -m "Initial commit with failing test"
dolt checkout -b feature
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt add .
dolt commit --skip-verification -m "Add Bob"
# Add Charlie to main to force non-fast-forward merge
dolt checkout main
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
dolt add .
dolt commit --skip-verification -m "Add Charlie"
run dolt merge feature
[ "$status" -ne 0 ]
[[ "$output" =~ "commit verification failed" ]] || false
[[ "$output" =~ "test_will_fail" ]] || false
[[ "$output" =~ "expected_single_value equal to 999, got 3" ]] || false
run dolt merge --skip-verification feature
[ "$status" -eq 0 ]
}
@test "commit_verification: cherry-pick with tests enabled - tests pass" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_user_count_update', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit --skip-verification -m "Add test"
dolt checkout -b feature
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_user_count_update'"
dolt add .
dolt commit --skip-verification -m "Add Bob and update test"
commit_hash=$(getHeadHash)
dolt checkout main
run dolt cherry-pick $commit_hash
[ "$status" -eq 0 ]
}
@test "commit_verification: cherry-pick with tests enabled - tests fail, aborted" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit -m "Initial commit"
dolt checkout -b feature
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt add .
dolt commit --skip-verification -m "Add Bob but don't update test"
commit_hash=$(getHeadHash)
dolt checkout main
run dolt cherry-pick $commit_hash
[ "$status" -ne 0 ]
[[ "$output" =~ "commit verification failed" ]] || false
[[ "$output" =~ "test_users_count" ]] || false
[[ "$output" =~ "expected_single_value equal to 1, got 2" ]] || false
run dolt cherry-pick --skip-verification $commit_hash
[ "$status" -eq 0 ]
}
@test "commit_verification: rebase with tests enabled - tests pass" {
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit -m "Initial commit"
dolt sql -q "UPDATE users SET name = 'Zed' WHERE id = 1"
dolt commit -am "Update Alice to Zed" # Tests still pass
dolt checkout -b feature HEAD~1
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'"
dolt add .
dolt commit -m "Add Bob and update test"
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
dolt sql -q "UPDATE dolt_tests SET assertion_value = '3' WHERE test_name = 'test_users_count'"
dolt add .
dolt commit -m "Add Charlie, update test"
run dolt rebase main
[ "$status" -eq 0 ]
[[ "$output" =~ "Successfully rebased" ]] || false
}
@test "commit_verification: rebase with tests enabled - tests fail, aborted" {
skip "Rebase restart of workflow on failed verification is currently busted."
dolt sql -q "SET @@PERSIST.dolt_commit_verification_groups = '*'"
dolt sql <<SQL
INSERT INTO dolt_tests (test_name, test_group, test_query, assertion_type, assertion_comparator, assertion_value) VALUES
('test_users_count', 'unit', 'SELECT COUNT(*) FROM users', 'expected_single_value', '==', '1');
SQL
dolt add .
dolt commit -m "Initial commit"
dolt checkout -b feature
dolt sql -q "INSERT INTO users VALUES (2, 'Bob', 'bob@example.com')"
dolt sql -q "UPDATE dolt_tests SET assertion_value = '2' WHERE test_name = 'test_users_count'"
dolt add .
dolt commit -m "Add Bob and update test"
dolt checkout main
dolt sql -q "INSERT INTO users VALUES (3, 'Charlie', 'charlie@example.com')"
dolt add .
dolt commit --skip-verification -m "Add Charlie"
dolt checkout feature
run dolt rebase main
[ "$status" -ne 0 ]
[[ "$output" =~ "commit verification failed" ]] || false
[[ "$output" =~ "test_users_count" ]] || false
[[ "$output" =~ "Expected '2' but got '3'" ]] || false
run dolt rebase --skip-verification main
[ "$status" -eq 0 ]
[[ "$output" =~ "Successfully rebased" ]] || false
}

View File

@@ -144,6 +144,7 @@ SKIP_SERVER_TESTS=$(cat <<-EOM
~branch-activity.bats~
~mutual-tls-auth.bats~
~requires-repo.bats~
~commit_verification.bats~
EOM
)