Merge remote-tracking branch 'origin/main' into contrib

This commit is contained in:
Neil Macneale IV
2026-01-05 16:13:46 -08:00
215 changed files with 7505 additions and 4056 deletions
+21
View File
@@ -173,6 +173,27 @@ jobs:
if (pull.keepAlive) process.exit(0);
const checkSuiteRes = await github.rest.checks.listSuitesForRef({
owner,
repo,
ref: pull.headRef,
});
if (checkSuiteRes.data) {
for (const suite of checkSuiteRes.data.check_suites) {
console.log("suite id:", suite.id);
console.log("suite app slug:", suite.app.slug);
console.log("suite status:", suite.status);
console.log("suite conclusion:", suite.conclusion);
if (suite.app.slug === "github-actions") {
if (suite.status !== "completed" || suite.conclusion !== "success") {
console.log(`Leaving pr open due to status:${suite.status} conclusion${suite.conclusion}`);
process.exit(0);
}
}
}
}
console.log(`Closing open pr ${pull.number}`);
await github.rest.issues.createComment({
issue_number: pull.number,
+20
View File
@@ -170,6 +170,26 @@ jobs:
asset_path: go/out/install.sh
asset_name: install.sh
asset_content_type: text/plain
- name: Upload Linux AMD64 RPM
id: upload-linux-amd64-rpm
uses: dolthub/upload-release-asset@v2
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: go/out/dolt-${{ needs.format-version.outputs.version }}-1.x86_64.rpm
asset_name: dolt-${{ needs.format-version.outputs.version }}-1.x86_64.rpm
asset_content_type: application/zip
- name: Upload Linux ARM64 RPM
id: upload-linux-arm64-rpm
uses: dolthub/upload-release-asset@v2
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.create_release.outputs.upload_url }}
asset_path: go/out/dolt-${{ needs.format-version.outputs.version }}-1.aarch64.rpm
asset_name: dolt-${{ needs.format-version.outputs.version }}-1.aarch64.rpm
asset_content_type: application/zip
create-windows-msi:
needs: [format-version, create-pgo-release]
+3 -3
View File
@@ -89,9 +89,9 @@ jobs:
- name: Install Maven
working-directory: ./.ci_bin
run: |
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.11-bin.tar.gz
tar -xf apache-maven-3.9.11-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.11/bin" >> $GITHUB_PATH
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.12-bin.tar.gz
tar -xf apache-maven-3.9.12-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.12/bin" >> $GITHUB_PATH
- name: Install Hadoop
working-directory: ./.ci_bin
run: |
+3 -3
View File
@@ -78,9 +78,9 @@ jobs:
if: ${{ env.use_credentials != 'true' }}
working-directory: ./.ci_bin
run: |
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.11-bin.tar.gz
tar -xf apache-maven-3.9.11-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.11/bin" >> $GITHUB_PATH
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.12/binaries/apache-maven-3.9.12-bin.tar.gz
tar -xf apache-maven-3.9.12-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.12/bin" >> $GITHUB_PATH
- name: Install Hadoop
if: ${{ env.use_credentials != 'true' }}
working-directory: ./.ci_bin
+3 -3
View File
@@ -94,9 +94,9 @@ jobs:
- name: Install Maven
working-directory: ./.ci_bin
run: |
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.11/binaries/apache-maven-3.9.11-bin.tar.gz
tar -xf apache-maven-3.9.11-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.11/bin" >> $GITHUB_PATH
curl -LO https://dlcdn.apache.org/maven/maven-3/3.9.12/binaries/apache-maven-3.9.12-bin.tar.gz
tar -xf apache-maven-3.9.12-bin.tar.gz
echo "$(pwd)/apache-maven-3.9.12/bin" >> $GITHUB_PATH
- name: Install Hadoop
working-directory: ./.ci_bin
run: |
+7 -28
View File
@@ -46,15 +46,6 @@ func ParseAuthor(authorStr string) (string, string, error) {
return name, email, nil
}
const (
SyncBackupId = "sync"
SyncBackupUrlId = "sync-url"
RestoreBackupId = "restore"
AddBackupId = "add"
RemoveBackupId = "remove"
RemoveBackupShortId = "rm"
)
var branchForceFlagDesc = "Reset {{.LessThan}}branchname{{.GreaterThan}} to {{.LessThan}}startpoint{{.GreaterThan}}, even if {{.LessThan}}branchname{{.GreaterThan}} exists already. Without {{.EmphasisLeft}}-f{{.EmphasisRight}}, {{.EmphasisLeft}}dolt branch{{.EmphasisRight}} refuses to change an existing branch. In combination with {{.EmphasisLeft}}-d{{.EmphasisRight}} (or {{.EmphasisLeft}}--delete{{.EmphasisRight}}), allow deleting the branch irrespective of its merged status. In combination with -m (or {{.EmphasisLeft}}--move{{.EmphasisRight}}), allow renaming the branch even if the new branch name already exists, the same applies for {{.EmphasisLeft}}-c{{.EmphasisRight}} (or {{.EmphasisLeft}}--copy{{.EmphasisRight}})."
// CreateCommitArgParser creates the argparser shared dolt commit cli and DOLT_COMMIT.
@@ -97,6 +88,7 @@ func CreateMergeArgParser() *argparser.ArgParser {
return errors.New("Error: Dolt does not support merging from multiple commits. You probably meant to checkout one and then merge from the other.")
}
ap.SupportsFlag(NoFFParam, "", "Create a merge commit even when the merge resolves as a fast-forward.")
ap.SupportsFlag(FFOnlyParam, "", "Refuse to merge unless the current HEAD is already up to date or the merge can be resolved as a fast-forward.")
ap.SupportsFlag(SquashParam, "", "Merge changes to the working set without updating the commit history")
ap.SupportsString(MessageArg, "m", "msg", "Use the given {{.LessThan}}msg{{.GreaterThan}} as the commit message.")
ap.SupportsFlag(AbortParam, "", "Abort the in-progress merge and return the working set to the state before the merge started.")
@@ -227,6 +219,7 @@ func CreatePullArgParser() *argparser.ArgParser {
ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"remoteBranch", "The name of a branch on the specified remote to be merged into the current working set."})
ap.SupportsFlag(SquashParam, "", "Merge changes to the working set without updating the commit history")
ap.SupportsFlag(NoFFParam, "", "Create a merge commit even when the merge resolves as a fast-forward.")
ap.SupportsFlag(FFOnlyParam, "", "Refuse to merge unless the current HEAD is already up to date or the merge can be resolved as a fast-forward.")
ap.SupportsFlag(ForceFlag, "f", "Update from the remote HEAD even if there are errors.")
ap.SupportsFlag(CommitFlag, "", "Perform the merge and commit the result. This is the default option, but can be overridden with the --no-commit flag. Note that this option does not affect fast-forward merges, which don't create a new merge commit, and if any merge conflicts or constraint violations are detected, no commit will be attempted.")
ap.SupportsFlag(NoCommitFlag, "", "Perform the merge and stop just before creating a merge commit. Note this will not prevent a fast-forward merge; use the --no-ff arg together with the --no-commit arg to prevent both fast-forwards and merge commits.")
@@ -320,6 +313,7 @@ func CreateDiffArgParser(isTableFunction bool) *argparser.ArgParser {
ap.SupportsString(FormatFlag, "r", "result output format", "How to format diff output. Valid values are tabular, sql, json. Defaults to tabular.")
ap.SupportsString(WhereParam, "", "column", "filters columns based on values in the diff. See {{.EmphasisLeft}}dolt diff --help{{.EmphasisRight}} for details.")
ap.SupportsInt(LimitParam, "", "record_count", "limits to the first N diffs.")
ap.SupportsString(FilterParam, "", "diff_type", "filters results based on the type of change (added, modified, renamed, dropped). 'removed' is accepted as an alias for 'dropped'.")
ap.SupportsFlag(StagedFlag, "", "Show only the staged data changes.")
ap.SupportsFlag(CachedFlag, "c", "Synonym for --staged")
ap.SupportsFlag(MergeBase, "", "Uses merge base of the first commit and second commit (or HEAD if not supplied) as the first commit")
@@ -391,36 +385,21 @@ func CreateGlobalArgParser(name string) *argparser.ArgParser {
return ap
}
var awsParams = []string{dbfactory.AWSRegionParam, dbfactory.AWSCredsTypeParam, dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile}
var AwsParams = []string{dbfactory.AWSRegionParam, dbfactory.AWSCredsTypeParam, dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile}
var ossParams = []string{dbfactory.OSSCredsFileParam, dbfactory.OSSCredsProfile}
func ProcessBackupArgs(apr *argparser.ArgParseResults, scheme, backupUrl string) (map[string]string, error) {
params := map[string]string{}
var err error
switch scheme {
case dbfactory.AWSScheme:
err = AddAWSParams(backupUrl, apr, params)
case dbfactory.OSSScheme:
err = AddOSSParams(backupUrl, apr, params)
default:
err = VerifyNoAwsParams(apr)
}
return params, err
}
func AddAWSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[string]string) error {
isAWS := strings.HasPrefix(remoteUrl, "aws")
if !isAWS {
for _, p := range awsParams {
for _, p := range AwsParams {
if _, ok := apr.GetValue(p); ok {
return fmt.Errorf("%s param is only valid for aws cloud remotes in the format aws://dynamo-table:s3-bucket/database", p)
}
}
}
for _, p := range awsParams {
for _, p := range AwsParams {
if val, ok := apr.GetValue(p); ok {
params[p] = val
}
@@ -450,7 +429,7 @@ func AddOSSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[s
}
func VerifyNoAwsParams(apr *argparser.ArgParseResults) error {
if awsParams := apr.GetValues(awsParams...); len(awsParams) > 0 {
if awsParams := apr.GetValues(AwsParams...); len(awsParams) > 0 {
awsParamKeys := make([]string, 0, len(awsParams))
for k := range awsParams {
awsParamKeys = append(awsParamKeys, k)
+2
View File
@@ -53,6 +53,7 @@ const (
NoCommitFlag = "no-commit"
NoEditFlag = "no-edit"
NoFFParam = "no-ff"
FFOnlyParam = "ff-only"
NoPrettyFlag = "no-pretty"
NoTLSFlag = "no-tls"
NoJsonMergeFlag = "dont-merge-json"
@@ -99,6 +100,7 @@ const (
SummaryFlag = "summary"
WhereParam = "where"
LimitParam = "limit"
FilterParam = "filter"
MergeBase = "merge-base"
DiffMode = "diff-mode"
ReverseFlag = "reverse"
+37 -2
View File
@@ -18,6 +18,7 @@ import (
"context"
"os"
"path/filepath"
"time"
"github.com/sirupsen/logrus"
@@ -47,7 +48,12 @@ func (cmd JournalInspectCmd) Docs() *cli.CommandDocumentation {
ShortDesc: "Inspect a Dolt journal file and display information about it",
LongDesc: `This tool is intented for debugging Dolt journal files. Since it is intended to debug potentially
corrupted files, it is best run from a location which doesn't attempt to load databases. Ie, go to /tmp, and run
dolt admin journal-inspect /path/to/journal/file`,
dolt admin journal-inspect /path/to/journal/file
When using the --filter-roots or --filter-chunks options, a new journal file will be created next to the original
file with the .filtered extension. This new journal will be identical to the original except that it will not contain
records of the specified type with any of the specified hashes. Multiple hashes can be provided as a comma-separated
list. The two filter options can be used together to filter both root and chunk records.`,
Synopsis: []string{
"<journal-path>",
},
@@ -61,10 +67,12 @@ func (cmd JournalInspectCmd) ArgParser() *argparser.ArgParser {
ap.SupportsFlag("verbose", "v", "Display verbose output during inspection (same as -r -c")
ap.SupportsFlag("crc-scan", "", "Scan invalid sections for valid CRCs. SLOW.")
ap.SupportsFlag("snappy-scan", "", "Scan invalid sections for snappy tags and content headers.")
ap.SupportsString("filter-roots", "", "hashcode1,hashcode2,...", "Create filtered copy of journal excluding the specified root hashes (comma-separated)")
ap.SupportsString("filter-chunks", "", "hashcode1,hashcode2,...", "Create filtered copy of journal excluding the specified chunk hashes (comma-separated)")
return ap
}
func (cmd JournalInspectCmd) Exec(_ context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
func (cmd JournalInspectCmd) Exec(_ context.Context, commandStr string, args []string, _ *env.DoltEnv, _ cli.CliContext) int {
ap := cmd.ArgParser()
usage, _ := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, cli.CommandDocumentationContent{}, ap))
apr := cli.ParseArgsOrDie(ap, args, usage)
@@ -93,6 +101,8 @@ func (cmd JournalInspectCmd) Exec(_ context.Context, commandStr string, args []s
crcScan := apr.Contains("crc-scan")
snappScan := apr.Contains("snappy-scan")
filterRootsStr := apr.GetValueOrDefault("filter-roots", "")
filterChunksStr := apr.GetValueOrDefault("filter-chunks", "")
if _, err := os.Stat(journalPath); os.IsNotExist(err) {
cli.PrintErrln("Error: Journal file does not exist:", journalPath)
@@ -105,6 +115,31 @@ func (cmd JournalInspectCmd) Exec(_ context.Context, commandStr string, args []s
return 1
}
// Handle filter mode
if filterRootsStr != "" || filterChunksStr != "" {
result, exitCode := nbs.JournalFilter(absPath, filterRootsStr, filterChunksStr)
if exitCode != 0 {
return exitCode
}
// Only print shell commands if a filtered file was actually created
if result.OutputPath != "" && result.FilteredRecords > 0 {
// Print shell commands to replace the journal file
now := time.Now()
dateString := now.Format("2006_01_02_150405")
cli.Println("")
cli.Printf("Filtered file: %s\n", result.OutputPath)
cli.Println("")
cli.Println("To replace the original journal file, run these commands:")
cli.Printf("cp %s %s_saved_%s\n", result.OriginalPath, result.OriginalPath, dateString)
cli.Printf("mv %s %s\n", result.OutputPath, result.OriginalPath)
cli.Printf("rm %s\n", filepath.Join(filepath.Dir(result.OriginalPath), "journal.idx"))
}
return 0
}
// JournalInspect returns an exit code. It's entire purpose it to print errors, after all.
return nbs.JournalInspect(absPath, seeRoots, seeChunks, crcScan, snappScan)
}
+3 -1
View File
@@ -36,6 +36,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/dconfig"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
)
@@ -581,6 +582,7 @@ func jsonMessage(role string, content string) ([]byte, error) {
}
func getCreateTableStatements(ctx *sql.Context, sqlEngine *engine.SqlEngine, dEnv *env.DoltEnv) (string, error) {
formatter := overrides.SchemaFormatterFromContext(ctx)
sb := strings.Builder{}
root, err := dEnv.WorkingRoot(ctx)
@@ -590,7 +592,7 @@ func getCreateTableStatements(ctx *sql.Context, sqlEngine *engine.SqlEngine, dEn
tables, err := root.GetTableNames(ctx, doltdb.DefaultSchemaName, true)
for _, table := range tables {
_, iter, _, err := sqlEngine.Query(ctx, fmt.Sprintf("SHOW CREATE TABLE %s", sql.QuoteIdentifier(table)))
_, iter, _, err := sqlEngine.Query(ctx, fmt.Sprintf("SHOW CREATE TABLE %s", formatter.QuoteIdentifier(table)))
if err != nil {
return "", err
}
+96 -281
View File
@@ -16,42 +16,42 @@ package commands
import (
"context"
"encoding/json"
"strings"
"fmt"
eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dprocedures"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/store/datas/pull"
"github.com/dolthub/dolt/go/store/types"
eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1"
)
const DoltBackupCommandName = "backup"
var backupDocs = cli.CommandDocumentationContent{
ShortDesc: "Manage server backups",
LongDesc: `With no arguments, shows a list of existing backups. Several subcommands are available to perform operations on backups, point in time snapshots of a database's contents.
ShortDesc: "Manage database backups, including creation, sync, and restore.",
LongDesc: `
With no arguments, shows a list of existing backups. Several subcommands are available to perform operations on backups; point in time snapshots of a database's contents.
{{.EmphasisLeft}}add{{.EmphasisRight}}
Adds a backup named {{.LessThan}}name{{.GreaterThan}} for the database at {{.LessThan}}url{{.GreaterThan}}.
The {{.LessThan}}url{{.GreaterThan}} parameter supports url schemes of http, https, aws, gs, and file. The url prefix defaults to https. If the {{.LessThan}}url{{.GreaterThan}} parameter is in the format {{.EmphasisLeft}}<organization>/<repository>{{.EmphasisRight}} then dolt will use the {{.EmphasisLeft}}backups.default_host{{.EmphasisRight}} from your configuration file (Which will be dolthub.com unless changed).
The {{.LessThan}}url{{.GreaterThan}} parameter supports http, https, aws, gs, and file schemes (https as default). If the {{.LessThan}}url{{.GreaterThan}} parameter is in the format {{.EmphasisLeft}}<organization>/<repository>{{.EmphasisRight}} then dolt will use the {{.EmphasisLeft}}backups.default_host{{.EmphasisRight}} from your configuration file (dolthub.com by default).
The URL address must be unique to existing remotes and backups.
AWS cloud backup urls should be of the form {{.EmphasisLeft}}aws://[dynamo-table:s3-bucket]/database{{.EmphasisRight}}. You may configure your aws cloud backup using the optional parameters {{.EmphasisLeft}}aws-region{{.EmphasisRight}}, {{.EmphasisLeft}}aws-creds-type{{.EmphasisRight}}, {{.EmphasisLeft}}aws-creds-file{{.EmphasisRight}}.
AWS cloud backup URLs should be of the form {{.EmphasisLeft}}aws://[dynamo-table:s3-bucket]/database{{.EmphasisRight}}. You may configure your AWS cloud backup using the optional parameters {{.EmphasisLeft}}aws-region{{.EmphasisRight}}, {{.EmphasisLeft}}aws-creds-type{{.EmphasisRight}}, {{.EmphasisLeft}}aws-creds-file{{.EmphasisRight}}, {{.EmphasisLeft}}aws-creds-profile{{.EmphasisRight}}.
aws-creds-type specifies the means by which credentials should be retrieved in order to access the specified cloud resources (specifically the dynamo table, and the s3 bucket). Valid values are 'role', 'env', or 'file'.
aws-creds-type specifies the means by which credentials should be retrieved in order to access the specified cloud resources (required for DynamoDB tables, and S3 buckets). Valid values are 'role', 'env', or 'file'.
role: Use the credentials installed for the current user
env: Looks for environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY
file: Uses the credentials file specified by the parameter aws-creds-file
role: Use the credentials installed for the current user.
env: Looks for environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY.
file: Uses the credentials file specified by the parameter aws-creds-file.
GCP backup urls should be of the form gs://gcs-bucket/database and will use the credentials setup using the gcloud command line available from Google.
GCP backup URLs should follow the format {{.EmphasisLeft}}gs://gcs-bucket/database{{.EmphasisRight}}. Backups will use the credentials that you configure using the gcloud CLI.
The local filesystem can be used as a backup by providing a repository url in the format file://absolute path. See https://en.wikipedia.org/wiki/File_URI_scheme
The local filesystem can be used as a backup by providing a repository URL in the format {{.EmphasisLeft}}file://absolute-path{{.EmphasisRight}}. See https://en.wikipedia.org/wiki/File_URI_scheme.
{{.EmphasisLeft}}remove{{.EmphasisRight}}, {{.EmphasisLeft}}rm{{.EmphasisRight}}
Remove the backup named {{.LessThan}}name{{.GreaterThan}}. All configuration settings for the backup are removed. The contents of the backup are not affected.
@@ -62,20 +62,21 @@ Restore a Dolt database from a given {{.LessThan}}url{{.GreaterThan}} into a spe
{{.EmphasisLeft}}sync{{.EmphasisRight}}
Snapshot the database and upload to the backup {{.LessThan}}name{{.GreaterThan}}. This includes branches, tags, working sets, and remote tracking refs.
{{.EmphasisLeft}}sync-url{{.EmphasisRight}}
Snapshot the database and upload the backup to {{.LessThan}}url{{.GreaterThan}}. Like sync, this includes branches, tags, working sets, and remote tracking refs, but it does not require you to create a named backup`,
Snapshot the database and upload the backup to {{.LessThan}}url{{.GreaterThan}}. Like sync, this includes branches, tags, working sets, and remote tracking refs, but it does not require you to create a named backup.
`,
Synopsis: []string{
"[-v | --verbose]",
"add [--aws-region {{.LessThan}}region{{.GreaterThan}}] [--aws-creds-type {{.LessThan}}creds-type{{.GreaterThan}}] [--aws-creds-file {{.LessThan}}file{{.GreaterThan}}] [--aws-creds-profile {{.LessThan}}profile{{.GreaterThan}}] {{.LessThan}}name{{.GreaterThan}} {{.LessThan}}url{{.GreaterThan}}",
"remove {{.LessThan}}name{{.GreaterThan}}",
"restore [--force] {{.LessThan}}url{{.GreaterThan}} {{.LessThan}}name{{.GreaterThan}}",
"restore [--aws-region {{.LessThan}}region{{.GreaterThan}}] [--aws-creds-type {{.LessThan}}creds-type{{.GreaterThan}}] [--aws-creds-file {{.LessThan}}file{{.GreaterThan}}] [--aws-creds-profile {{.LessThan}}profile{{.GreaterThan}}] [--force] {{.LessThan}}url{{.GreaterThan}} {{.LessThan}}name{{.GreaterThan}}",
"sync {{.LessThan}}name{{.GreaterThan}}",
"sync-url [--aws-region {{.LessThan}}region{{.GreaterThan}}] [--aws-creds-type {{.LessThan}}creds-type{{.GreaterThan}}] [--aws-creds-file {{.LessThan}}file{{.GreaterThan}}] [--aws-creds-profile {{.LessThan}}profile{{.GreaterThan}}] {{.LessThan}}url{{.GreaterThan}}",
},
}
var VerboseErrUsage = errhand.BuildDError("").SetPrintUsage().Build()
type BackupCmd struct{}
// Name is returns the name of the Dolt cli command. This is what is used on the command line to invoke the command
@@ -106,287 +107,101 @@ func (cmd BackupCmd) EventType() eventsapi.ClientEventType {
return eventsapi.ClientEventType_REMOTE
}
// Exec executes the command
func (cmd BackupCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, cliCtx cli.CliContext) int {
ap := cmd.ArgParser()
help, usage := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, backupDocs, ap))
apr := cli.ParseArgsOrDie(ap, args, help)
// Exec executes the `dolt backup` command with the provided subcommand. If no subcommand is provided, the dolt_backups
// table is printed.
func (cmd BackupCmd) Exec(ctx context.Context, commandStr string, args []string, _ *env.DoltEnv, cliCtx cli.CliContext) int {
argParser := cmd.ArgParser()
help, usage := cli.HelpAndUsagePrinters(cli.CommandDocsForCommandString(commandStr, backupDocs, argParser))
apr := cli.ParseArgsOrDie(argParser, args, help)
var verr errhand.VerboseError
// All the sub commands except `restore` require a valid environment
if apr.NArg() == 0 || apr.Arg(0) != cli.RestoreBackupId {
if !cli.CheckEnvIsValid(dEnv) {
return 2
}
}
switch {
case apr.NArg() == 0:
verr = printBackups(dEnv, apr)
case apr.Arg(0) == cli.AddBackupId:
verr = addBackup(dEnv, apr)
case apr.Arg(0) == cli.RemoveBackupId:
verr = removeBackup(ctx, dEnv, apr)
case apr.Arg(0) == cli.RemoveBackupShortId:
verr = removeBackup(ctx, dEnv, apr)
case apr.Arg(0) == cli.SyncBackupId:
verr = syncBackup(ctx, dEnv, apr)
case apr.Arg(0) == cli.SyncBackupUrlId:
verr = syncBackupUrl(ctx, dEnv, apr)
case apr.Arg(0) == cli.RestoreBackupId:
verr = restoreBackup(ctx, dEnv, apr)
default:
verr = errhand.BuildDError("").SetPrintUsage().Build()
}
return HandleVErrAndExitCode(verr, usage)
}
func removeBackup(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.NArg() != 2 {
return errhand.BuildDError("").SetPrintUsage().Build()
}
old := strings.TrimSpace(apr.Arg(1))
err := dEnv.RemoveBackup(ctx, old)
switch err {
case nil:
return nil
case env.ErrFailedToWriteRepoState:
return errhand.BuildDError("error: failed to save change to repo state").AddCause(err).Build()
case env.ErrFailedToDeleteBackup:
return errhand.BuildDError("error: failed to delete backup tracking ref").AddCause(err).Build()
case env.ErrFailedToReadFromDb:
return errhand.BuildDError("error: failed to read from db").AddCause(err).Build()
case env.ErrBackupNotFound:
return errhand.BuildDError("error: unknown backup: '%s' ", old).Build()
default:
return errhand.BuildDError("error: unknown error").AddCause(err).Build()
}
}
func addBackup(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.NArg() != 3 {
return errhand.BuildDError("").SetPrintUsage().Build()
}
backupName := strings.TrimSpace(apr.Arg(1))
backupUrl := apr.Arg(2)
scheme, absBackupUrl, err := env.GetAbsRemoteUrl(dEnv.FS, dEnv.Config, backupUrl)
queryEngine, err := cliCtx.QueryEngine(ctx)
if err != nil {
return errhand.BuildDError("error: '%s' is not valid.", backupUrl).AddCause(err).Build()
return HandleVErrAndExitCode(errhand.VerboseErrorFromError(err), usage)
}
params, err := cli.ProcessBackupArgs(apr, scheme, absBackupUrl)
if apr.NArg() == 0 {
verboseErr := printDoltBackupsTable(&queryEngine, apr.Contains(cli.VerboseFlag))
return HandleVErrAndExitCode(verboseErr, usage)
}
switch apr.Arg(0) {
case dprocedures.DoltBackupParamAdd:
if apr.NArg() != 3 {
return HandleVErrAndExitCode(VerboseErrUsage, usage)
}
case dprocedures.DoltBackupParamRemove,
dprocedures.DoltBackupParamRm,
dprocedures.DoltBackupParamSync,
dprocedures.DoltBackupParamSyncUrl:
if apr.NArg() != 2 {
return HandleVErrAndExitCode(VerboseErrUsage, usage)
}
case dprocedures.DoltBackupParamRestore:
if apr.NArg() < 3 {
return HandleVErrAndExitCode(VerboseErrUsage, usage)
}
default:
return HandleVErrAndExitCode(VerboseErrUsage, usage)
}
verboseErr := callDoltBackupProc(&queryEngine, args)
return HandleVErrAndExitCode(verboseErr, usage)
}
// callDoltBackupProc calls the dolt_backup stored procedure with the given parameters.
func callDoltBackupProc(queryEngine *cli.QueryEngineResult, params []string) errhand.VerboseError {
query, err := interpolateStoredProcedureCall(dprocedures.DoltBackupProcedureName, params)
if err != nil {
return errhand.BuildDError("failed to interpolate stored procedure %s", dprocedures.DoltBackupProcedureName).AddCause(err).Build()
}
_, err = cli.GetRowsForSql(queryEngine.Queryist, queryEngine.Context, query)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
r := env.NewRemote(backupName, backupUrl, params)
err = dEnv.AddBackup(r)
switch err {
case nil:
return nil
case env.ErrBackupAlreadyExists:
return errhand.BuildDError("error: a backup named '%s' already exists.", r.Name).AddDetails("remove it before running this command again").Build()
case env.ErrBackupNotFound:
return errhand.BuildDError("error: unknown backup: '%s' ", r.Name).Build()
case env.ErrInvalidBackupURL:
return errhand.BuildDError("error: '%s' is not valid.", r.Url).AddCause(err).Build()
case env.ErrInvalidBackupName:
return errhand.BuildDError("error: invalid backup name: %s", r.Name).Build()
default:
return errhand.BuildDError("error: Unable to save changes.").AddCause(err).Build()
}
}
func printBackups(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
backups, err := dEnv.GetBackups()
if err != nil {
return errhand.BuildDError("Unable to get backups from the local directory").AddCause(err).Build()
}
for _, r := range backups.Snapshot() {
if apr.Contains(cli.VerboseFlag) {
paramStr := make([]byte, 0)
if len(r.Params) > 0 {
paramStr, _ = json.Marshal(r.Params)
}
cli.Printf("%s %s %s\n", r.Name, r.Url, paramStr)
} else {
cli.Println(r.Name)
}
}
return nil
}
func syncBackupUrl(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.NArg() != 2 {
return errhand.BuildDError("").SetPrintUsage().Build()
}
backupUrl := apr.Arg(1)
scheme, absBackupUrl, err := env.GetAbsRemoteUrl(dEnv.FS, dEnv.Config, backupUrl)
if err != nil {
return errhand.BuildDError("error: '%s' is not valid.", backupUrl).AddCause(err).Build()
}
params, err := cli.ProcessBackupArgs(apr, scheme, absBackupUrl)
// printDoltBackupsTable queries the dolt_backups table and prints the results. If the verbose flag is set, it prints
// name, url, and params columns. Otherwise, it prints only the name column.
func printDoltBackupsTable(queryEngine *cli.QueryEngineResult, showVerbose bool) errhand.VerboseError {
query := fmt.Sprintf("SELECT * FROM `%s`", doltdb.BackupsTableName)
schema, rowIter, _, err := queryEngine.Queryist.Query(queryEngine.Context, query)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
b := env.NewRemote("__temp__", backupUrl, params)
return backup(ctx, dEnv, b)
}
func syncBackup(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.NArg() != 2 {
return errhand.BuildDError("").SetPrintUsage().Build()
}
backupName := strings.TrimSpace(apr.Arg(1))
backups, err := dEnv.GetBackups()
rows, err := sql.RowIterToRows(queryEngine.Context, rowIter)
if err != nil {
return errhand.BuildDError("Unable to get backups from the local directory").AddCause(err).Build()
return errhand.BuildDError("failed to retrieve slice for %s", doltdb.BackupsTableName).AddCause(err).Build()
}
b, ok := backups.Get(backupName)
if !ok {
return errhand.BuildDError("error: unknown backup: '%s' ", backupName).Build()
}
return backup(ctx, dEnv, b)
}
func backup(ctx context.Context, dEnv *env.DoltEnv, b env.Remote) errhand.VerboseError {
destDb, err := b.GetRemoteDB(ctx, dEnv.DoltDB(ctx).ValueReadWriter().Format(), dEnv)
if err != nil {
return errhand.BuildDError("error: unable to open destination.").AddCause(err).Build()
}
tmpDir, err := dEnv.TempTableFilesDir()
if err != nil {
return errhand.BuildDError("error: ").AddCause(err).Build()
}
err = actions.SyncRoots(ctx, dEnv.DoltDB(ctx), destDb, tmpDir, buildProgStarter(defaultLanguage), stopProgFuncs)
switch err {
case nil:
return nil
case pull.ErrDBUpToDate:
return nil
case env.ErrBackupAlreadyExists:
return errhand.BuildDError("error: a backup named '%s' already exists.", b.Name).AddDetails("remove it before running this command again").Build()
case env.ErrBackupNotFound:
return errhand.BuildDError("error: unknown backup: '%s' ", b.Name).Build()
case env.ErrInvalidBackupURL:
return errhand.BuildDError("error: '%s' is not valid.", b.Url).AddCause(err).Build()
case env.ErrInvalidBackupName:
return errhand.BuildDError("error: invalid backup name: %s", b.Name).Build()
default:
return errhand.BuildDError("error: Unable to save changes.").AddCause(err).Build()
}
}
func restoreBackup(ctx context.Context, dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.VerboseError {
if apr.NArg() < 3 {
return errhand.BuildDError("").SetPrintUsage().Build()
}
apr.Args = apr.Args[1:]
restoredDB, urlStr, verr := parseArgs(apr)
if verr != nil {
return verr
}
// For error recovery, record whether EnvForClone created the directory, or just `.dolt/noms` within the directory.
userDirExisted, _ := dEnv.FS.Exists(restoredDB)
force := apr.Contains(cli.ForceFlag)
scheme, remoteUrl, err := env.GetAbsRemoteUrl(dEnv.FS, dEnv.Config, urlStr)
if err != nil {
return errhand.BuildDError("error: '%s' is not valid.", urlStr).Build()
}
var params map[string]string
params, verr = parseRemoteArgs(apr, scheme, remoteUrl)
if verr != nil {
return verr
}
r := env.NewRemote("", remoteUrl, params)
srcDb, err := r.GetRemoteDB(ctx, types.Format_Default, dEnv)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), dEnv.FS, dEnv.Version, dEnv)
if err != nil {
return errhand.BuildDError("error: Unable to list databases").AddCause(err).Build()
}
var existingDEnv *env.DoltEnv
err = mrEnv.Iter(func(dbName string, dEnv *env.DoltEnv) (stop bool, err error) {
if dbName == restoredDB {
existingDEnv = dEnv
return true, nil
}
return false, nil
})
if err != nil {
return errhand.BuildDError("error: Unable to list databases").AddCause(err).Build()
}
if existingDEnv != nil {
if !force {
return errhand.BuildDError("error: cannot restore backup into %s. A database with that name already exists. Did you mean to supply --force?", restoredDB).Build()
const errColumnExpectedStringFmt = "column '%s' expected string, got %v"
for _, row := range rows {
// Backup configuration name
nameStr, ok := row[0].(string)
if !ok {
return errhand.BuildDError(errColumnExpectedStringFmt, schema[0].Name, row[0]).Build()
}
tmpDir, err := existingDEnv.TempTableFilesDir()
if !showVerbose {
cli.Println(nameStr)
continue
}
// Remote backup location URL (aws://, gs://, file://, http[s]://)
urlStr, ok := row[1].(string)
if !ok {
return errhand.BuildDError(errColumnExpectedStringFmt, schema[1].Name, row[1]).Build()
}
// Backup connection parameters
jsonStr, err := getJsonAsString(queryEngine.Context, row[2])
if err != nil {
return errhand.VerboseErrorFromError(err)
return errhand.BuildDError(errColumnExpectedStringFmt, schema[2].Name, row[2]).AddCause(err).Build()
}
err = actions.SyncRoots(ctx, srcDb, existingDEnv.DoltDB(ctx), tmpDir, buildProgStarter(downloadLanguage), stopProgFuncs)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
} else {
// Create a new Dolt env for the clone; use env.NoRemote to avoid origin upstream
clonedEnv, err := actions.EnvForClone(ctx, srcDb.ValueReadWriter().Format(), env.NoRemote, restoredDB, dEnv.FS, dEnv.Version, env.GetCurrentUserHomeDir)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
// Nil out the old Dolt env so we don't accidentally use the wrong database
dEnv = nil
// still make empty repo state
_, err = env.CreateRepoState(clonedEnv.FS, env.DefaultInitBranch)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
tmpDir, err := clonedEnv.TempTableFilesDir()
if err != nil {
return errhand.VerboseErrorFromError(err)
}
err = actions.SyncRoots(ctx, srcDb, clonedEnv.DoltDB(ctx), tmpDir, buildProgStarter(downloadLanguage), stopProgFuncs)
if err != nil {
// If we're cloning into a directory that already exists do not erase it. Otherwise
// make best effort to delete the directory we created.
if userDirExisted {
_ = clonedEnv.FS.Delete(dbfactory.DoltDir, true)
} else {
_ = clonedEnv.FS.Delete(".", true)
}
return errhand.VerboseErrorFromError(err)
}
cli.Printf("%s %s %s\n", nameStr, urlStr, jsonStr)
}
return nil
+199 -15
View File
@@ -86,6 +86,8 @@ The diffs displayed can be limited to show the first N by providing the paramete
To filter which data rows are displayed, use {{.EmphasisLeft}}--where <SQL expression>{{.EmphasisRight}}. Table column names in the filter expression must be prefixed with {{.EmphasisLeft}}from_{{.EmphasisRight}} or {{.EmphasisLeft}}to_{{.EmphasisRight}}, e.g. {{.EmphasisLeft}}to_COLUMN_NAME > 100{{.EmphasisRight}} or {{.EmphasisLeft}}from_COLUMN_NAME + to_COLUMN_NAME = 0{{.EmphasisRight}}.
To filter diff output by change type, use {{.EmphasisLeft}}--filter <type>{{.EmphasisRight}} where {{.EmphasisLeft}}<type>{{.EmphasisRight}} is one of {{.EmphasisLeft}}added{{.EmphasisRight}}, {{.EmphasisLeft}}modified{{.EmphasisRight}}, {{.EmphasisLeft}}renamed{{.EmphasisRight}}, or {{.EmphasisLeft}}dropped{{.EmphasisRight}}. The {{.EmphasisLeft}}added{{.EmphasisRight}} filter shows only additions (new tables or rows), {{.EmphasisLeft}}modified{{.EmphasisRight}} shows only schema modifications or row updates, {{.EmphasisLeft}}renamed{{.EmphasisRight}} shows only renamed tables, and {{.EmphasisLeft}}dropped{{.EmphasisRight}} shows only deletions (dropped tables or deleted rows). You can also use {{.EmphasisLeft}}removed{{.EmphasisRight}} as an alias for {{.EmphasisLeft}}dropped{{.EmphasisRight}}. For example, {{.EmphasisLeft}}dolt diff --filter=dropped{{.EmphasisRight}} shows only deleted rows and dropped tables.
The {{.EmphasisLeft}}--diff-mode{{.EmphasisRight}} argument controls how modified rows are presented when the format output is set to {{.EmphasisLeft}}tabular{{.EmphasisRight}}. When set to {{.EmphasisLeft}}row{{.EmphasisRight}}, modified rows are presented as old and new rows. When set to {{.EmphasisLeft}}line{{.EmphasisRight}}, modified rows are presented as a single row, and changes are presented using "+" and "-" within the column. When set to {{.EmphasisLeft}}in-place{{.EmphasisRight}}, modified rows are presented as a single row, and changes are presented side-by-side with a color distinction (requires a color-enabled terminal). When set to {{.EmphasisLeft}}context{{.EmphasisRight}}, rows that contain at least one column that spans multiple lines uses {{.EmphasisLeft}}line{{.EmphasisRight}}, while all other rows use {{.EmphasisLeft}}row{{.EmphasisRight}}. The default value is {{.EmphasisLeft}}context{{.EmphasisRight}}.
`,
Synopsis: []string{
@@ -102,6 +104,7 @@ type diffDisplaySettings struct {
where string
skinny bool
includeCols []string
filter *diffTypeFilter
}
type diffDatasets struct {
@@ -130,6 +133,141 @@ type diffStatistics struct {
NewCellCount uint64
}
// diffTypeFilter manages which diff types should be included in the output.
// When filters is nil or empty, all types are included.
type diffTypeFilter struct {
// Map of diff type -> should include
// If nil or empty, includes all types
filters map[string]bool
}
// newDiffTypeFilter creates a filter for the specified diff type.
// Pass diff.DiffTypeAll or empty string to include all types.
// Accepts "removed" as an alias for "dropped" for user convenience.
func newDiffTypeFilter(filterType string) *diffTypeFilter {
if filterType == "" || filterType == diff.DiffTypeAll {
return &diffTypeFilter{filters: nil} // nil means include all
}
// Map "removed" to "dropped" (alias for user convenience)
internalFilterType := filterType
if filterType == "removed" {
internalFilterType = diff.DiffTypeDropped
}
return &diffTypeFilter{
filters: map[string]bool{
internalFilterType: true,
},
}
}
// shouldInclude checks if the given diff type should be included.
// Uses TableDeltaSummary.DiffType field for table-level filtering.
func (df *diffTypeFilter) shouldInclude(diffType string) bool {
// nil or empty filters means include everything
if df.filters == nil || len(df.filters) == 0 {
return true
}
return df.filters[diffType]
}
// isValid validates the filter configuration
func (df *diffTypeFilter) isValid() bool {
if df.filters == nil {
return true
}
for filterType := range df.filters {
if filterType != diff.DiffTypeAdded &&
filterType != diff.DiffTypeModified &&
filterType != diff.DiffTypeRenamed &&
filterType != diff.DiffTypeDropped {
return false
}
}
return true
}
// shouldSkipRow checks if a row should be skipped based on the filter settings.
// Uses the DiffType infrastructure for consistency with table-level filtering.
func shouldSkipRow(filter *diffTypeFilter, rowChangeType diff.ChangeType) bool {
if filter == nil {
return false
}
// Don't filter None - it represents "no row" on one side of the diff
if rowChangeType == diff.None {
return false
}
// Convert row-level ChangeType to table-level DiffType string
diffType := diff.ChangeTypeToDiffType(rowChangeType)
// Use the map-based shouldInclude method
return !filter.shouldInclude(diffType)
}
// shouldUseLazyHeader determines if we should delay printing the table header
// until we know there are rows to display. This prevents empty headers when
// all rows are filtered out in data-only diffs.
func shouldUseLazyHeader(dArgs *diffArgs, tableSummary diff.TableDeltaSummary) bool {
return dArgs.filter != nil && dArgs.filter.filters != nil &&
!tableSummary.SchemaChange && !tableSummary.IsRename()
}
// lazyRowWriter wraps a SqlRowDiffWriter and delays calling BeginTable
// until the first row is actually written. This prevents empty table headers
// when all rows are filtered out.
type lazyRowWriter struct {
writer diff.SqlRowDiffWriter
// Callback to invoke before first write
// Set to nil after first call
onFirstWrite func() error
}
// newLazyRowWriter creates a lazy writer that wraps the given writer.
// The onFirstWrite callback is invoked exactly once before the first write.
func newLazyRowWriter(writer diff.SqlRowDiffWriter, onFirstWrite func() error) *lazyRowWriter {
return &lazyRowWriter{
writer: writer,
onFirstWrite: onFirstWrite,
}
}
// WriteRow implements diff.SqlRowDiffWriter
func (l *lazyRowWriter) WriteRow(ctx *sql.Context, row sql.Row, diffType diff.ChangeType, colDiffTypes []diff.ChangeType) error {
// Initialize on first write
if l.onFirstWrite != nil {
if err := l.onFirstWrite(); err != nil {
return err
}
l.onFirstWrite = nil // Prevent double-initialization
}
return l.writer.WriteRow(ctx, row, diffType, colDiffTypes)
}
// WriteCombinedRow implements diff.SqlRowDiffWriter
func (l *lazyRowWriter) WriteCombinedRow(ctx *sql.Context, oldRow, newRow sql.Row, mode diff.Mode) error {
// Initialize on first write
if l.onFirstWrite != nil {
if err := l.onFirstWrite(); err != nil {
return err
}
l.onFirstWrite = nil
}
return l.writer.WriteCombinedRow(ctx, oldRow, newRow, mode)
}
// Close implements diff.SqlRowDiffWriter
func (l *lazyRowWriter) Close(ctx context.Context) error {
return l.writer.Close(ctx)
}
type DiffCmd struct{}
// Name is returns the name of the Dolt cli command. This is what is used on the command line to invoke the command
@@ -220,6 +358,15 @@ func (cmd DiffCmd) validateArgs(apr *argparser.ArgParseResults) errhand.VerboseE
return errhand.BuildDError("invalid output format: %s", f).Build()
}
filterValue, hasFilter := apr.GetValue(cli.FilterParam)
if hasFilter {
filter := newDiffTypeFilter(filterValue)
if !filter.isValid() {
return errhand.BuildDError("invalid filter: %s. Valid values are: %s, %s, %s, %s (or %s)",
filterValue, diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeRenamed, diff.DiffTypeDropped, "removed").Build()
}
}
return nil
}
@@ -268,6 +415,9 @@ func parseDiffDisplaySettings(apr *argparser.ArgParseResults) *diffDisplaySettin
displaySettings.limit, _ = apr.GetInt(cli.LimitParam)
displaySettings.where = apr.GetValueOrDefault(cli.WhereParam, "")
filterValue := apr.GetValueOrDefault(cli.FilterParam, diff.DiffTypeAll)
displaySettings.filter = newDiffTypeFilter(filterValue)
return displaySettings
}
@@ -670,13 +820,13 @@ func getSchemaDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Contex
tableName = fromTable
}
case fromTable == "":
diffType = "added"
diffType = diff.DiffTypeAdded
tableName = toTable
case toTable == "":
diffType = "dropped"
diffType = diff.DiffTypeDropped
tableName = fromTable
case fromTable != "" && toTable != "" && fromTable != toTable:
diffType = "renamed"
diffType = diff.DiffTypeRenamed
tableName = toTable
default:
return nil, fmt.Errorf("error: unexpected schema diff case: fromTable='%s', toTable='%s'", fromTable, toTable)
@@ -738,13 +888,13 @@ func getDiffSummariesBetweenRefs(queryist cli.Queryist, sqlCtx *sql.Context, fro
}
switch summary.DiffType {
case "dropped":
case diff.DiffTypeDropped:
summary.TableName = summary.FromTableName
case "added":
case diff.DiffTypeAdded:
summary.TableName = summary.ToTableName
case "renamed":
case diff.DiffTypeRenamed:
summary.TableName = summary.ToTableName
case "modified":
case diff.DiffTypeModified:
summary.TableName = summary.FromTableName
default:
return nil, fmt.Errorf("error: unexpected diff type '%s'", summary.DiffType)
@@ -816,6 +966,16 @@ func diffUserTables(queryist cli.Queryist, sqlCtx *sql.Context, dArgs *diffArgs)
continue
}
// Apply table-level filtering based on diff type
if dArgs.filter != nil && dArgs.filter.filters != nil {
// For data-only changes (no schema/rename), always let them through for row-level filtering
isDataOnlyChange := !delta.SchemaChange && !delta.IsRename() && delta.DataChange
if !isDataOnlyChange && !dArgs.filter.shouldInclude(delta.DiffType) {
continue // Skip this table
}
}
if strings.HasPrefix(delta.ToTableName.Name, diff.DBPrefix) {
verr := diffDatabase(queryist, sqlCtx, delta, dArgs, dw)
if verr != nil {
@@ -1110,9 +1270,9 @@ func diffUserTable(
fromTable := tableSummary.FromTableName
toTable := tableSummary.ToTableName
if dArgs.diffParts&NameOnlyDiff == 0 {
if dArgs.diffParts&NameOnlyDiff == 0 && !shouldUseLazyHeader(dArgs, tableSummary) {
// TODO: schema names
err := dw.BeginTable(tableSummary.FromTableName.Name, tableSummary.ToTableName.Name, tableSummary.IsAdd(), tableSummary.IsDrop())
err := dw.BeginTable(sqlCtx, tableSummary.FromTableName.Name, tableSummary.ToTableName.Name, tableSummary.IsAdd(), tableSummary.IsDrop())
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1176,7 +1336,7 @@ func diffUserTable(
return errhand.BuildDError("cannot retrieve diff stats between '%s' and '%s'", dArgs.fromRef, dArgs.toRef).AddCause(err).Build()
}
err = dw.WriteTableDiffStats(diffStats, fromColLen, toColLen, areTablesKeyless)
err = dw.WriteTableDiffStats(sqlCtx, diffStats, fromColLen, toColLen, areTablesKeyless)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1184,7 +1344,7 @@ func diffUserTable(
}
if dArgs.diffParts&SchemaOnlyDiff != 0 {
err = dw.WriteTableSchemaDiff(fromTableInfo, toTableInfo, tableSummary)
err = dw.WriteTableSchemaDiff(sqlCtx, fromTableInfo, toTableInfo, tableSummary)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1327,7 +1487,7 @@ func diffDatabase(
return nil
}
err := dw.BeginTable(tableSummary.FromTableName.Name, tableSummary.ToTableName.Name, tableSummary.IsAdd(), tableSummary.IsDrop())
err := dw.BeginTable(sqlCtx, tableSummary.FromTableName.Name, tableSummary.ToTableName.Name, tableSummary.IsAdd(), tableSummary.IsDrop())
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1352,7 +1512,7 @@ func diffDatabase(
toTableInfo = &to
}
err = dw.WriteTableSchemaDiff(fromTableInfo, toTableInfo, tableSummary)
err = dw.WriteTableSchemaDiff(sqlCtx, fromTableInfo, toTableInfo, tableSummary)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1446,11 +1606,28 @@ func diffRows(
}
// We always instantiate a RowWriter in case the diffWriter needs it to close off any work from schema output
rowWriter, err := dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, unionSch)
var rowWriter diff.SqlRowDiffWriter
realWriter, err := dw.RowWriter(sqlCtx, fromTableInfo, toTableInfo, tableSummary, unionSch)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
if shouldUseLazyHeader(dArgs, tableSummary) {
// Wrap with lazy writer to delay BeginTable until first row write
onFirstWrite := func() error {
return dw.BeginTable(
sqlCtx,
tableSummary.FromTableName.Name,
tableSummary.ToTableName.Name,
tableSummary.IsAdd(),
tableSummary.IsDrop(),
)
}
rowWriter = newLazyRowWriter(realWriter, onFirstWrite)
} else {
rowWriter = realWriter
}
// can't diff
if !diffable {
// TODO: this messes up some structured output if the user didn't redirect it
@@ -1547,7 +1724,7 @@ func diffRows(
}
// instantiate a new RowWriter with the new schema that only contains the columns with changes
rowWriter, err = dw.RowWriter(fromTableInfo, toTableInfo, tableSummary, filteredUnionSch)
rowWriter, err = dw.RowWriter(sqlCtx, fromTableInfo, toTableInfo, tableSummary, filteredUnionSch)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
@@ -1708,6 +1885,13 @@ func writeDiffResults(
return err
}
// Apply row-level filtering based on diff type
if dArgs.filter != nil {
if shouldSkipRow(dArgs.filter, oldRow.RowDiff) || shouldSkipRow(dArgs.filter, newRow.RowDiff) {
continue
}
}
if dArgs.skinny {
var filteredOldRow, filteredNewRow diff.RowDiff
for i, changeType := range newRow.ColDiffs {
+548
View File
@@ -0,0 +1,548 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package commands
import (
"context"
"strings"
"testing"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
)
func TestDiffTypeFilter_IsValid(t *testing.T) {
tests := []struct {
name string
filterBy string
want bool
}{
{"valid: added", diff.DiffTypeAdded, true},
{"valid: modified", diff.DiffTypeModified, true},
{"valid: removed", diff.DiffTypeDropped, true},
{"valid: all", diff.DiffTypeAll, true},
{"invalid: empty string with nil filter", "", true}, // nil filter is valid
{"invalid: random string", "invalid", false},
{"invalid: uppercase", "ADDED", false},
{"invalid: typo addedd", "addedd", false},
{"invalid: plural adds", "adds", false},
{"invalid: typo modifiedd", "modifiedd", false},
{"invalid: typo removedd", "removedd", false},
{"invalid: insert instead of added", "insert", false},
{"invalid: update instead of modified", "update", false},
{"invalid: delete instead of removed", "delete", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
df := newDiffTypeFilter(tt.filterBy)
got := df.isValid()
if got != tt.want {
t.Errorf("isValid() = %v, want %v", got, tt.want)
}
})
}
}
func TestDiffTypeFilter_ShouldInclude(t *testing.T) {
tests := []struct {
name string
filterType string
checkType string
want bool
}{
// Testing with filter=added
{"filter=added, check added", diff.DiffTypeAdded, diff.DiffTypeAdded, true},
{"filter=added, check modified", diff.DiffTypeAdded, diff.DiffTypeModified, false},
{"filter=added, check removed", diff.DiffTypeAdded, diff.DiffTypeDropped, false},
// Testing with filter=modified
{"filter=modified, check added", diff.DiffTypeModified, diff.DiffTypeAdded, false},
{"filter=modified, check modified", diff.DiffTypeModified, diff.DiffTypeModified, true},
{"filter=modified, check removed", diff.DiffTypeModified, diff.DiffTypeDropped, false},
// Testing with filter=dropped
{"filter=dropped, check added", diff.DiffTypeDropped, diff.DiffTypeAdded, false},
{"filter=dropped, check modified", diff.DiffTypeDropped, diff.DiffTypeModified, false},
{"filter=dropped, check dropped", diff.DiffTypeDropped, diff.DiffTypeDropped, true},
{"filter=dropped, check renamed", diff.DiffTypeDropped, diff.DiffTypeRenamed, false},
// Testing with filter=renamed
{"filter=renamed, check added", diff.DiffTypeRenamed, diff.DiffTypeAdded, false},
{"filter=renamed, check modified", diff.DiffTypeRenamed, diff.DiffTypeModified, false},
{"filter=renamed, check dropped", diff.DiffTypeRenamed, diff.DiffTypeDropped, false},
{"filter=renamed, check renamed", diff.DiffTypeRenamed, diff.DiffTypeRenamed, true},
// Testing with "removed" alias (should map to dropped)
{"filter=removed (alias), check dropped", "removed", diff.DiffTypeDropped, true},
{"filter=removed (alias), check added", "removed", diff.DiffTypeAdded, false},
{"filter=removed (alias), check renamed", "removed", diff.DiffTypeRenamed, false},
// Testing with filter=all
{"filter=all, check added", diff.DiffTypeAll, diff.DiffTypeAdded, true},
{"filter=all, check modified", diff.DiffTypeAll, diff.DiffTypeModified, true},
{"filter=all, check removed", diff.DiffTypeAll, diff.DiffTypeDropped, true},
// Testing with empty filter (nil filters map)
{"filter=empty, check added", "", diff.DiffTypeAdded, true},
{"filter=empty, check modified", "", diff.DiffTypeModified, true},
{"filter=empty, check removed", "", diff.DiffTypeDropped, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
df := newDiffTypeFilter(tt.filterType)
got := df.shouldInclude(tt.checkType)
if got != tt.want {
t.Errorf("shouldInclude(%s) = %v, want %v", tt.checkType, got, tt.want)
}
})
}
}
func TestDiffTypeFilter_ConsistencyAcrossMethods(t *testing.T) {
// Test that filter=all returns true for all diff types
t.Run("filter=all returns true for all types", func(t *testing.T) {
df := newDiffTypeFilter(diff.DiffTypeAll)
if !df.shouldInclude(diff.DiffTypeAdded) {
t.Error("filter=all should include added")
}
if !df.shouldInclude(diff.DiffTypeDropped) {
t.Error("filter=all should include removed")
}
if !df.shouldInclude(diff.DiffTypeModified) {
t.Error("filter=all should include modified")
}
})
// Test that each specific filter only returns true for its type
t.Run("filter=added only includes added", func(t *testing.T) {
df := newDiffTypeFilter(diff.DiffTypeAdded)
if !df.shouldInclude(diff.DiffTypeAdded) {
t.Error("filter=added should include added")
}
if df.shouldInclude(diff.DiffTypeDropped) {
t.Error("filter=added should not include removed")
}
if df.shouldInclude(diff.DiffTypeModified) {
t.Error("filter=added should not include modified")
}
})
t.Run("filter=dropped only includes removed", func(t *testing.T) {
df := newDiffTypeFilter(diff.DiffTypeDropped)
if df.shouldInclude(diff.DiffTypeAdded) {
t.Error("filter=dropped should not include added")
}
if !df.shouldInclude(diff.DiffTypeDropped) {
t.Error("filter=dropped should include removed")
}
if df.shouldInclude(diff.DiffTypeModified) {
t.Error("filter=dropped should not include modified")
}
})
t.Run("filter=modified only includes modified", func(t *testing.T) {
df := newDiffTypeFilter(diff.DiffTypeModified)
if df.shouldInclude(diff.DiffTypeAdded) {
t.Error("filter=modified should not include added")
}
if df.shouldInclude(diff.DiffTypeDropped) {
t.Error("filter=modified should not include removed")
}
if !df.shouldInclude(diff.DiffTypeModified) {
t.Error("filter=modified should include modified")
}
})
}
func TestDiffTypeFilter_InvalidFilterBehavior(t *testing.T) {
// Test that invalid filters return false for isValid
invalidFilters := []string{"invalid", "ADDED", "addedd", "delete"}
for _, filterValue := range invalidFilters {
t.Run("invalid filter: "+filterValue, func(t *testing.T) {
df := newDiffTypeFilter(filterValue)
if df.isValid() {
t.Errorf("Filter %s should be invalid", filterValue)
}
})
}
}
func TestFilterConstants(t *testing.T) {
// Test that filter constants have expected values
tests := []struct {
name string
constant string
expected string
}{
{"DiffTypeAdded value", diff.DiffTypeAdded, "added"},
{"DiffTypeModified value", diff.DiffTypeModified, "modified"},
{"DiffTypeDropped value", diff.DiffTypeDropped, "dropped"},
{"DiffTypeAll value", diff.DiffTypeAll, "all"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.constant != tt.expected {
t.Errorf("Expected %s = %s, got %s", tt.name, tt.expected, tt.constant)
}
})
}
}
func TestFilterConstants_AreUnique(t *testing.T) {
// Test that all filter constants are unique
constants := []string{diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeDropped, diff.DiffTypeAll}
seen := make(map[string]bool)
for _, c := range constants {
if seen[c] {
t.Errorf("Duplicate filter constant value: %s", c)
}
seen[c] = true
}
if len(seen) != 4 {
t.Errorf("Expected 4 unique filter constants, got %d", len(seen))
}
}
func TestFilterConstants_AreLowercase(t *testing.T) {
// Test that filter constants are lowercase (convention)
constants := []string{diff.DiffTypeAdded, diff.DiffTypeModified, diff.DiffTypeDropped, diff.DiffTypeAll}
for _, c := range constants {
if c != strings.ToLower(c) {
t.Errorf("Filter constant %s should be lowercase", c)
}
}
}
func TestShouldUseLazyHeader(t *testing.T) {
tests := []struct {
name string
filterType string
schemaChange bool
isRename bool
expectedResult bool
}{
{
name: "use lazy: filter active, data-only change",
filterType: diff.DiffTypeAdded,
schemaChange: false,
isRename: false,
expectedResult: true,
},
{
name: "don't use lazy: no filter",
filterType: "",
schemaChange: false,
isRename: false,
expectedResult: false,
},
{
name: "don't use lazy: filter is all",
filterType: diff.DiffTypeAll,
schemaChange: false,
isRename: false,
expectedResult: false,
},
{
name: "don't use lazy: schema changed",
filterType: diff.DiffTypeModified,
schemaChange: true,
isRename: false,
expectedResult: false,
},
{
name: "don't use lazy: table renamed",
filterType: diff.DiffTypeDropped,
schemaChange: false,
isRename: true,
expectedResult: false,
},
{
name: "don't use lazy: schema changed AND renamed",
filterType: diff.DiffTypeAdded,
schemaChange: true,
isRename: true,
expectedResult: false,
},
{
name: "use lazy: filter=modified, data-only",
filterType: diff.DiffTypeModified,
schemaChange: false,
isRename: false,
expectedResult: true,
},
{
name: "use lazy: filter=dropped, data-only",
filterType: diff.DiffTypeDropped,
schemaChange: false,
isRename: false,
expectedResult: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var filter *diffTypeFilter
if tt.filterType != "" {
filter = newDiffTypeFilter(tt.filterType)
}
dArgs := &diffArgs{
diffDisplaySettings: &diffDisplaySettings{
filter: filter,
},
}
tableSummary := diff.TableDeltaSummary{
SchemaChange: tt.schemaChange,
}
// Create a mock rename by setting different from/to names
if tt.isRename {
tableSummary.FromTableName = doltdb.TableName{Name: "old_table"}
tableSummary.ToTableName = doltdb.TableName{Name: "new_table"}
} else {
tableSummary.FromTableName = doltdb.TableName{Name: "table"}
tableSummary.ToTableName = doltdb.TableName{Name: "table"}
}
result := shouldUseLazyHeader(dArgs, tableSummary)
if result != tt.expectedResult {
t.Errorf("%s: expected %v, got %v", tt.name, tt.expectedResult, result)
}
})
}
}
// mockDiffWriter is a test implementation of diffWriter
type mockDiffWriter struct {
beginTableCalled bool
beginTableError error
}
var _ diffWriter = (*mockDiffWriter)(nil)
func (m *mockDiffWriter) BeginTable(_ /* ctx */ context.Context, _ /* fromTableName */, _ /* toTableName */ string, _ /* isAdd */, _ /* isDrop */ bool) error {
m.beginTableCalled = true
return m.beginTableError
}
func (m *mockDiffWriter) WriteTableSchemaDiff(_ /* ctx */ context.Context, _ /* fromTableInfo */, _ /* toTableInfo */ *diff.TableInfo, _ /* tds */ diff.TableDeltaSummary) error {
return nil
}
func (m *mockDiffWriter) WriteEventDiff(_ /* ctx */ context.Context, _ /* eventName */, _ /* oldDefn */, _ /* newDefn */ string) error {
return nil
}
func (m *mockDiffWriter) WriteTriggerDiff(_ /* ctx */ context.Context, _ /* triggerName */, _ /* oldDefn */, _ /* newDefn */ string) error {
return nil
}
func (m *mockDiffWriter) WriteViewDiff(_ /* ctx */ context.Context, _ /* viewName */, _ /* oldDefn */, _ /* newDefn */ string) error {
return nil
}
func (m *mockDiffWriter) WriteTableDiffStats(_ /* ctx */ context.Context, _ /* diffStats */ []diffStatistics, _ /* oldColLen */, _ /* newColLen */ int, _ /* areTablesKeyless */ bool) error {
return nil
}
func (m *mockDiffWriter) RowWriter(_ /* ctx */ context.Context, _ /* fromTableInfo */, _ /* toTableInfo */ *diff.TableInfo, _ /* tds */ diff.TableDeltaSummary, _ /* unionSch */ sql.Schema) (diff.SqlRowDiffWriter, error) {
return &mockRowWriter{}, nil
}
func (m *mockDiffWriter) Close(_ /* ctx */ context.Context) error {
return nil
}
// mockRowWriter is a test implementation of SqlRowDiffWriter
type mockRowWriter struct {
writeCalled bool
closeCalled bool
}
func (m *mockRowWriter) WriteRow(_ /* ctx */ *sql.Context, _ /* row */ sql.Row, _ /* diffType */ diff.ChangeType, _ /* colDiffTypes */ []diff.ChangeType) error {
m.writeCalled = true
return nil
}
func (m *mockRowWriter) WriteCombinedRow(_ /* ctx */ *sql.Context, _ /* oldRow */, _ /* newRow */ sql.Row, _ /* mode */ diff.Mode) error {
m.writeCalled = true
return nil
}
func (m *mockRowWriter) Close(_ /* ctx */ context.Context) error {
m.closeCalled = true
return nil
}
func TestLazyRowWriter_NoRowsWritten(t *testing.T) {
mockDW := &mockDiffWriter{}
realWriter := &mockRowWriter{}
onFirstWrite := func() error {
return mockDW.BeginTable(context.Background(), "fromTable", "toTable", false, false)
}
lazyWriter := newLazyRowWriter(realWriter, onFirstWrite)
// Close without writing any rows
err := lazyWriter.Close(context.Background())
if err != nil {
t.Fatalf("Close() returned error: %v", err)
}
// BeginTable should NEVER have been called
if mockDW.beginTableCalled {
t.Error("BeginTable() was called even though no rows were written - should have been lazy!")
}
}
func TestLazyRowWriter_RowsWritten(t *testing.T) {
mockDW := &mockDiffWriter{}
realWriter := &mockRowWriter{}
onFirstWrite := func() error {
return mockDW.BeginTable(context.Background(), "fromTable", "toTable", false, false)
}
lazyWriter := newLazyRowWriter(realWriter, onFirstWrite)
// Write a row
ctx := sql.NewEmptyContext()
err := lazyWriter.WriteRow(ctx, sql.Row{}, diff.Added, []diff.ChangeType{})
if err != nil {
t.Fatalf("WriteRow() returned error: %v", err)
}
// BeginTable should have been called on first write
if !mockDW.beginTableCalled {
t.Error("BeginTable() was NOT called after writing a row - should have been initialized!")
}
// Close
err = lazyWriter.Close(context.Background())
if err != nil {
t.Fatalf("Close() returned error: %v", err)
}
}
func TestLazyRowWriter_CombinedRowsWritten(t *testing.T) {
mockDW := &mockDiffWriter{}
realWriter := &mockRowWriter{}
onFirstWrite := func() error {
return mockDW.BeginTable(context.Background(), "fromTable", "toTable", false, false)
}
lazyWriter := newLazyRowWriter(realWriter, onFirstWrite)
// Write a combined row
ctx := sql.NewEmptyContext()
err := lazyWriter.WriteCombinedRow(ctx, sql.Row{}, sql.Row{}, diff.ModeRow)
if err != nil {
t.Fatalf("WriteCombinedRow() returned error: %v", err)
}
// BeginTable should have been called on first write
if !mockDW.beginTableCalled {
t.Error("BeginTable() was NOT called after writing combined row - should have been initialized!")
}
}
func TestLazyRowWriter_InitializedOnlyOnce(t *testing.T) {
callCount := 0
mockDW := &mockDiffWriter{}
realWriter := &mockRowWriter{}
onFirstWrite := func() error {
callCount++
return mockDW.BeginTable(context.Background(), "fromTable", "toTable", false, false)
}
lazyWriter := newLazyRowWriter(realWriter, onFirstWrite)
ctx := sql.NewEmptyContext()
// Write multiple rows
for i := 0; i < 5; i++ {
err := lazyWriter.WriteRow(ctx, sql.Row{}, diff.Added, []diff.ChangeType{})
if err != nil {
t.Fatalf("WriteRow() %d returned error: %v", i, err)
}
}
// BeginTable should have been called exactly ONCE (on first write only)
if callCount != 1 {
t.Errorf("BeginTable() called %d times, expected exactly 1", callCount)
}
}
func TestShouldSkipRow(t *testing.T) {
tests := []struct {
name string
filterType string
rowChangeType diff.ChangeType
expectedResult bool
}{
{"filter=added, row=Added", diff.DiffTypeAdded, diff.Added, false},
{"filter=added, row=Dropped", diff.DiffTypeAdded, diff.Removed, true},
{"filter=added, row=ModifiedOld", diff.DiffTypeAdded, diff.ModifiedOld, true},
{"filter=added, row=ModifiedNew", diff.DiffTypeAdded, diff.ModifiedNew, true},
{"filter=dropped, row=Added", diff.DiffTypeDropped, diff.Added, true},
{"filter=dropped, row=Dropped", diff.DiffTypeDropped, diff.Removed, false},
{"filter=dropped, row=ModifiedOld", diff.DiffTypeDropped, diff.ModifiedOld, true},
{"filter=modified, row=Added", diff.DiffTypeModified, diff.Added, true},
{"filter=modified, row=Dropped", diff.DiffTypeModified, diff.Removed, true},
{"filter=modified, row=ModifiedOld", diff.DiffTypeModified, diff.ModifiedOld, false},
{"filter=modified, row=ModifiedNew", diff.DiffTypeModified, diff.ModifiedNew, false},
{"filter=all, row=Added", diff.DiffTypeAll, diff.Added, false},
{"filter=all, row=Dropped", diff.DiffTypeAll, diff.Removed, false},
{"filter=all, row=ModifiedOld", diff.DiffTypeAll, diff.ModifiedOld, false},
{"nil filter, row=Added", "", diff.Added, false},
{"nil filter, row=Dropped", "", diff.Removed, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var filter *diffTypeFilter
if tt.filterType != "" {
filter = newDiffTypeFilter(tt.filterType)
}
result := shouldSkipRow(filter, tt.rowChangeType)
if result != tt.expectedResult {
t.Errorf("expected %v, got %v", tt.expectedResult, result)
}
})
}
}
+30 -24
View File
@@ -30,6 +30,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/diff"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtablefunctions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/json"
@@ -41,9 +42,9 @@ import (
// diffWriter is an interface that lets us write diffs in a variety of output formats
type diffWriter interface {
// BeginTable is called when a new table is about to be written, before any schema or row diffs are written
BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error
BeginTable(ctx context.Context, fromTableName, toTableName string, isAdd, isDrop bool) error
// WriteTableSchemaDiff is called to write a schema diff for the table given (if requested by args)
WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error
WriteTableSchemaDiff(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error
// WriteEventDiff is called to write an event diff
WriteEventDiff(ctx context.Context, eventName, oldDefn, newDefn string) error
// WriteTriggerDiff is called to write a trigger diff
@@ -51,10 +52,10 @@ type diffWriter interface {
// WriteViewDiff is called to write a view diff
WriteViewDiff(ctx context.Context, viewName, oldDefn, newDefn string) error
// WriteTableDiffStats is called to write the diff stats for the table given
WriteTableDiffStats(diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error
WriteTableDiffStats(ctx context.Context, diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error
// RowWriter returns a row writer for the table delta provided, which will have Close() called on it when rows are
// done being written.
RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error)
RowWriter(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error)
// Close finalizes the work of the writer
Close(ctx context.Context) error
}
@@ -91,7 +92,7 @@ func (t tabularDiffWriter) Close(ctx context.Context) error {
return nil
}
func (t tabularDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error {
func (t tabularDiffWriter) BeginTable(ctx context.Context, fromTableName, toTableName string, isAdd, isDrop bool) error {
bold := color.New(color.Bold)
if isDrop {
_, _ = bold.Printf("diff --dolt a/%s b/%s\n", fromTableName, fromTableName)
@@ -107,7 +108,7 @@ func (t tabularDiffWriter) BeginTable(fromTableName, toTableName string, isAdd,
return nil
}
func (t tabularDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
func (t tabularDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
var fromCreateStmt = ""
if fromTableInfo != nil {
fromCreateStmt = fromTableInfo.CreateStmt
@@ -141,7 +142,7 @@ func (t tabularDiffWriter) WriteViewDiff(ctx context.Context, viewName, oldDefn,
return nil
}
func (t tabularDiffWriter) WriteTableDiffStats(diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
func (t tabularDiffWriter) WriteTableDiffStats(ctx context.Context, diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
acc := diff.DiffStatProgress{}
eP := cli.NewEphemeralPrinter()
var pos int
@@ -221,7 +222,7 @@ func (t tabularDiffWriter) printKeylessStat(acc diff.DiffStatProgress) {
cli.Printf("%s\n", deletions)
}
func (t tabularDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
func (t tabularDiffWriter) RowWriter(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
return tabular.NewFixedWidthDiffTableWriter(unionSch, iohelp.NopWrCloser(cli.CliOut), 100), nil
}
@@ -233,16 +234,17 @@ func (s sqlDiffWriter) Close(ctx context.Context) error {
return nil
}
func (s sqlDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error {
func (s sqlDiffWriter) BeginTable(ctx context.Context, fromTableName, toTableName string, isAdd, isDrop bool) error {
return nil
}
func (s sqlDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
func (s sqlDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
stmts := tds.AlterStmts
if tds.IsAdd() {
stmts = []string{toTableInfo.CreateStmt}
} else if tds.IsDrop() {
stmts = []string{sqlfmt.DropTableStmt(fromTableInfo.Name)}
formatter := overrides.SchemaFormatterFromContext(ctx)
stmts = []string{sqlfmt.DropTableStmt(formatter, fromTableInfo.Name)}
}
for _, stmt := range stmts {
if len(stmt) == 0 {
@@ -256,12 +258,13 @@ func (s sqlDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.Tab
func (s sqlDiffWriter) WriteEventDiff(ctx context.Context, eventName, oldDefn, newDefn string) error {
// definitions will already be semicolon terminated, no need to add additional ones
formatter := overrides.SchemaFormatterFromContext(ctx)
if oldDefn == "" {
cli.Println(newDefn)
} else if newDefn == "" {
cli.Println(fmt.Sprintf("DROP EVENT %s;", sql.QuoteIdentifier(eventName)))
cli.Println(fmt.Sprintf("DROP EVENT %s;", formatter.QuoteIdentifier(eventName)))
} else {
cli.Println(fmt.Sprintf("DROP EVENT %s;", sql.QuoteIdentifier(eventName)))
cli.Println(fmt.Sprintf("DROP EVENT %s;", formatter.QuoteIdentifier(eventName)))
cli.Println(newDefn)
}
@@ -270,12 +273,13 @@ func (s sqlDiffWriter) WriteEventDiff(ctx context.Context, eventName, oldDefn, n
func (s sqlDiffWriter) WriteTriggerDiff(ctx context.Context, triggerName, oldDefn, newDefn string) error {
// definitions will already be semicolon terminated, no need to add additional ones
formatter := overrides.SchemaFormatterFromContext(ctx)
if oldDefn == "" {
cli.Println(newDefn)
} else if newDefn == "" {
cli.Println(fmt.Sprintf("DROP TRIGGER %s;", sql.QuoteIdentifier(triggerName)))
cli.Println(fmt.Sprintf("DROP TRIGGER %s;", formatter.QuoteIdentifier(triggerName)))
} else {
cli.Println(fmt.Sprintf("DROP TRIGGER %s;", sql.QuoteIdentifier(triggerName)))
cli.Println(fmt.Sprintf("DROP TRIGGER %s;", formatter.QuoteIdentifier(triggerName)))
cli.Println(newDefn)
}
@@ -284,23 +288,24 @@ func (s sqlDiffWriter) WriteTriggerDiff(ctx context.Context, triggerName, oldDef
func (s sqlDiffWriter) WriteViewDiff(ctx context.Context, viewName, oldDefn, newDefn string) error {
// definitions will already be semicolon terminated, no need to add additional ones
formatter := overrides.SchemaFormatterFromContext(ctx)
if oldDefn == "" {
cli.Println(newDefn)
} else if newDefn == "" {
cli.Println(fmt.Sprintf("DROP VIEW %s;", sql.QuoteIdentifier(viewName)))
cli.Println(fmt.Sprintf("DROP VIEW %s;", formatter.QuoteIdentifier(viewName)))
} else {
cli.Println(fmt.Sprintf("DROP VIEW %s;", sql.QuoteIdentifier(viewName)))
cli.Println(fmt.Sprintf("DROP VIEW %s;", formatter.QuoteIdentifier(viewName)))
cli.Println(newDefn)
}
return nil
}
func (s sqlDiffWriter) WriteTableDiffStats(diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
func (s sqlDiffWriter) WriteTableDiffStats(ctx context.Context, diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
return errors.New("invalid output format: sql. SQL format diffs only rendered for schema or data changes")
}
func (s sqlDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
func (s sqlDiffWriter) RowWriter(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
var targetSch schema.Schema
if toTableInfo != nil {
targetSch = toTableInfo.Sch
@@ -345,7 +350,7 @@ func (j *jsonDiffWriter) beginDocumentIfNecessary() error {
return nil
}
func (j *jsonDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, isDrop bool) error {
func (j *jsonDiffWriter) BeginTable(ctx context.Context, fromTableName, toTableName string, isAdd, isDrop bool) error {
err := j.beginDocumentIfNecessary()
if err != nil {
return err
@@ -375,7 +380,7 @@ func (j *jsonDiffWriter) BeginTable(fromTableName, toTableName string, isAdd, is
return err
}
func (j *jsonDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
func (j *jsonDiffWriter) WriteTableSchemaDiff(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary) error {
jsonSchDiffWriter, err := json.NewSchemaDiffWriter(iohelp.NopWrCloser(j.wr))
if err != nil {
return err
@@ -385,7 +390,8 @@ func (j *jsonDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.T
if tds.IsAdd() {
stmts = []string{toTableInfo.CreateStmt}
} else if tds.IsDrop() {
stmts = []string{sqlfmt.DropTableStmt(fromTableInfo.Name)}
formatter := overrides.SchemaFormatterFromContext(ctx)
stmts = []string{sqlfmt.DropTableStmt(formatter, fromTableInfo.Name)}
}
for _, stmt := range stmts {
@@ -401,7 +407,7 @@ func (j *jsonDiffWriter) WriteTableSchemaDiff(fromTableInfo, toTableInfo *diff.T
return jsonSchDiffWriter.Close()
}
func (j *jsonDiffWriter) RowWriter(fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
func (j *jsonDiffWriter) RowWriter(ctx context.Context, fromTableInfo, toTableInfo *diff.TableInfo, tds diff.TableDeltaSummary, unionSch sql.Schema) (diff.SqlRowDiffWriter, error) {
err := iohelp.WriteAll(j.wr, []byte(jsonDiffDataDiffHeader))
if err != nil {
return nil, err
@@ -590,7 +596,7 @@ func (j *jsonDiffWriter) WriteViewDiff(ctx context.Context, viewName, oldDefn, n
const jsonDiffStatsHeader = `"stats":{`
const jsonDiffStatsFooter = `}`
func (j *jsonDiffWriter) WriteTableDiffStats(diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
func (j *jsonDiffWriter) WriteTableDiffStats(ctx context.Context, diffStats []diffStatistics, oldColLen, newColLen int, areTablesKeyless bool) error {
acc := diff.DiffStatProgress{}
for _, diffStat := range diffStats {
acc.Adds += diffStat.RowsAdded
+1 -1
View File
@@ -458,7 +458,7 @@ func dumpViews(ctx *sql.Context, engine *engine.SqlEngine, root doltdb.RootValue
}
// We used to store just the SELECT part of a view, but now we store the entire CREATE VIEW statement
sqlEngine := engine.GetUnderlyingEngine()
binder := planbuilder.New(ctx, sqlEngine.Analyzer.Catalog, sqlEngine.EventScheduler, sqlEngine.Parser)
binder := planbuilder.New(ctx, sqlEngine.Analyzer.Catalog, sqlEngine.EventScheduler)
binder.SetParserOptions(sql.NewSqlModeFromString(sqlMode).ParserOptions())
fragCol, ok, err := sql.Unwrap[string](ctx, row[fragColIdx])
if err != nil {
+4 -3
View File
@@ -86,6 +86,7 @@ type SqlEngineConfig struct {
BinlogReplicaController binlogreplication.BinlogReplicaController
EventSchedulerStatus eventscheduler.SchedulerStatus
BranchActivityTracking bool
EngineOverrides sql.EngineOverrides
}
type SqlEngineConfigOption func(*SqlEngineConfig)
@@ -157,7 +158,7 @@ func NewSqlEngine(
}
b := env.GetDefaultInitBranch(mrEnv.Config())
pro, err := sqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations)
pro, err := sqle.NewDoltDatabaseProviderWithDatabases(b, mrEnv.FileSystem(), all, locations, config.EngineOverrides)
if err != nil {
return nil, err
}
@@ -172,7 +173,7 @@ func NewSqlEngine(
sqlEngine := &SqlEngine{}
// Create the engine
engine := gms.New(analyzer.NewBuilder(pro).Build(), &gms.Config{
engine := gms.New(analyzer.NewBuilder(pro).AddOverrides(config.EngineOverrides).Build(), &gms.Config{
IsReadOnly: config.IsReadOnly,
IsServerLocked: config.IsServerLocked,
}).WithBackgroundThreads(bThreads)
@@ -253,7 +254,7 @@ func NewSqlEngine(
branchActivityTracker := doltdb.NewBranchActivityTracker(ctx, config.BranchActivityTracking)
engine.Analyzer.ExecBuilder = rowexec.NewOverrideBuilder(kvexec.Builder{})
engine.Analyzer.ExecBuilder = rowexec.NewBuilder(kvexec.Builder{}, engine.Analyzer.Overrides)
sessFactory := doltSessionFactory(pro, statsPro, mrEnv.Config(), bcController, gcSafepointController, config.Autocommit, branchActivityTracker)
sqlEngine.provider = pro
sqlEngine.dsessFactory = sessFactory
+1 -1
View File
@@ -351,7 +351,7 @@ func rebaseSqlEngine(ctx context.Context, dEnv *env.DoltEnv, root doltdb.RootVal
}
b := env.GetDefaultInitBranch(dEnv.Config)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS)
pro, err := dsqle.NewDoltDatabaseProviderWithDatabase(b, mrEnv.FileSystem(), db, dEnv.FS, sql.EngineOverrides{})
if err != nil {
return nil, nil, err
}
+219 -15
View File
@@ -15,14 +15,28 @@
package commands
import (
"bytes"
"context"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"sync"
"sync/atomic"
"github.com/fatih/color"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/earl"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/nbs"
"github.com/dolthub/dolt/go/store/types"
)
type FsckCmd struct{}
@@ -30,7 +44,7 @@ type FsckCmd struct{}
var _ cli.Command = FsckCmd{}
func (cmd FsckCmd) Description() string {
return "Verifies the contents of the database are not corrupted."
return "Verifies the contents of the database are not corrupted. Provides repair when possible."
}
var fsckDocs = cli.CommandDocumentationContent{
@@ -38,9 +52,14 @@ var fsckDocs = cli.CommandDocumentationContent{
LongDesc: "Verifies the contents of the database are not corrupted.",
Synopsis: []string{
"[--quiet]",
"--revive-journal-with-data-loss",
},
}
const (
journalReviveFlag = "revive-journal-with-data-loss"
)
func (cmd FsckCmd) Docs() *cli.CommandDocumentation {
return cli.NewCommandDocumentation(fsckDocs, cmd.ArgParser())
}
@@ -48,6 +67,10 @@ func (cmd FsckCmd) Docs() *cli.CommandDocumentation {
func (cmd FsckCmd) ArgParser() *argparser.ArgParser {
ap := argparser.NewArgParserWithMaxArgs(cmd.Name(), 0)
ap.SupportsFlag(cli.QuietFlag, "", "Don't show progress. Just print final report.")
ap.SupportsFlag(journalReviveFlag, "", `Revives a corrupted chunk journal by discarding unparsable data.
WARNING: This may result in data loss. Your original data will be preserved in a backup file. Use this option to restore
the ability to use your Dolt database. Please contact Dolt (https://github.com/dolthub/dolt/issues) for assistance.
`)
return ap
}
@@ -56,6 +79,10 @@ func (cmd FsckCmd) Name() string {
return "fsck"
}
// Exec re-loads the database, and verifies the integrity of all chunks in the local dolt database.
//
// We go to extra effort to load a new database because the default behavior of dolt is to self-heal for some types
// of corruption. For this reason we bypass any cached database and load a fresh one from disk.
func (cmd FsckCmd) Exec(ctx context.Context, commandStr string, args []string, dEnv *env.DoltEnv, _ cli.CliContext) int {
ap := cmd.ArgParser()
apr, _, terminate, status := ParseArgsOrPrintHelp(ap, commandStr, args, fsckDocs)
@@ -63,21 +90,78 @@ func (cmd FsckCmd) Exec(ctx context.Context, commandStr string, args []string, d
return status
}
if apr.Contains(journalReviveFlag) {
return reviveJournalWithDataLoss(dEnv)
}
quiet := apr.Contains(cli.QuietFlag)
progress := make(chan string, 32)
go fsckHandleProgress(ctx, progress, quiet)
// We expect these to work because the database has already been initialized in higher layers. We'll check anyway
// since it's possible something went sideways or this isn't a local database.
exists, isDir := dEnv.FS.Exists(dbfactory.DoltDataDir)
if !exists || !isDir {
cli.PrintErrln(fmt.Sprintf("Dolt data directory not found at %s", dbfactory.DoltDataDir))
return 1
}
var report *doltdb.FSCKReport
absPath, err := dEnv.FS.Abs(dbfactory.DoltDataDir)
if err != nil {
// This should never happen
cli.PrintErrln("Could not get absolute path for dolt data directory:", err.Error())
return 1
}
urlStr := earl.FileUrlFromPath(filepath.ToSlash(absPath), os.PathSeparator)
u, err := url.Parse(urlStr)
if err != nil {
panic(err)
}
var errs []error
params := make(map[string]interface{})
params[dbfactory.ChunkJournalParam] = struct{}{}
dbFact := dbfactory.FileFactory{}
ddb, _, _, err := dbFact.CreateDbNoCache(ctx, types.Format_Default, u, params, func(vErr error) {
errs = append(errs, vErr)
})
if err != nil {
if errors.Is(err, nbs.ErrJournalDataLoss) {
cli.PrintErrln("WARNING: Chunk journal is corrupted and some data may be lost.")
cli.PrintErrln("Run `dolt fsck --revive-journal-with-data-loss` to attempt to recover the journal by")
cli.PrintErrln("discarding invalid data blocks. Your original data will be preserved in a backup file.")
return 1
} else {
cli.PrintErrln(fmt.Sprintf("Could not open dolt database: %s", err.Error()))
}
return 1
}
gs, ok := datas.ChunkStoreFromDatabase(ddb).(*nbs.GenerationalNBS)
if !ok {
// This should never happen. Mainly a protection against future changes.
cli.PrintErrln(fmt.Sprintf("runtime error: FSCK requires *nbs.GenerationalNBS chunk store. Got: %T", datas.ChunkStoreFromDatabase(ddb)))
return 1
}
progress := make(chan string, 32)
done := make(chan struct{})
go func() {
fsckHandleProgress(ctx, progress, quiet)
close(done)
}()
var report *FSCKReport
terminate = func() bool {
defer close(progress)
var err error
report, err = dEnv.DoltDB(ctx).FSCK(ctx, progress)
report, err = fsckOnChunkStore(ctx, gs, errs, progress)
if err != nil {
// When FSCK errors, it's unexpected. As in corruption can be found and we shouldn't get an error here.
// So we print the error and not the report.
cli.PrintErrln(err.Error())
return true
}
// skip printing the report is we were cancelled. Most likely we tripped on the error above first.
// skip printing the report if we were cancelled.
select {
case <-ctx.Done():
cli.PrintErrln(ctx.Err().Error())
@@ -86,6 +170,9 @@ func (cmd FsckCmd) Exec(ctx context.Context, commandStr string, args []string, d
return false
}
}()
// Wait for fsckHandleProgress to finish processing all messages
<-done
if terminate {
return 1
}
@@ -93,7 +180,26 @@ func (cmd FsckCmd) Exec(ctx context.Context, commandStr string, args []string, d
return printFSCKReport(report)
}
func printFSCKReport(report *doltdb.FSCKReport) int {
func reviveJournalWithDataLoss(dEnv *env.DoltEnv) int {
root, err := dEnv.FS.Abs("")
if err != nil {
cli.PrintErrln("Could not get absolute path for dolt data directory:", err.Error())
return 1
}
noms := filepath.Join(root, ".dolt", "noms")
path, err := nbs.ReviveJournalWithDataLoss(noms)
if err != nil {
cli.PrintErrln("Could not revive chunk journal:", err.Error())
return 1
}
cli.Printf("Revived chunk journal at:\n%s\n", path)
cli.Printf("For assistance recovering data, please file a ticket: https://github.com/dolthub/dolt/issues\n")
return 0
}
func printFSCKReport(report *FSCKReport) int {
cli.Printf("Chunks Scanned: %d\n", report.ChunkCount)
if len(report.Problems) == 0 {
cli.Println("No problems found.")
@@ -108,15 +214,113 @@ func printFSCKReport(report *doltdb.FSCKReport) int {
}
}
func fsckHandleProgress(ctx context.Context, progress chan string, quiet bool) {
func fsckHandleProgress(ctx context.Context, progress <-chan string, quiet bool) {
for item := range progress {
if !quiet {
// when ctx is canceled, keep draining but stop printing
if !quiet && ctx.Err() == nil {
cli.Println(item)
}
select {
case <-ctx.Done():
return
default:
}
}
}
type FSCKReport struct {
ChunkCount uint32
Problems []error
}
// FSCK performs a full file system check on the database. This is currently exposed with the CLI as `dolt fsck`
// The success or failure of the scan are returned in the report as a list of errors. The error returned by this function
// indicates a deeper issue such as an inability to read from the underlying storage at all.
func fsckOnChunkStore(ctx context.Context, gs *nbs.GenerationalNBS, errs []error, progress chan string) (*FSCKReport, error) {
chunkCount, err := gs.OldGen().Count()
if err != nil {
return nil, err
}
chunkCount2, err := gs.NewGen().Count()
if err != nil {
return nil, err
}
chunkCount += chunkCount2
proccessedCnt := int64(0)
vs := types.NewValueStore(gs)
decodeMsg := func(chk chunks.Chunk) string {
hrs := ""
val, err := types.DecodeValue(chk, vs)
if err == nil {
hrs = val.HumanReadableString()
} else {
hrs = fmt.Sprintf("Unable to decode value: %s", err.Error())
}
return hrs
}
// Append safely to the slice of errors with a mutex.
errsLock := &sync.Mutex{}
appendErr := func(err error) {
errsLock.Lock()
defer errsLock.Unlock()
errs = append(errs, err)
}
// Callback for validating chunks. This code could be called concurrently, though that is not currently the case.
validationCallback := func(chunk chunks.Chunk) {
chunkOk := true
pCnt := atomic.AddInt64(&proccessedCnt, 1)
h := chunk.Hash()
raw := chunk.Data()
calcChkSum := hash.Of(raw)
if h != calcChkSum {
fuzzyMatch := false
// Special case for the journal chunk source. We may have an address which has 4 null bytes at the end.
if h[hash.ByteLen-1] == 0 && h[hash.ByteLen-2] == 0 && h[hash.ByteLen-3] == 0 && h[hash.ByteLen-4] == 0 {
// Now we'll just verify that the first 16 bytes match.
ln := hash.ByteLen - 4
fuzzyMatch = bytes.Compare(h[:ln], calcChkSum[:ln]) == 0
}
if !fuzzyMatch {
hrs := decodeMsg(chunk)
appendErr(errors.New(fmt.Sprintf("Chunk: %s content hash mismatch: %s\n%s", h.String(), calcChkSum.String(), hrs)))
chunkOk = false
}
}
if chunkOk {
// Round trip validation. Ensure that the top level store returns the same data.
c, err := gs.Get(ctx, h)
if err != nil {
appendErr(errors.New(fmt.Sprintf("Chunk: %s load failed with error: %s", h.String(), err.Error())))
chunkOk = false
} else if bytes.Compare(raw, c.Data()) != 0 {
hrs := decodeMsg(chunk)
appendErr(errors.New(fmt.Sprintf("Chunk: %s read with incorrect ID: %s\n%s", h.String(), c.Hash().String(), hrs)))
chunkOk = false
}
}
percentage := (float64(pCnt) * 100) / float64(chunkCount)
result := fmt.Sprintf("(%4.1f%% done)", percentage)
progStr := "OK: " + h.String()
if !chunkOk {
progStr = "FAIL: " + h.String()
}
progStr = result + " " + progStr
progress <- progStr
}
err = gs.OldGen().IterateAllChunks(ctx, validationCallback)
if err != nil {
return nil, err
}
err = gs.NewGen().IterateAllChunks(ctx, validationCallback)
if err != nil {
return nil, err
}
FSCKReport := FSCKReport{Problems: errs, ChunkCount: chunkCount}
return &FSCKReport, nil
}
+12 -3
View File
@@ -53,6 +53,7 @@ The second syntax ({{.LessThan}}dolt merge --abort{{.GreaterThan}}) can only be
Synopsis: []string{
"[--squash] {{.LessThan}}branch{{.GreaterThan}}",
"--no-ff [-m message] {{.LessThan}}branch{{.GreaterThan}}",
"--ff-only {{.LessThan}}branch{{.GreaterThan}}",
"--abort",
},
}
@@ -197,6 +198,12 @@ func validateDoltMergeArgs(apr *argparser.ArgParseResults, usage cli.UsagePrinte
if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) {
return HandleVErrAndExitCode(errhand.BuildDError(ErrConflictingFlags, cli.SquashParam, cli.NoFFParam).Build(), usage)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.NoFFParam) {
return HandleVErrAndExitCode(errhand.BuildDError(ErrConflictingFlags, cli.FFOnlyParam, cli.NoFFParam).Build(), usage)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.SquashParam) {
return HandleVErrAndExitCode(errhand.BuildDError(ErrConflictingFlags, cli.FFOnlyParam, cli.SquashParam).Build(), usage)
}
// This command may create a commit, so we need user identity
if !cli.CheckUserNameAndEmail(cliCtx.Config()) {
@@ -266,6 +273,8 @@ func constructInterpolatedDoltMergeQuery(apr *argparser.ArgParseResults, cliCtx
params = append(params, apr.Arg(0))
} else if apr.Contains(cli.NoFFParam) {
writeToBuffer("--no-ff", false)
} else if apr.Contains(cli.FFOnlyParam) {
writeToBuffer("--ff-only", false)
} else if apr.Contains(cli.AbortParam) {
writeToBuffer("--abort", false)
}
@@ -480,17 +489,17 @@ func calculateMergeStats(queryist cli.Queryist, sqlCtx *sql.Context, mergeStats
if strings.HasPrefix(summary.TableName.Name, diff.DBPrefix) {
continue
}
if summary.DiffType == "added" {
if summary.DiffType == diff.DiffTypeAdded {
allUnmodified = false
mergeStats[summary.TableName.Name] = &merge.MergeStats{
Operation: merge.TableAdded,
}
} else if summary.DiffType == "dropped" {
} else if summary.DiffType == diff.DiffTypeDropped {
allUnmodified = false
mergeStats[summary.TableName.Name] = &merge.MergeStats{
Operation: merge.TableRemoved,
}
} else if summary.DiffType == "modified" || summary.DiffType == "renamed" {
} else if summary.DiffType == diff.DiffTypeModified || summary.DiffType == diff.DiffTypeRenamed {
allUnmodified = false
mergeStats[summary.TableName.Name] = &merge.MergeStats{
Operation: merge.TableModified,
+11
View File
@@ -91,6 +91,14 @@ func (cmd PullCmd) Exec(ctx context.Context, commandStr string, args []string, d
verr := errhand.VerboseErrorFromError(errors.New(fmt.Sprintf(ErrConflictingFlags, cli.SquashParam, cli.NoFFParam)))
return HandleVErrAndExitCode(verr, usage)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.NoFFParam) {
verr := errhand.VerboseErrorFromError(errors.New(fmt.Sprintf(ErrConflictingFlags, cli.FFOnlyParam, cli.NoFFParam)))
return HandleVErrAndExitCode(verr, usage)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.SquashParam) {
verr := errhand.VerboseErrorFromError(errors.New(fmt.Sprintf(ErrConflictingFlags, cli.FFOnlyParam, cli.SquashParam)))
return HandleVErrAndExitCode(verr, usage)
}
// This command may create a commit, so we need user identity
if !cli.CheckUserNameAndEmail(cliCtx.Config()) {
bdr := errhand.BuildDError("Could not determine name and/or email.")
@@ -253,6 +261,9 @@ func constructInterpolatedDoltPullQuery(apr *argparser.ArgParseResults) (string,
if apr.Contains(cli.NoFFParam) {
args = append(args, "'--no-ff'")
}
if apr.Contains(cli.FFOnlyParam) {
args = append(args, "'--ff-only'")
}
if apr.Contains(cli.ForceFlag) {
args = append(args, "'--force'")
}
@@ -21,12 +21,13 @@ import (
"strings"
"time"
"github.com/dolthub/dolt/go/libraries/doltcore/servercfg"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/commands"
"github.com/dolthub/dolt/go/libraries/doltcore/servercfg"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
)
type commandLineServerConfig struct {
@@ -49,6 +50,7 @@ type commandLineServerConfig struct {
tlsKey string
tlsCert string
caCert string
requireClientCert bool
requireSecureTransport bool
maxLoggedQueryLen int
shouldEncodeLoggedQuery bool
@@ -321,6 +323,13 @@ func (cfg *commandLineServerConfig) CACert() string {
return cfg.caCert
}
// RequireClientCert is true if the server should reject any connections that don't present a certificate. When
// enabled, a client certificate is always required, and if a CA cert is also configured, then the client cert
// will also be verified. Enabling this option also means that non-TLS connections are not allowed.
func (cfg *commandLineServerConfig) RequireClientCert() bool {
return cfg.requireClientCert
}
// RequireSecureTransport is true if the server should reject non-TLS connections.
func (cfg *commandLineServerConfig) RequireSecureTransport() bool {
return cfg.requireSecureTransport
@@ -360,6 +369,26 @@ func (cfg *commandLineServerConfig) MetricsPort() int {
return servercfg.DefaultMetricsPort
}
func (cfg *commandLineServerConfig) MetricsTLSCert() string {
return ""
}
func (cfg *commandLineServerConfig) MetricsTLSKey() string {
return ""
}
func (cfg *commandLineServerConfig) MetricsTLSCA() string {
return ""
}
func (cfg *commandLineServerConfig) MetricsJwksConfig() *servercfg.JwksConfig {
return nil
}
func (cfg *commandLineServerConfig) MetricsJWTRequiredForLocalhost() bool {
return false
}
func (cfg *commandLineServerConfig) RemotesapiPort() *int {
return cfg.remotesapiPort
}
@@ -575,6 +604,10 @@ func (cfg *commandLineServerConfig) AutoGCBehavior() servercfg.AutoGCBehavior {
return stubAutoGCBehavior{}
}
func (cfg *commandLineServerConfig) Overrides() sql.EngineOverrides {
return sql.EngineOverrides{}
}
// DoltServerConfigReader is the default implementation of ServerConfigReader suitable for parsing Dolt config files
// and command line options.
type DoltServerConfigReader struct{}
@@ -0,0 +1,91 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlserver
import (
"errors"
"fmt"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/dolthub/dolt/go/libraries/doltcore/servercfg"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
)
func validateJWT(jwksConfig *servercfg.JwksConfig, token string, reqTime time.Time) (bool, *jwtauth.Claims, error) {
if jwksConfig == nil {
return false, nil, errors.New("ValidateJWT: JWKS metrics config not found")
}
pr, err := getJWTProvider(jwksConfig.Claims, jwksConfig.LocationUrl)
if err != nil {
return false, nil, fmt.Errorf("unable to get JWT provider: %w", err)
}
vd, err := jwtauth.NewJWTValidator(pr)
if err != nil {
return false, nil, fmt.Errorf("unable to get JWT validator: %w", err)
}
privClaims, err := vd.ValidateJWT(token, reqTime)
if err != nil {
return false, nil, fmt.Errorf("unable to validate JWT token: %w", err)
}
if pr.Subject != privClaims.Subject {
return false, nil, fmt.Errorf("JWT token subject does not match subject claim")
}
var keyValPairs []string
for _, field := range jwksConfig.FieldsToLog {
keyValPairs = append(keyValPairs, fmt.Sprintf("'%s': '%s'", field, getClaimFromKey(privClaims, field)))
}
logrus.Info("Metrics Auth with JWT: " + strings.Join(keyValPairs, ", "))
return true, privClaims, nil
}
func getClaimFromKey(claims *jwtauth.Claims, field string) string {
switch field {
case "id":
return claims.ID
case "iss":
return claims.Issuer
case "sub":
return claims.Subject
case "on_behalf_of":
return claims.OnBehalfOf
}
return ""
}
func getJWTProvider(expectedClaimsMap map[string]string, url string) (jwtauth.JWTProvider, error) {
pr := jwtauth.JWTProvider{URL: url}
for name, claim := range expectedClaimsMap {
switch name {
case "iss":
pr.Issuer = claim
case "aud":
pr.Audience = claim
case "sub":
pr.Subject = claim
default:
return pr, errors.New("ValidateJWT: Unsupported claim found in user identity")
}
}
return pr, nil
}
@@ -16,6 +16,7 @@ package sqlserver
import (
"context"
"crypto/tls"
sql2 "database/sql"
"fmt"
"io"
@@ -36,9 +37,33 @@ import (
"github.com/dolthub/dolt/go/libraries/utils/filesys"
)
type QueryistTLSMode int
const (
QueryistTLSMode_Disabled QueryistTLSMode = iota
// Require TLS, verify the server certificate using the system
// trust store, do not allow fallback to plaintext.
//
// Used for `dolt --host ... sql ...` when `--no-tls-` is not
// specified. Often used for connecting to Hosted DoltDB
// instances using the CLI commands posted on
// hosted.doltdb.com.
QueryistTLSMode_Enabled
// Used for local Dolt CLI queryist connecting to the running
// local server. In this mode, TLS is allowed but not required
// and the client does not verify the remote TLS
// certificate. It is assumed connecting to the port locally
// is secure and lands the client in the correct place, given
// the contents of sql-server.info, for example.
//
// This mode still does not allow the Dolt CLI to connect to a
// server which requires a client certificate.
QueryistTLSMode_NoVerify_FallbackToPlaintext
)
// BuildConnectionStringQueryist returns a Queryist that connects to the server specified by the given server config. Presence in this
// module isn't ideal, but it's the only way to get the server config into the queryist.
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, useTLS bool, dbRev string) (cli.LateBindQueryist, error) {
func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, creds *cli.UserPassword, apr *argparser.ArgParseResults, host string, port int, tlsMode QueryistTLSMode, dbRev string) (cli.LateBindQueryist, error) {
clientConfig, err := GetClientConfig(cwdFS, creds, apr)
if err != nil {
return nil, err
@@ -54,8 +79,13 @@ func BuildConnectionStringQueryist(ctx context.Context, cwdFS filesys.Filesys, c
parsedMySQLConfig.DBName = dbRev
parsedMySQLConfig.Addr = fmt.Sprintf("%s:%d", host, port)
if useTLS {
parsedMySQLConfig.TLSConfig = "true"
switch tlsMode {
case QueryistTLSMode_Disabled:
case QueryistTLSMode_Enabled:
parsedMySQLConfig.TLS = &tls.Config{}
case QueryistTLSMode_NoVerify_FallbackToPlaintext:
parsedMySQLConfig.TLS = &tls.Config{InsecureSkipVerify: true}
parsedMySQLConfig.AllowFallbackToPlaintext = true
}
mysqlConnector, err := mysql.NewConnector(parsedMySQLConfig)
+66 -7
View File
@@ -58,6 +58,7 @@ import (
"github.com/dolthub/dolt/go/libraries/events"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
httputils "github.com/dolthub/dolt/go/libraries/utils/http"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
"github.com/dolthub/dolt/go/store/chunks"
eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1"
@@ -251,6 +252,7 @@ func ConfigureServices(
ClusterController: clusterController,
BinlogReplicaController: binlogreplication.DoltBinlogReplicaController,
SkipRootUserInitialization: cfg.SkipRootUserInit,
EngineOverrides: cfg.ServerConfig.Overrides(),
}
return nil
},
@@ -601,7 +603,6 @@ func ConfigureServices(
}
var metSrv SQLMetricsService
RunMetricsServer := &svcs.AnonService{
InitF: func(context.Context) (err error) {
if cfg.ServerConfig.MetricsHost() != "" && cfg.ServerConfig.MetricsPort() > 0 {
@@ -613,19 +614,72 @@ func ConfigureServices(
return err
}
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
metSrv.srv = &http.Server{
Addr: addr,
Handler: mux,
tlsConfig, err := servercfg.LoadMetricsTLSConfig(cfg.ServerConfig)
if err != nil {
return err
}
mux := http.NewServeMux()
metricsHandler := promhttp.Handler()
jwksConfig := cfg.ServerConfig.MetricsJwksConfig()
enableMetricsAuth := jwksConfig != nil
requireLocalhostAuth := cfg.ServerConfig.MetricsJWTRequiredForLocalhost()
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s, require_localhost_auth = %t", enableMetricsAuth, addr, requireLocalhostAuth)
if enableMetricsAuth {
mux.Handle("/metrics", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !requireLocalhostAuth {
isLocal, err := httputils.IsLocalRequest(r)
logrus.Info("Metrics JWT not required for localhost isLocal:", isLocal, "err:", err)
if err != nil {
logrus.Warnf("error checking if request is local for /metrics (assuming remote) request: %v.", err)
} else if isLocal {
metricsHandler.ServeHTTP(w, r)
return
}
}
auth := r.Header.Get("Authorization")
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`)
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
valid, _, err := validateJWT(jwksConfig, strings.TrimPrefix(auth, "Bearer "), time.Now())
if err != nil {
logrus.Warnf("JWT validation error for /metrics: %v", err)
http.Error(w, "auth failed", http.StatusUnauthorized)
return
} else if !valid {
logrus.Warnf("JWT validation error for /metrics: JWT token is invalid")
http.Error(w, "invalid token", http.StatusUnauthorized)
return
}
metricsHandler.ServeHTTP(w, r)
}))
} else {
mux.Handle("/metrics", metricsHandler)
}
metSrv.srv = &http.Server{
Addr: addr,
Handler: mux,
TLSConfig: tlsConfig,
}
}
return nil
},
RunF: func(context.Context) {
if metSrv.state.CompareAndSwap(svcs.ServiceState_Init, svcs.ServiceState_Run) {
_ = metSrv.srv.Serve(metSrv.lis)
if metSrv.srv.TLSConfig != nil {
_ = metSrv.srv.ServeTLS(metSrv.lis, "", "")
} else {
_ = metSrv.srv.Serve(metSrv.lis)
}
}
},
StopF: func() error {
@@ -1157,6 +1211,11 @@ func getConfigFromServerConfig(serverConfig servercfg.ServerConfig, plf server.P
serverConf.EncodeLoggedQuery = serverConfig.ShouldEncodeLoggedQuery()
serverConf.ProtocolListenerFactory = plf
// If client certs are required, then TLS connections are implicitly required
if serverConfig.RequireClientCert() {
serverConf.RequireSecureTransport = true
}
return serverConf, nil
}
@@ -692,6 +692,9 @@ branch_control_file: dir1/dir2/abc.db
# labels: {}
# host: localhost
# port: 9091
# tls_cert: ""
# tls_key: ""
# tls_ca: ""
# cluster:
# standby_remotes:
+9 -5
View File
@@ -117,15 +117,19 @@ SUPPORTED CONFIG FILE FIELDS:
{{.EmphasisLeft}}listener.max_wait_connections_timeout{{.EmphasisRight}}: The maximum amount of time that a connection will block waiting for a connection before being rejected.
{{.EmphasisLeft}}listener.read_timeout_millis{{.EmphasisRight}}: The number of milliseconds that the server will wait for a read operation
{{.EmphasisLeft}}listener.read_timeout_millis{{.EmphasisRight}}: The number of milliseconds that the server will wait for a read operation.
{{.EmphasisLeft}}listener.write_timeout_millis{{.EmphasisRight}}: The number of milliseconds that the server will wait for a write operation
{{.EmphasisLeft}}listener.write_timeout_millis{{.EmphasisRight}}: The number of milliseconds that the server will wait for a write operation.
{{.EmphasisLeft}}listener.require_secure_transport{{.EmphasisRight}}: Boolean flag to turn on TLS/SSL transport
{{.EmphasisLeft}}listener.require_secure_transport{{.EmphasisRight}}: Boolean flag to turn on TLS/SSL transport.
{{.EmphasisLeft}}listener.tls_cert{{.EmphasisRight}}: The path to the TLS certicifcate used for secure transport
{{.EmphasisLeft}}listener.require_client_cert{{.EmphasisRight}}: Boolean flag to require all connections present a certificate. This implies that all connections must be over TLS, so listener.tls_key and listener.tls_cert must also be set.
{{.EmphasisLeft}}listener.tls_key{{.EmphasisRight}}: The path to the TLS key used for secure transport
{{.EmphasisLeft}}listener.ca_cert{{.EmphasisRight}}: The path to a Certificate Authority (CA) certificate used to validate client certificates.
{{.EmphasisLeft}}listener.tls_cert{{.EmphasisRight}}: The path to the TLS certificate used for secure transport.
{{.EmphasisLeft}}listener.tls_key{{.EmphasisRight}}: The path to the TLS key used for secure transport.
{{.EmphasisLeft}}remotesapi.port{{.EmphasisRight}}: A port to listen for remote API operations on. If set to a positive integer, this server will accept connections from clients to clone, pull, etc. databases being served.
+2
View File
@@ -361,6 +361,8 @@ func GetTinyIntColAsBool(col interface{}) (bool, error) {
switch v := col.(type) {
case bool:
return v, nil
case byte:
return v == 1, nil
case int:
return v == 1, nil
case string:
+47 -20
View File
@@ -70,7 +70,6 @@ var dumpZshCommand = &commands.GenZshCompCmd{}
var commandsWithoutCliCtx = []cli.Command{
commands.CloneCmd{},
commands.BackupCmd{},
commands.LoginCmd{},
credcmds.Commands,
cvcmds.Commands,
@@ -115,6 +114,14 @@ var commandsWithoutCurrentDirWrites = []cli.Command{
commands.ProfileCmd{},
}
// commands that specifically skip the env.MultiEnvForDirectory loading step. These commands work in a context where
// we expect the database to fail loading, but we need to get through to the Exec call anyway. The dEnv created with LoadWithoutDB
// will be passed to these commands, so we can determine where the data root is, but commands will need to load the database
// on terms that make sense for their purpose.
var commandsSkippingDBLoad = []cli.Command{
commands.FsckCmd{},
}
func initCliContext(commandName string) bool {
for _, command := range commandsWithoutCliCtx {
if command.Name() == commandName {
@@ -142,6 +149,15 @@ func needsWriteAccess(commandName string) bool {
return true
}
func needsDBLoad(commandName string) bool {
for _, command := range commandsSkippingDBLoad {
if command.Name() == commandName {
return false
}
}
return true
}
var doltCommand = doltcmd.DoltCommand
var globalArgParser = cli.CreateGlobalArgParser("dolt")
var globalDocs = cli.CommandDocsForCommandString("dolt", doc, globalArgParser)
@@ -472,21 +488,23 @@ func runMain() int {
// will be lost. This is particularly confusing for database specific system
// variables like `${db_name}_default_branch` (maybe these should not be
// part of Dolt config in the first place!).
var mrEnv *env.MultiRepoEnv
if needsDBLoad(cfg.subCommand) {
mrEnv, err = env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), cfg.dataDirFS, dEnv.Version, dEnv)
if err != nil {
cli.PrintErrln("failed to load database names")
return 1
}
_ = mrEnv.Iter(func(dbName string, dEnv *env.DoltEnv) (stop bool, err error) {
dsess.DefineSystemVariablesForDB(dbName)
return false, nil
})
mrEnv, err := env.MultiEnvForDirectory(ctx, dEnv.Config.WriteableConfig(), cfg.dataDirFS, dEnv.Version, dEnv)
if err != nil {
cli.PrintErrln("failed to load database names")
return 1
}
_ = mrEnv.Iter(func(dbName string, dEnv *env.DoltEnv) (stop bool, err error) {
dsess.DefineSystemVariablesForDB(dbName)
return false, nil
})
// TODO: we set persisted vars here, and this should be deferred until after we know what command line arguments might change them
err = dsess.InitPersistedSystemVars(dEnv)
if err != nil {
cli.Printf("error: failed to load persisted global variables: %s\n", err.Error())
// TODO: we set persisted vars here, and this should be deferred until after we know what command line arguments might change them
err = dsess.InitPersistedSystemVars(dEnv)
if err != nil {
cli.Printf("error: failed to load persisted global variables: %s\n", err.Error())
}
}
var cliCtx cli.CliContext = nil
@@ -623,8 +641,11 @@ If you're interested in running this command against a remote host, hit us up on
if !hasPort {
port = 3306
}
useTLS := !apr.Contains(cli.NoTLSFlag)
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, useTLS, useDb)
tlsMode := sqlserver.QueryistTLSMode_Enabled
if apr.Contains(cli.NoTLSFlag) {
tlsMode = sqlserver.QueryistTLSMode_Disabled
}
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, host, port, tlsMode, useDb)
} else {
_, hasPort := apr.GetInt(cli.PortFlag)
if hasPort {
@@ -659,8 +680,14 @@ If you're interested in running this command against a remote host, hit us up on
//
// This is also allowed when --help is passed. So we defer the error
// until the caller tries to use the cli.LateBindQueryist.
isValidRepositoryRequired := subcommandName != "init" && subcommandName != "sql" && subcommandName != "sql-server" && subcommandName != "sql-client"
commandsNotRequiringRepo := map[string]bool{
"init": true,
"sql": true,
"sql-server": true,
"sql-client": true,
commands.DoltBackupCommandName: true,
}
isValidRepositoryRequired := !commandsNotRequiringRepo[subcommandName]
if noValidRepository && isValidRepositoryRequired {
return func(ctx context.Context, opts ...cli.LateBindQueryistOption) (res cli.LateBindQueryistResult, err error) {
err = errors.New("The current directory is not a valid dolt repository.")
@@ -712,7 +739,7 @@ If you're interested in running this command against a remote host, hit us up on
if !creds.Specified {
creds = &cli.UserPassword{Username: sqlserver.LocalConnectionUser, Password: localCreds.Secret, Specified: false}
}
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", localCreds.Port, false, useDb)
return sqlserver.BuildConnectionStringQueryist(ctx, cwdFS, creds, apr, "localhost", localCreds.Port, sqlserver.QueryistTLSMode_NoVerify_FallbackToPlaintext, useDb)
}
}
+1 -1
View File
@@ -15,5 +15,5 @@
package doltversion
const (
Version = "1.78.1"
Version = "1.79.2"
)
+3 -3
View File
@@ -10,10 +10,10 @@ require (
github.com/bcicen/jstream v1.0.0
github.com/boltdb/bolt v1.3.1
github.com/denisbrodbeck/machineid v1.0.1
github.com/dolthub/fslock v0.0.3
github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20251107003339-843d10a6a8d4
github.com/dolthub/vitess v0.0.0-20251210200925-1d33d416d162
github.com/dustin/go-humanize v1.0.1
github.com/fatih/color v1.13.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
@@ -61,7 +61,7 @@ require (
github.com/dolthub/dolt-mcp v0.2.2
github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.20.1-0.20251118232608-f06d88560cc2
github.com/dolthub/go-mysql-server v0.20.1-0.20260105202743-1b9a4010ea84
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
github.com/edsrzf/mmap-go v1.2.0
github.com/esote/minmaxheap v1.0.0
+6 -6
View File
@@ -191,12 +191,12 @@ github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca h1:BGFz/0
github.com/dolthub/eventsapi_schema v0.0.0-20250915094920-eadfd39051ca/go.mod h1:CoDLfgPqHyBtth0Cp+fi/CmC4R81zJNX4wPjShdZ+Bw=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318 h1:n+vdH5G5Db+1qnDCpRjSQMxlTewwvTzKuuq0nJm0AqI=
github.com/dolthub/fslock v0.0.0-20251215194149-ef20baba2318/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790 h1:zxMsH7RLiG+dlZ/y0LgJHTV26XoiSJcuWq+em6t6VVc=
github.com/dolthub/go-icu-regex v0.0.0-20250916051405-78a38d478790/go.mod h1:F3cnm+vMRK1HaU6+rNqQrOCyR03HHhR1GWG2gnPOqaE=
github.com/dolthub/go-mysql-server v0.20.1-0.20251118232608-f06d88560cc2 h1:FGVjkA6MdVZcMD4dNVQhXhTN/nMJc+0tqBGUTvlWb/g=
github.com/dolthub/go-mysql-server v0.20.1-0.20251118232608-f06d88560cc2/go.mod h1:HTOKSMPJWcbSgCe1DksDgNPlZyZP1usV+EoA7Utax+A=
github.com/dolthub/go-mysql-server v0.20.1-0.20260105202743-1b9a4010ea84 h1:5Mkyt+kQbSr7PmQpgkkIthtbSR5+OmC+NhLTnWa7eAY=
github.com/dolthub/go-mysql-server v0.20.1-0.20260105202743-1b9a4010ea84/go.mod h1:NjewWKoa5bVSLdKwL7fg7eAfrcIxDybWUKoWEHWRTw4=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE=
@@ -205,8 +205,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20251107003339-843d10a6a8d4 h1:vOF5qPLC0Yd4BN/FKJlRLNELIZZlev40TrckORQqzhA=
github.com/dolthub/vitess v0.0.0-20251107003339-843d10a6a8d4/go.mod h1:FLWqdXsAeeBQyFwDjmBVu0GnbjI2MKeRf3tRVdJEKlI=
github.com/dolthub/vitess v0.0.0-20251210200925-1d33d416d162 h1:6RW2VpUs/cUFdvk4mXSmJfQZLs9wJABVjke3CHGJBcs=
github.com/dolthub/vitess v0.0.0-20251210200925-1d33d416d162/go.mod h1:FLWqdXsAeeBQyFwDjmBVu0GnbjI2MKeRf3tRVdJEKlI=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
+22 -7
View File
@@ -135,6 +135,27 @@ func (fact FileFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat,
return s.ddb, s.vrw, s.ns, nil
}
ddb, vrw, ns, err := fact.CreateDbNoCache(ctx, nbf, urlObj, params, nbs.JournalParserLoggingWarningsCb)
if err != nil {
return nil, nil, nil, err
}
singletons[urlObj.Path] = singletonDB{
ddb: ddb,
vrw: vrw,
ns: ns,
}
return ddb, vrw, ns, nil
}
// CreateDbNoCache creates a local filesys backed database without using the singleton cache. This is used for a very specific
// case: the `dolt fsck` command. Since database loading happens before subcommand execution, and `dolt fsck` needs to report
// journal issues, it needs to load the database without simply printing an error to the log for journal issues.
//
// Furthermore, regular database loading uses this code path to construct the GenerationalCS, which is desired because we
// want the same underlying implementation.
func (fact FileFactory) CreateDbNoCache(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}, recCb func(error)) (datas.Database, types.ValueReadWriter, tree.NodeStore, error) {
path, err := url.PathUnescape(urlObj.Path)
if err != nil {
return nil, nil, nil, err
@@ -158,7 +179,7 @@ func (fact FileFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat,
var newGenSt *nbs.NomsBlockStore
q := nbs.NewUnlimitedMemQuotaProvider()
if useJournal && chunkJournalFeatureFlag {
newGenSt, err = nbs.NewLocalJournalingStore(ctx, nbf.VersionString(), path, q, mmapArchiveIndexes)
newGenSt, err = nbs.NewLocalJournalingStore(ctx, nbf.VersionString(), path, q, mmapArchiveIndexes, recCb)
} else {
newGenSt, err = nbs.NewLocalStore(ctx, nbf.VersionString(), path, defaultMemTableSize, q, mmapArchiveIndexes)
}
@@ -205,12 +226,6 @@ func (fact FileFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat,
ns := tree.NewNodeStore(st)
ddb := datas.NewTypesDatabase(vrw, ns)
singletons[urlObj.Path] = singletonDB{
ddb: ddb,
vrw: vrw,
ns: ns,
}
return ddb, vrw, ns, nil
}
+31 -4
View File
@@ -39,6 +39,17 @@ const (
RemovedTable
)
// Filter type constants for diff filtering.
// These correspond to the string values used in the --filter flag and
// are stored in TableDeltaSummary.DiffType field.
const (
DiffTypeAdded = "added"
DiffTypeModified = "modified"
DiffTypeRenamed = "renamed"
DiffTypeDropped = "dropped"
DiffTypeAll = "all"
)
const DBPrefix = "__DATABASE__"
type TableInfo struct {
@@ -97,6 +108,22 @@ func (tds TableDeltaSummary) IsRename() bool {
return tds.FromTableName != tds.ToTableName
}
// ChangeTypeToDiffType converts a row-level ChangeType to a table-level DiffType string.
// This allows row-level filtering to use the same DiffType infrastructure as table-level filtering.
func ChangeTypeToDiffType(ct ChangeType) string {
switch ct {
case Added:
return DiffTypeAdded
case Removed:
return DiffTypeDropped
case ModifiedOld, ModifiedNew:
// Both ModifiedOld and ModifiedNew represent the same logical change: modified
return DiffTypeModified
default:
return ""
}
}
// GetStagedUnstagedTableDeltas represents staged and unstaged changes as TableDelta slices.
func GetStagedUnstagedTableDeltas(ctx context.Context, roots doltdb.Roots) (staged, unstaged []TableDelta, err error) {
staged, err = GetTableDeltas(ctx, roots.Head, roots.Staged)
@@ -689,7 +716,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error)
FromTableName: td.FromName,
DataChange: dataChange,
SchemaChange: true,
DiffType: "dropped",
DiffType: DiffTypeDropped,
}, nil
}
@@ -700,7 +727,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error)
ToTableName: td.ToName,
DataChange: dataChange,
SchemaChange: true,
DiffType: "added",
DiffType: DiffTypeAdded,
}, nil
}
@@ -712,7 +739,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error)
ToTableName: td.ToName,
DataChange: dataChange,
SchemaChange: true,
DiffType: "renamed",
DiffType: DiffTypeRenamed,
}, nil
}
@@ -727,7 +754,7 @@ func (td TableDelta) GetSummary(ctx context.Context) (*TableDeltaSummary, error)
ToTableName: td.ToName,
DataChange: dataChange,
SchemaChange: schemaChange,
DiffType: "modified",
DiffType: DiffTypeModified,
}, nil
}
-114
View File
@@ -15,15 +15,12 @@
package doltdb
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/dolthub/go-mysql-server/sql"
@@ -2386,117 +2383,6 @@ func (ddb *DoltDB) PurgeCaches() {
ddb.ns.PurgeCaches()
}
type FSCKReport struct {
ChunkCount uint32
Problems []error
}
// FSCK performs a full file system check on the database. This is currently exposed with the CLI as `dolt fsck`
// The success of failure of the scan are returned in the report as a list of errors. The error returned by this function
// indicates a deeper issue such as having database in an old format.
func (ddb *DoltDB) FSCK(ctx context.Context, progress chan string) (*FSCKReport, error) {
cs := datas.ChunkStoreFromDatabase(ddb.db)
vs := types.NewValueStore(cs)
gs, ok := cs.(*nbs.GenerationalNBS)
if !ok {
return nil, errors.New("FSCK requires a local database")
}
chunkCount, err := gs.OldGen().Count()
if err != nil {
return nil, err
}
chunkCount2, err := gs.NewGen().Count()
if err != nil {
return nil, err
}
chunkCount += chunkCount2
proccessedCnt := int64(0)
var errs []error
decodeMsg := func(chk chunks.Chunk) string {
hrs := ""
val, err := types.DecodeValue(chk, vs)
if err == nil {
hrs = val.HumanReadableString()
} else {
hrs = fmt.Sprintf("Unable to decode value: %s", err.Error())
}
return hrs
}
// Append safely to the slice of errors with a mutex.
errsLock := &sync.Mutex{}
appendErr := func(err error) {
errsLock.Lock()
defer errsLock.Unlock()
errs = append(errs, err)
}
// Callback for validating chunks. This code could be called concurrently, though that is not currently the case.
validationCallback := func(chunk chunks.Chunk) {
chunkOk := true
pCnt := atomic.AddInt64(&proccessedCnt, 1)
h := chunk.Hash()
raw := chunk.Data()
calcChkSum := hash.Of(raw)
if h != calcChkSum {
fuzzyMatch := false
// Special case for the journal chunk source. We may have an address which has 4 null bytes at the end.
if h[hash.ByteLen-1] == 0 && h[hash.ByteLen-2] == 0 && h[hash.ByteLen-3] == 0 && h[hash.ByteLen-4] == 0 {
// Now we'll just verify that the first 16 bytes match.
ln := hash.ByteLen - 4
fuzzyMatch = bytes.Compare(h[:ln], calcChkSum[:ln]) == 0
}
if !fuzzyMatch {
hrs := decodeMsg(chunk)
appendErr(errors.New(fmt.Sprintf("Chunk: %s content hash mismatch: %s\n%s", h.String(), calcChkSum.String(), hrs)))
chunkOk = false
}
}
if chunkOk {
// Round trip validation. Ensure that the top level store returns the same data.
c, err := cs.Get(ctx, h)
if err != nil {
appendErr(errors.New(fmt.Sprintf("Chunk: %s load failed with error: %s", h.String(), err.Error())))
chunkOk = false
} else if bytes.Compare(raw, c.Data()) != 0 {
hrs := decodeMsg(chunk)
appendErr(errors.New(fmt.Sprintf("Chunk: %s read with incorrect ID: %s\n%s", h.String(), c.Hash().String(), hrs)))
chunkOk = false
}
}
percentage := (float64(pCnt) * 100) / float64(chunkCount)
result := fmt.Sprintf("(%4.1f%% done)", percentage)
progStr := "OK: " + h.String()
if !chunkOk {
progStr = "FAIL: " + h.String()
}
progStr = result + " " + progStr
progress <- progStr
}
err = gs.OldGen().IterateAllChunks(ctx, validationCallback)
if err != nil {
return nil, err
}
err = gs.NewGen().IterateAllChunks(ctx, validationCallback)
if err != nil {
return nil, err
}
FSCKReport := FSCKReport{Problems: errs, ChunkCount: chunkCount}
return &FSCKReport, nil
}
const (
DbRevisionDelimiter = "/"
)
+2 -6
View File
@@ -155,12 +155,13 @@ func GeneratedSystemTableNames() []string {
GetTableOfTablesWithViolationsName(),
GetCommitsTableName(),
GetCommitAncestorsTableName(),
GetStatusTableName(),
GetRemotesTableName(),
GetHelpTableName(),
GetBackupsTableName(),
GetStashesTableName(),
GetBranchActivityTableName(),
// [dtables.StatusTable] now uses [adapters.DoltTableAdapterRegistry] in its constructor for Doltgres.
StatusTableName,
}
}
@@ -367,11 +368,6 @@ var GetSchemaConflictsTableName = func() string {
return SchemaConflictsTableName
}
// GetStatusTableName returns the status system table name.
var GetStatusTableName = func() string {
return StatusTableName
}
// GetTagsTableName returns the tags table name
var GetTagsTableName = func() string {
return TagsTableName
+17 -15
View File
@@ -53,6 +53,8 @@ const (
tempTablesDir = "temptf"
TmpDirName = "tmp"
InvalidRemoteNameCharacters = " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|"
)
var zeroHashStr = (hash.Hash{}).String()
@@ -64,15 +66,15 @@ var ErrRemoteAlreadyExists = errors.New("remote already exists")
var ErrInvalidRemoteURL = errors.New("remote URL invalid")
var ErrRemoteNotFound = errors.New("remote not found")
var ErrInvalidRemoteName = errors.New("remote name invalid")
var ErrBackupAlreadyExists = errors.New("backup already exists")
var ErrInvalidBackupURL = errors.New("backup URL invalid")
var ErrBackupNotFound = errors.New("backup not found")
var ErrInvalidBackupName = errors.New("backup name invalid")
var ErrFailedToDeleteBackup = errors.New("failed to delete backup")
var ErrBackupAlreadyExists = goerrors.NewKind("backup '%s' already exists")
var ErrBackupInvalidUrl = goerrors.NewKind("backup URL '%s' is invalid")
var ErrBackupNotFound = goerrors.NewKind("backup '%s' not found")
var ErrBackupInvalidName = goerrors.NewKind("backup name '%s' is invalid")
var ErrBackupFailedDelete = goerrors.NewKind("backup '%s' failed to delete")
var ErrFailedToReadFromDb = errors.New("failed to read from db")
var ErrFailedToDeleteRemote = errors.New("failed to delete remote")
var ErrFailedToWriteRepoState = errors.New("failed to write repo state")
var ErrRemoteAddressConflict = errors.New("address conflict with a remote")
var ErrRemoteAddressConflict = goerrors.NewKind("address conflict with a remote: '%s' -> %s")
var ErrDoltRepositoryNotFound = errors.New("can no longer find .dolt dir on disk")
var ErrFailedToAccessDB = goerrors.NewKind("failed to access '%s' database: can no longer find .dolt dir on disk")
var ErrDatabaseIsLocked = errors.New("the database is locked by another dolt process")
@@ -1023,7 +1025,7 @@ func (dEnv *DoltEnv) AddRemote(r Remote) error {
return ErrRemoteAlreadyExists
}
if strings.IndexAny(r.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
if strings.IndexAny(r.Name, InvalidRemoteNameCharacters) != -1 {
return ErrInvalidRemoteName
}
@@ -1034,7 +1036,7 @@ func (dEnv *DoltEnv) AddRemote(r Remote) error {
// can have multiple remotes with the same address, but no conflicting backups
if rem, found := CheckRemoteAddressConflict(absRemoteUrl, nil, dEnv.RepoState.Backups); found {
return fmt.Errorf("%w: '%s' -> %s", ErrRemoteAddressConflict, rem.Name, rem.Url)
return ErrRemoteAddressConflict.New(rem.Name, rem.Url)
}
r.Url = absRemoteUrl
@@ -1052,21 +1054,21 @@ func (dEnv *DoltEnv) GetBackups() (*concurrentmap.Map[string, Remote], error) {
func (dEnv *DoltEnv) AddBackup(r Remote) error {
if _, ok := dEnv.RepoState.Backups.Get(r.Name); ok {
return ErrBackupAlreadyExists
return ErrBackupAlreadyExists.New(r.Name)
}
if strings.IndexAny(r.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
return ErrInvalidBackupName
if strings.IndexAny(r.Name, InvalidRemoteNameCharacters) != -1 {
return ErrBackupInvalidName.New(r.Name)
}
_, absRemoteUrl, err := GetAbsRemoteUrl(dEnv.FS, dEnv.Config, r.Url)
if err != nil {
return fmt.Errorf("%w; %s", ErrInvalidBackupURL, err.Error())
return ErrBackupInvalidUrl.New(r.Url, err.Error())
}
// no conflicting remote or backup addresses
if rem, found := CheckRemoteAddressConflict(absRemoteUrl, dEnv.RepoState.Remotes, dEnv.RepoState.Backups); found {
return fmt.Errorf("%w: '%s' -> %s", ErrRemoteAddressConflict, rem.Name, rem.Url)
if conflict, found := CheckRemoteAddressConflict(absRemoteUrl, dEnv.RepoState.Remotes, dEnv.RepoState.Backups); found {
return ErrRemoteAddressConflict.New(conflict.Name, conflict.Url)
}
r.Url = absRemoteUrl
@@ -1110,7 +1112,7 @@ func (dEnv *DoltEnv) RemoveRemote(ctx context.Context, name string) error {
func (dEnv *DoltEnv) RemoveBackup(ctx context.Context, name string) error {
backup, ok := dEnv.RepoState.Backups.Get(name)
if !ok {
return ErrBackupNotFound
return ErrBackupNotFound.New(name)
}
dEnv.RepoState.RemoveBackup(backup)
+12 -3
View File
@@ -169,7 +169,10 @@ func MultiEnvForDirectory(
} else {
dbErr := newDEnv.DBLoadError
if dbErr != nil {
if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
if errors.Is(dbErr, nbs.ErrJournalDataLoss) {
logrus.Errorf("failed to load database %s with error: %s", dbName, dbErr.Error())
logrus.Errorf("please run 'dolt fsck' to assess the damage and attempt repairs")
} else if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
logrus.Warnf("failed to load database with error: %s", dbErr.Error())
}
}
@@ -204,7 +207,10 @@ func MultiEnvForDirectory(
} else {
dbErr := newEnv.DBLoadError
if dbErr != nil {
if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
if errors.Is(dbErr, nbs.ErrJournalDataLoss) {
logrus.Errorf("failed to load database at %s with error: %s", path, dbErr.Error())
logrus.Errorf("please run 'dolt fsck' to assess the damage and attempt repairs")
} else if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
logrus.Warnf("failed to load database at %s with error: %s", path, dbErr.Error())
}
}
@@ -247,7 +253,10 @@ func (mrEnv *MultiRepoEnv) ReloadDBs(
if !dEnv.Valid() {
dbErr := dEnv.DBLoadError
if dbErr != nil {
if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
if errors.Is(dbErr, nbs.ErrJournalDataLoss) {
logrus.Errorf("failed to load database at %s with error: %s", dEnv.urlStr, dbErr.Error())
logrus.Errorf("please run 'dolt fsck' to assess the damage and attempt repairs")
} else if !errors.Is(dbErr, doltdb.ErrMissingDoltDataDir) {
logrus.Warnf("failed to load database at %s with error: %s", dEnv.urlStr, dbErr.Error())
}
}
+11 -3
View File
@@ -29,6 +29,14 @@ import (
var ErrFailedToDetermineMergeability = errors.New("failed to determine mergeability")
type FastForwardMode int
const (
FastForwardDefault FastForwardMode = iota
FastForwardOnly
NoFastForward
)
type MergeSpec struct {
HeadH hash.Hash
MergeH hash.Hash
@@ -38,7 +46,7 @@ type MergeSpec struct {
StompedTblNames []doltdb.TableName
WorkingDiffs map[doltdb.TableName]hash.Hash
Squash bool
NoFF bool
FFMode FastForwardMode
NoCommit bool
NoEdit bool
Force bool
@@ -49,9 +57,9 @@ type MergeSpec struct {
type MergeSpecOpt func(*MergeSpec)
func WithNoFF(noFF bool) MergeSpecOpt {
func WithFastForwardMode(mode FastForwardMode) MergeSpecOpt {
return func(ms *MergeSpec) {
ms.NoFF = noFF
ms.FFMode = mode
}
}
@@ -63,7 +63,6 @@ func createFulltextTable(ctx *sql.Context, name string, root doltdb.RootValue) (
gmsDb := memory.NewDatabase("gms_db")
gmsTable := memory.NewLocalTable(gmsDb, name, sqlSch, nil)
gmsTable.EnablePrimaryKeyIndexes()
return &fulltextTable{
GMSTable: gmsTable,
Table: tbl,
@@ -1747,7 +1747,7 @@ func convertValueToNewType(ctx *sql.Context, value interface{}, newTypeInfo type
if err != nil {
return nil, err
}
if !inRange {
if inRange != sql.InRange {
return nil, fmt.Errorf("out of range conversion for value %v to type %s", value, newTypeInfo.String())
}
return newValue, nil
@@ -2154,6 +2154,10 @@ func (m *valueMerger) processColumn(ctx *sql.Context, i int, left, right, base v
return nil, true, err
}
if _, ok := sqlType.(types.JsonType); ok && !disallowJsonMerge {
// if any of the values are NULL, this is an unresolvable conflict
if baseCol == nil || leftCol == nil || rightCol == nil {
return nil, true, nil
}
return m.mergeJSONAddr(ctx, baseCol, leftCol, rightCol)
}
// otherwise, this is a conflict.
@@ -2228,46 +2232,14 @@ func mergeJSON(ctx context.Context, ns tree.NodeStore, base, left, right sql.JSO
}
}
indexedBase, isBaseIndexed := base.(tree.IndexedJsonDocument)
indexedLeft, isLeftIndexed := left.(tree.IndexedJsonDocument)
indexedRight, isRightIndexed := right.(tree.IndexedJsonDocument)
// We only do three way merges on values read from tables right now, which are read in as tree.IndexedJsonDocument.
var leftDiffer tree.IJsonDiffer
if isBaseIndexed && isLeftIndexed {
leftDiffer, err = tree.NewIndexedJsonDiffer(ctx, indexedBase, indexedLeft)
if err != nil {
return nil, true, err
}
} else {
baseObject, err := base.ToInterface(ctx)
if err != nil {
return nil, true, err
}
leftObject, err := left.ToInterface(ctx)
if err != nil {
return nil, true, err
}
leftDiffer = tree.NewJsonDiffer(baseObject.(types.JsonObject), leftObject.(types.JsonObject))
leftDiffer, err := tree.NewJsonDiffer(ctx, base, left)
if err != nil {
return nil, true, err
}
var rightDiffer tree.IJsonDiffer
if isBaseIndexed && isRightIndexed {
rightDiffer, err = tree.NewIndexedJsonDiffer(ctx, indexedBase, indexedRight)
if err != nil {
return nil, true, err
}
} else {
baseObject, err := base.ToInterface(ctx)
if err != nil {
return nil, true, err
}
rightObject, err := right.ToInterface(ctx)
if err != nil {
return nil, true, err
}
rightDiffer = tree.NewJsonDiffer(baseObject.(types.JsonObject), rightObject.(types.JsonObject))
rightDiffer, err := tree.NewJsonDiffer(ctx, base, right)
if err != nil {
return nil, true, err
}
threeWayDiffer := ThreeWayJsonDiffer{
@@ -1250,6 +1250,20 @@ var jsonMergeTests = []schemaMergeTest{
right: singleRow(1, 1, 2, `{ "key1": "value3" }`),
dataConflict: true,
},
{
name: `divergent modification with NULL ancestor`,
ancestor: singleRow(1, 1, 1, nil),
left: singleRow(1, 2, 1, `{ "key1": "value2" }`),
right: singleRow(1, 1, 2, `{ "key1": "value3" }`),
dataConflict: true,
},
{
name: `divergent modification with NULL child`,
ancestor: singleRow(1, 1, 1, `{ "key1": "value1"}`),
left: singleRow(1, 2, 1, nil),
right: singleRow(1, 1, 2, `{ "key1": "value3" }`),
dataConflict: true,
},
{
name: `divergent modification and deletion`,
ancestor: singleRow(1, 1, 1, `{ "key1": "value1"}`),
@@ -1428,7 +1442,7 @@ func jsonMergeLargeDocumentTests(t *testing.T) []schemaMergeTest {
insert := func(document sqltypes.MutableJSON, path string, val interface{}) sqltypes.MutableJSON {
jsonVal, inRange, err := sqltypes.JSON.Convert(ctx, val)
require.NoError(t, err)
require.True(t, (bool)(inRange))
require.True(t, inRange == sql.InRange)
newDoc, changed, err := document.Insert(ctx, path, jsonVal.(sql.JSONWrapper))
require.NoError(t, err)
require.True(t, changed)
@@ -1438,7 +1452,7 @@ func jsonMergeLargeDocumentTests(t *testing.T) []schemaMergeTest {
set := func(document sqltypes.MutableJSON, path string, val interface{}) sqltypes.MutableJSON {
jsonVal, inRange, err := sqltypes.JSON.Convert(ctx, val)
require.NoError(t, err)
require.True(t, (bool)(inRange))
require.True(t, inRange == sql.InRange)
newDoc, changed, err := document.Replace(ctx, path, jsonVal.(sql.JSONWrapper))
require.NoError(t, err)
require.True(t, changed)
@@ -26,7 +26,7 @@ import (
)
type ThreeWayJsonDiffer struct {
leftDiffer, rightDiffer tree.IJsonDiffer
leftDiffer, rightDiffer tree.JsonDiffer
leftCurrentDiff, rightCurrentDiff *tree.JsonDiff
leftIsDone, rightIsDone bool
ns tree.NodeStore
@@ -37,7 +37,7 @@ type sqlEngineTableReader struct {
}
func NewSqlEngineReader(ctx *sql.Context, engine *sqle.Engine, root doltdb.RootValue, tableName string) (*sqlEngineTableReader, error) {
binder := planbuilder.New(ctx, engine.Analyzer.Catalog, engine.EventScheduler, engine.Parser)
binder := planbuilder.New(ctx, engine.Analyzer.Catalog, engine.EventScheduler)
ret, _, _, _, err := binder.Parse(fmt.Sprintf("show create table `%s`", tableName), nil, false)
if err != nil {
return nil, err
@@ -29,6 +29,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/transform"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/typed/noms"
"github.com/dolthub/dolt/go/store/types"
@@ -143,7 +144,8 @@ func (s *SqlEngineTableWriter) WriteRows(ctx context.Context, inputChannel chan
return err
}
iter, err := rowexec.DefaultBuilder.Build(s.sqlCtx, insertOrUpdateOperation, nil)
engOverrides := overrides.EngineOverridesFromContext(ctx)
iter, err := rowexec.NewBuilder(nil, engOverrides).Build(s.sqlCtx, insertOrUpdateOperation, nil)
if err != nil {
return err
}
@@ -299,21 +301,22 @@ func (s *SqlEngineTableWriter) createOrEmptyTableIfNeeded() error {
func (s *SqlEngineTableWriter) createTable() error {
// TODO don't use internal interfaces to do this, we had to have a sql.Schema somewhere
// upstream to make the dolt schema
formatter := overrides.SchemaFormatterFromContext(s.sqlCtx)
sqlCols := make([]string, len(s.tableSchema.Schema))
for i, c := range s.tableSchema.Schema {
sqlCols[i] = sql.GenerateCreateTableColumnDefinition(c, c.Default.String(), c.OnUpdate.String(), sql.Collation_Default)
sqlCols[i] = formatter.GenerateCreateTableColumnDefinition(c, c.Default.String(), c.OnUpdate.String(), sql.Collation_Default)
}
var pks string
var sep string
for _, i := range s.tableSchema.PkOrdinals {
pks += sep + sql.QuoteIdentifier(s.tableSchema.Schema[i].Name)
pks += sep + formatter.QuoteIdentifier(s.tableSchema.Schema[i].Name)
sep = ", "
}
if len(sep) > 0 {
sqlCols = append(sqlCols, fmt.Sprintf("PRIMARY KEY (%s)", pks))
}
createTable := sql.GenerateCreateTableStatement(s.tableName, sqlCols, "", "", sql.CharacterSet_utf8mb4.String(), sql.Collation_Default.String(), "")
createTable := formatter.GenerateCreateTableStatement(s.tableName, sqlCols, "", "", sql.CharacterSet_utf8mb4.String(), sql.Collation_Default.String(), "")
_, iter, _, err := s.se.Query(s.sqlCtx, createTable)
if err != nil {
return err
@@ -325,6 +328,7 @@ func (s *SqlEngineTableWriter) createTable() error {
// createInsertImportNode creates the relevant/analyzed insert node given the import option. This insert node is wrapped
// with an error handler.
func (s *SqlEngineTableWriter) getInsertNode(inputChannel chan sql.Row, replace bool) (sql.Node, error) {
formatter := overrides.SchemaFormatterFromContext(s.sqlCtx)
update := s.importOption == UpdateOp
colNames := ""
values := ""
@@ -334,7 +338,7 @@ func (s *SqlEngineTableWriter) getInsertNode(inputChannel chan sql.Row, replace
}
sep := ""
for _, col := range s.rowOperationSchema.Schema {
colNames += fmt.Sprintf("%s%s", sep, sql.QuoteIdentifier(col.Name))
colNames += fmt.Sprintf("%s%s", sep, formatter.QuoteIdentifier(col.Name))
values += fmt.Sprintf("%s1", sep)
if update {
duplicate += fmt.Sprintf("%s`%s` = VALUES(`%s`)", sep, col.Name, col.Name)
@@ -343,7 +347,7 @@ func (s *SqlEngineTableWriter) getInsertNode(inputChannel chan sql.Row, replace
}
sqlEngine := s.se
binder := planbuilder.New(s.sqlCtx, sqlEngine.Analyzer.Catalog, sqlEngine.EventScheduler, sqlEngine.Parser)
binder := planbuilder.New(s.sqlCtx, sqlEngine.Analyzer.Catalog, sqlEngine.EventScheduler)
insert := fmt.Sprintf("insert into `%s` (%s) VALUES (%s)%s", s.tableName, colNames, values, duplicate)
parsed, _, _, qFlags, err := binder.Parse(insert, nil, false)
if err != nil {
+5 -5
View File
@@ -137,11 +137,11 @@ func NewCheck(name, expression string, enforced bool) check {
}
}
func (c checkCollection) Copy() CheckCollection {
checks := make([]check, len(c.checks))
func (c *checkCollection) Copy() CheckCollection {
newC := *c
newC.checks = make([]check, len(c.checks))
for i, check := range c.checks {
checks[i] = NewCheck(check.name, check.expression, check.enforced)
newC.checks[i] = NewCheck(check.name, check.expression, check.enforced)
}
return &c
return &newC
}
@@ -0,0 +1,44 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package schema
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCopy(t *testing.T) {
var original = &checkCollection{
checks: []check{{"check1", "expr1", true},
{"check2", "expr2", false}},
}
var copy = original.Copy()
// Assert copy doesn't reuse the same check instances
original.checks[0].name = "XXX"
original.checks[0].expression = "XXX"
original.checks[0].enforced = false
original.checks[1].name = "XXX"
original.checks[1].expression = "XXX"
original.checks[1].enforced = true
assert.Equal(t, "check1", copy.AllChecks()[0].Name())
assert.Equal(t, "expr1", copy.AllChecks()[0].Expression())
assert.Equal(t, true, copy.AllChecks()[0].Enforced())
assert.Equal(t, "check2", copy.AllChecks()[1].Name())
assert.Equal(t, "expr2", copy.AllChecks()[1].Expression())
assert.Equal(t, false, copy.AllChecks()[1].Enforced())
}
@@ -33,15 +33,17 @@ var _ val.TupleComparator = CollationTupleComparator{}
// Compare implements TupleComparator
func (c CollationTupleComparator) Compare(ctx context.Context, left, right val.Tuple, desc *val.TupleDesc) (cmp int) {
fast := desc.GetFixedAccess()
for i := range fast {
start, stop := fast[i][0], fast[i][1]
off := len(fast)
var start, stop val.ByteSize
for i := 0; i < off; i++ {
stop = fast[i]
cmp = collationCompare(ctx, desc.Types[i], c.Collations[i], left[start:stop], right[start:stop])
if cmp != 0 {
return cmp
}
start = stop
}
off := len(fast)
for i, typ := range desc.Types[off:] {
j := i + off
cmp = collationCompare(ctx, typ, c.Collations[j], left.GetField(j), right.GetField(j))
@@ -111,7 +113,7 @@ func collationCompare(ctx context.Context, typ val.Type, collation sql.Collation
if typ.Enc == val.StringEnc {
return compareCollatedStrings(collation, left[:len(left)-1], right[:len(right)-1])
} else {
return val.DefaultTupleComparator{}.CompareValues(ctx, 0, left, right, typ)
return (&val.DefaultTupleComparator{}).CompareValues(ctx, 0, left, right, typ)
}
}
@@ -381,7 +381,7 @@ func (e extendedType) Convert(ctx context.Context, i interface{}) (interface{},
panic("unimplemented")
}
func (e extendedType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val any) (any, error) {
func (e extendedType) ConvertToType(ctx *sql.Context, typ sql.ExtendedType, val any) (any, sql.ConvertInRange, error) {
panic("unimplemented")
}
+91 -48
View File
@@ -185,6 +185,10 @@ type ServerConfig interface {
CACert() string
// RequireSecureTransport is true if the server should reject non-TLS connections.
RequireSecureTransport() bool
// RequireClientCert is true if the server should reject any connections that don't present a certificate. When
// enabled, a client certificate is always required, and if a CA cert is also configured, then the client cert
// will also be verified. Enabling this option also means that non-TLS connections are not allowed.
RequireClientCert() bool
// MaxLoggedQueryLen is the max length of queries written to the logs. Queries longer than this number are truncated.
// If this value is 0 then the query is not truncated and will be written to the logs in its entirety. If the value
// is less than 0 then the queries will be omitted from the logs completely
@@ -201,6 +205,12 @@ type ServerConfig interface {
MetricsLabels() map[string]string
MetricsHost() string
MetricsPort() int
MetricsTLSCert() string
MetricsTLSKey() string
MetricsTLSCA() string
MetricsJwksConfig() *JwksConfig
MetricsJWTRequiredForLocalhost() bool
// PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a
// JSON string.
PrivilegeFilePath() string
@@ -239,6 +249,8 @@ type ServerConfig interface {
ValueSet(value string) bool
// AutoGCBehavior defines parameters around how auto-GC works for the running server.
AutoGCBehavior() AutoGCBehavior
// Overrides returns any overrides that are defined. This is primarily used by Doltgres.
Overrides() sql.EngineOverrides
}
// DefaultServerConfig creates a `*ServerConfig` that has all of the options set to their default values.
@@ -324,43 +336,48 @@ func ValidateConfig(config ServerConfig) error {
}
const (
HostKey = "host"
PortKey = "port"
UserKey = "user"
PasswordKey = "password"
ReadTimeoutKey = "net_read_timeout"
WriteTimeoutKey = "net_write_timeout"
ReadOnlyKey = "read_only"
LogLevelKey = "log_level"
LogFormatKey = "log_format"
AutoCommitKey = "autocommit"
DoltTransactionCommitKey = "dolt_transaction_commit"
BranchActivityTrackingKey = "branch_activity_tracking"
DataDirKey = "data_dir"
CfgDirKey = "cfg_dir"
MaxConnectionsKey = "max_connections"
MaxWaitConnectionsKey = "back_log"
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
TLSKeyKey = "tls_key"
TLSCertKey = "tls_cert"
RequireSecureTransportKey = "require_secure_transport"
MaxLoggedQueryLenKey = "max_logged_query_len"
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
DisableClientMultiStatementsKey = "disable_client_multi_statements"
MetricsLabelsKey = "metrics_labels"
MetricsHostKey = "metrics_host"
MetricsPortKey = "metrics_port"
PrivilegeFilePathKey = "privilege_file_path"
BranchControlFilePathKey = "branch_control_file_path"
UserVarsKey = "user_vars"
SystemVarsKey = "system_vars"
JwksConfigKey = "jwks_config"
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
SocketKey = "socket"
RemotesapiPortKey = "remotesapi_port"
RemotesapiReadOnlyKey = "remotesapi_read_only"
ClusterConfigKey = "cluster_config"
EventSchedulerKey = "event_scheduler"
HostKey = "host"
PortKey = "port"
UserKey = "user"
PasswordKey = "password"
ReadTimeoutKey = "net_read_timeout"
WriteTimeoutKey = "net_write_timeout"
ReadOnlyKey = "read_only"
LogLevelKey = "log_level"
LogFormatKey = "log_format"
AutoCommitKey = "autocommit"
DoltTransactionCommitKey = "dolt_transaction_commit"
BranchActivityTrackingKey = "branch_activity_tracking"
DataDirKey = "data_dir"
CfgDirKey = "cfg_dir"
MaxConnectionsKey = "max_connections"
MaxWaitConnectionsKey = "back_log"
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
TLSKeyKey = "tls_key"
TLSCertKey = "tls_cert"
RequireSecureTransportKey = "require_secure_transport"
MaxLoggedQueryLenKey = "max_logged_query_len"
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
DisableClientMultiStatementsKey = "disable_client_multi_statements"
MetricsLabelsKey = "metrics_labels"
MetricsHostKey = "metrics_host"
MetricsPortKey = "metrics_port"
MetricsTLSCertKey = "metrics_tls_cert"
MetricsTLSKeyKey = "metrics_tls_key"
MetricsTLSCAKey = "metrics_tls_ca"
MetricsJwksConfigKey = "metrics_jwks_config"
MetricsJWTRequiredForLocalhostKey = "metrics_jwt_required_for_localhost"
PrivilegeFilePathKey = "privilege_file_path"
BranchControlFilePathKey = "branch_control_file_path"
UserVarsKey = "user_vars"
SystemVarsKey = "system_vars"
JwksConfigKey = "jwks_config"
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
SocketKey = "socket"
RemotesapiPortKey = "remotesapi_port"
RemotesapiReadOnlyKey = "remotesapi_read_only"
ClusterConfigKey = "cluster_config"
EventSchedulerKey = "event_scheduler"
)
type SystemVariableTarget interface {
@@ -467,38 +484,64 @@ func ConfigInfo(config ServerConfig) string {
config.ReadTimeout(), config.ReadOnly(), config.LogLevel(), socket)
}
// LoadTLSConfig loads the certificate chain from config.TLSKey() and config.TLSCert() and returns
// a *tls.Config configured for its use. Returns `nil` if key and cert are `""`.
func LoadTLSConfig(cfg ServerConfig) (*tls.Config, error) {
if cfg.TLSKey() == "" && cfg.TLSCert() == "" {
return nil, nil
func getTLSConfig(cert, key, ca string, requireClientCert bool) (*tls.Config, error) {
if key == "" && cert == "" {
if requireClientCert {
return nil, fmt.Errorf("must supply tls_cert and tls_key when require_client_cert is enabled")
} else {
// No TLS configuration needed
return nil, nil
}
}
c, err := tls.LoadX509KeyPair(cfg.TLSCert(), cfg.TLSKey())
c, err := tls.LoadX509KeyPair(cert, key)
if err != nil {
return nil, err
return nil, fmt.Errorf("tls.LoadX509KeyPair(%v, %v) failed: %w", cert, key, err)
}
var caCertPool *x509.CertPool
if cfg.CACert() != "" {
caCertPEM, err := os.ReadFile(cfg.CACert())
if ca != "" {
caCertPEM, err := os.ReadFile(ca)
if err != nil {
return nil, fmt.Errorf("unable to read CA file at: %s", cfg.CACert())
return nil, fmt.Errorf("unable to read CA file at %s: %w", ca, err)
}
caCertPool = x509.NewCertPool()
if ok := caCertPool.AppendCertsFromPEM(caCertPEM); !ok {
return nil, fmt.Errorf("unable to add CA cert to cert pool")
}
}
clientAuthType := tls.VerifyClientCertIfGiven
if requireClientCert {
// If a CA cert has been specified, then in addition to requiring
// a client cert, also verify it, otherwise allow any client cert.
if ca != "" {
clientAuthType = tls.RequireAndVerifyClientCert
} else {
clientAuthType = tls.RequireAnyClientCert
}
}
return &tls.Config{
Certificates: []tls.Certificate{c},
// tlsVerifyClientCertIfGiven will request a client cert from the client,
// and if provided, will validate it against the specified client CAs.
ClientAuth: tls.VerifyClientCertIfGiven,
ClientAuth: clientAuthType,
ClientCAs: caCertPool,
}, nil
}
// LoadTLSConfig loads the certificate chain from config.TLSKey() and config.TLSCert() and returns
// a *tls.Config configured for its use. Returns `nil` if key and cert are `""`.
func LoadTLSConfig(cfg ServerConfig) (*tls.Config, error) {
return getTLSConfig(cfg.TLSCert(), cfg.TLSKey(), cfg.CACert(), cfg.RequireClientCert())
}
func LoadMetricsTLSConfig(cfg ServerConfig) (*tls.Config, error) {
return getTLSConfig(cfg.MetricsTLSCert(), cfg.MetricsTLSKey(), cfg.MetricsTLSCA(), false)
}
// CheckForUnixSocket evaluates ServerConfig for whether the unix socket is to be used or not.
// If user defined socket flag or host is 'localhost', it returns the unix socket file location
// either user-defined or the default if it was not defined.
@@ -31,6 +31,7 @@ ListenerConfig servercfg.ListenerYAMLConfig 0.0.0 listener,omitempty
-TLSCert *string 0.0.0 tls_cert,omitempty
-CACert *string 1.77.0 ca_cert,omitempty
-RequireSecureTransport *bool 0.0.0 require_secure_transport,omitempty
-RequireClientCert *bool 1.78.3 require_client_cert,omitempty
-AllowCleartextPasswords *bool 0.0.0 allow_cleartext_passwords,omitempty
-Socket *string 0.0.0 socket,omitempty
PerformanceConfig *servercfg.PerformanceYAMLConfig 0.0.0 performance,omitempty
@@ -61,6 +62,15 @@ MetricsConfig servercfg.MetricsYAMLConfig 0.0.0 metrics,omitempty
-Labels map[string]string 0.0.0 labels
-Host *string 0.0.0 host,omitempty
-Port *int 0.0.0 port,omitempty
-TlsCert *string 1.78.2 tls_cert,omitempty
-TlsKey *string 1.78.2 tls_key,omitempty
-TlsCa *string 1.78.2 tls_ca,omitempty
-Jwks *servercfg.JwksConfig 1.79.0 jwks,omitempty
--Name string 0.0.0 name
--LocationUrl string 0.0.0 location_url
--Claims map[string]string 0.0.0 claims
--FieldsToLog []string 0.0.0 fields_to_log
-JWTRequiredForLocalhost *bool 1.79.0 jwt_required_for_localhost,omitempty
ClusterCfg *servercfg.ClusterYAMLConfig 0.0.0 cluster,omitempty
-StandbyRemotes_ []servercfg.StandbyRemoteYAMLConfig 0.0.0 standby_remotes
--Name_ string 0.0.0 name
+85 -9
View File
@@ -22,6 +22,7 @@ import (
"unicode"
"unicode/utf8"
"github.com/dolthub/go-mysql-server/sql"
"gopkg.in/yaml.v2"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
@@ -95,6 +96,9 @@ type ListenerYAMLConfig struct {
CACert *string `yaml:"ca_cert,omitempty" minver:"1.77.0"`
// RequireSecureTransport can enable a mode where non-TLS connections are turned away.
RequireSecureTransport *bool `yaml:"require_secure_transport,omitempty"`
// RequireClientCert enables a mode where all clients must present a certificate. If a CA
// cert is also provided, the client cert will also be verified.
RequireClientCert *bool `yaml:"require_client_cert,omitempty" minver:"1.78.3"`
// AllowCleartextPasswords enables use of cleartext passwords.
AllowCleartextPasswords *bool `yaml:"allow_cleartext_passwords,omitempty"`
// Socket is unix socket file path
@@ -108,9 +112,14 @@ type PerformanceYAMLConfig struct {
}
type MetricsYAMLConfig struct {
Labels map[string]string `yaml:"labels"`
Host *string `yaml:"host,omitempty"`
Port *int `yaml:"port,omitempty"`
Labels map[string]string `yaml:"labels"`
Host *string `yaml:"host,omitempty"`
Port *int `yaml:"port,omitempty"`
TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"`
TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"`
TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"`
Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"1.79.0"`
JWTRequiredForLocalhost *bool `yaml:"jwt_required_for_localhost,omitempty" minver:"1.79.0"`
}
type RemotesapiYAMLConfig struct {
@@ -227,9 +236,14 @@ func ServerConfigAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
DataDirStr: ptr(cfg.DataDir()),
CfgDirStr: ptr(cfg.CfgDir()),
MetricsConfig: MetricsYAMLConfig{
Labels: cfg.MetricsLabels(),
Host: nillableStrPtr(cfg.MetricsHost()),
Port: ptr(cfg.MetricsPort()),
Labels: cfg.MetricsLabels(),
Host: nillableStrPtr(cfg.MetricsHost()),
Port: ptr(cfg.MetricsPort()),
TlsCert: ptr(cfg.MetricsTLSCert()),
TlsKey: ptr(cfg.MetricsTLSKey()),
TlsCa: ptr(cfg.MetricsTLSCA()),
Jwks: cfg.MetricsJwksConfig(),
JWTRequiredForLocalhost: ptr(cfg.MetricsJWTRequiredForLocalhost()),
},
RemotesapiConfig: RemotesapiYAMLConfig{
Port_: cfg.RemotesapiPort(),
@@ -300,9 +314,14 @@ func ServerConfigSetValuesAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
DataDirStr: zeroIf(ptr(cfg.DataDir()), !cfg.ValueSet(DataDirKey)),
CfgDirStr: zeroIf(ptr(cfg.CfgDir()), !cfg.ValueSet(CfgDirKey)),
MetricsConfig: MetricsYAMLConfig{
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)),
TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)),
TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)),
Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)),
JWTRequiredForLocalhost: zeroIf(ptr(cfg.MetricsJWTRequiredForLocalhost()), !cfg.ValueSet(MetricsJWTRequiredForLocalhostKey)),
},
RemotesapiConfig: RemotesapiYAMLConfig{
Port_: zeroIf(cfg.RemotesapiPort(), !cfg.ValueSet(RemotesapiPortKey)),
@@ -403,6 +422,15 @@ func (cfg YAMLConfig) withPlaceholdersFilledIn() YAMLConfig {
if withPlaceholders.MetricsConfig.Port == nil {
withPlaceholders.MetricsConfig.Port = ptr(9091)
}
if withPlaceholders.MetricsConfig.TlsCert == nil {
withPlaceholders.MetricsConfig.TlsCert = ptr("")
}
if withPlaceholders.MetricsConfig.TlsKey == nil {
withPlaceholders.MetricsConfig.TlsKey = ptr("")
}
if withPlaceholders.MetricsConfig.TlsCa == nil {
withPlaceholders.MetricsConfig.TlsCa = ptr("")
}
if withPlaceholders.RemotesapiConfig.Port_ == nil {
withPlaceholders.RemotesapiConfig.Port_ = ptr(8000)
@@ -759,6 +787,40 @@ func (cfg YAMLConfig) MetricsPort() int {
return *cfg.MetricsConfig.Port
}
func (cfg YAMLConfig) MetricsTLSCert() string {
if cfg.MetricsConfig.TlsCert == nil {
return ""
}
return *cfg.MetricsConfig.TlsCert
}
func (cfg YAMLConfig) MetricsTLSKey() string {
if cfg.MetricsConfig.TlsKey == nil {
return ""
}
return *cfg.MetricsConfig.TlsKey
}
func (cfg YAMLConfig) MetricsTLSCA() string {
if cfg.MetricsConfig.TlsCa == nil {
return ""
}
return *cfg.MetricsConfig.TlsCa
}
func (cfg YAMLConfig) MetricsJwksConfig() *JwksConfig {
return cfg.MetricsConfig.Jwks
}
func (cfg YAMLConfig) MetricsJWTRequiredForLocalhost() bool {
if cfg.MetricsConfig.JWTRequiredForLocalhost == nil {
return false
}
return *cfg.MetricsConfig.JWTRequiredForLocalhost
}
func (cfg YAMLConfig) RemotesapiPort() *int {
return cfg.RemotesapiConfig.Port_
}
@@ -873,6 +935,16 @@ func (cfg YAMLConfig) CACert() string {
return *cfg.ListenerConfig.CACert
}
// RequireClientCert is true if the server should reject any connections that don't present a certificate. When
// enabled, a client certificate is always required, and if a CA cert is also configured, then the client cert
// will also be verified. Enabling this option also means that non-TLS connections are not allowed.
func (cfg YAMLConfig) RequireClientCert() bool {
if cfg.ListenerConfig.RequireClientCert == nil {
return false
}
return *cfg.ListenerConfig.RequireClientCert
}
// RequireSecureTransport is true if the server should reject non-TLS connections.
func (cfg YAMLConfig) RequireSecureTransport() bool {
if cfg.ListenerConfig.RequireSecureTransport == nil {
@@ -963,6 +1035,10 @@ func (cfg YAMLConfig) EventSchedulerStatus() string {
}
}
func (cfg YAMLConfig) Overrides() sql.EngineOverrides {
return sql.EngineOverrides{}
}
type ClusterYAMLConfig struct {
StandbyRemotes_ []StandbyRemoteYAMLConfig `yaml:"standby_remotes"`
BootstrapRole_ string `yaml:"bootstrap_role"`
@@ -57,6 +57,16 @@ metrics:
label1: value1
label2: 2
label3: true
tls_cert: /path/to/file.cert
tls_key: /path/to/file.key
tls_ca: /path/to/ca.cert
jwks:
name: jwks_name
location_url: https://website.com
claims:
iss: dolthub.com
aud: metrics
fields_to_log: [iss, aud]
user_session_vars:
- name: user0
@@ -104,6 +114,18 @@ jwks:
"label2": "2",
"label3": "true",
},
TlsCert: ptr("/path/to/file.cert"),
TlsKey: ptr("/path/to/file.key"),
TlsCa: ptr("/path/to/ca.cert"),
Jwks: &JwksConfig{
Name: "jwks_name",
LocationUrl: "https://website.com",
Claims: map[string]string{
"iss": "dolthub.com",
"aud": "metrics",
},
FieldsToLog: []string{"iss", "aud"},
},
}
expected.DataDirStr = ptr("some nonsense")
expected.SystemVars_ = nil
@@ -347,6 +369,9 @@ func TestYAMLConfigDefaults(t *testing.T) {
assert.Equal(t, DefaultMetricsHost, cfg.MetricsHost())
assert.Equal(t, DefaultMetricsPort, cfg.MetricsPort())
assert.Nil(t, cfg.MetricsConfig.Labels)
assert.Equal(t, "", cfg.MetricsTLSCert())
assert.Equal(t, "", cfg.MetricsTLSKey())
assert.Equal(t, "", cfg.MetricsTLSCA())
assert.Equal(t, DefaultAllowCleartextPasswords, cfg.AllowCleartextPasswords())
assert.Nil(t, cfg.RemotesapiPort())
@@ -0,0 +1,85 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package adapters
import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
)
// TableAdapter provides a hook for extensions to customize or wrap table implementations. For example, this allows
// libraries like Doltgres to intercept system table creation and apply type conversions, schema modifications, or other
// customizations without modifying the core Dolt implementation for their compatibility.
type TableAdapter interface {
// NewTable creates or wraps a system table. The function receives all necessary parameters to construct the table
// and can either build it from scratch or call the default Dolt constructor and wrap it.
NewTable(ctx *sql.Context, tableName string, dDb *doltdb.DoltDB, workingSet *doltdb.WorkingSet, rootsProvider env.RootsProvider[*sql.Context]) sql.Table
// TableName returns the preferred name for the adapter's table. This allows extensions to rename tables while
// preserving the underlying implementation. For example, Doltgres uses "status" while Dolt uses "dolt_status",
// enabling cleaner Postgres-style naming.
TableName() string
}
var DoltTableAdapterRegistry = newDoltTableAdapterRegistry()
// doltTableAdapterRegistry is a Dolt table name to TableAdapter map. Integrators populate this registry during package
// initialization, and it's intended to be read-only thereafter. The registry links with existing Dolt system tables to
// allow them to be resolved and evaluated to integrator's version and internal aliases (integrators' Dolt table name
// keys).
type doltTableAdapterRegistry struct {
Adapters map[string]TableAdapter
internalAliases map[string]string
}
// newDoltTableAdapterRegistry constructs Dolt table adapter registry with empty internal alias and adapter maps.
func newDoltTableAdapterRegistry() *doltTableAdapterRegistry {
return &doltTableAdapterRegistry{
Adapters: make(map[string]TableAdapter),
internalAliases: make(map[string]string),
}
}
// AddAdapter maps |doltTableName| to an |adapter| in the Dolt table adapter registry, with optional |internalAliases|.
func (as *doltTableAdapterRegistry) AddAdapter(doltTableName string, adapter TableAdapter, internalAliases ...string) {
for _, alias := range internalAliases {
as.internalAliases[alias] = doltTableName
}
as.Adapters[doltTableName] = adapter
}
// GetAdapter gets a Dolt TableAdapter mapped to |name|, which can be the dolt table name or internal alias.
func (as *doltTableAdapterRegistry) GetAdapter(name string) (TableAdapter, bool) {
adapter, ok := as.Adapters[name]
if !ok {
name = as.internalAliases[name]
adapter, ok = as.Adapters[name]
}
return adapter, ok
}
// NormalizeName normalizes |name| if it's an internal alias of the underlying Dolt table name. If no match is found,
// |name| is returned as-is.
func (as *doltTableAdapterRegistry) NormalizeName(name string) string {
doltTableName, ok := as.internalAliases[name]
if !ok {
return name
}
return doltTableName
}
@@ -0,0 +1,77 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package adapters
import (
"testing"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
)
type mockAdapter struct {
name string
}
func (m mockAdapter) NewTable(_ *sql.Context, _ string, _ *doltdb.DoltDB, _ *doltdb.WorkingSet, _ env.RootsProvider[*sql.Context]) sql.Table {
return nil
}
func (m mockAdapter) TableName() string {
return m.name
}
func TestDoltTableAdapterRegistry(t *testing.T) {
registry := newDoltTableAdapterRegistry()
statusAdapter := mockAdapter{name: "status"}
logAdapter := mockAdapter{name: "log"}
registry.AddAdapter(doltdb.StatusTableName, statusAdapter, "status")
registry.AddAdapter(doltdb.LogTableName, logAdapter, "log")
t.Run("GetAdapter", func(t *testing.T) {
adapter, ok := registry.GetAdapter("dolt_status")
require.True(t, ok)
require.Equal(t, "status", adapter.TableName())
adapter, ok = registry.GetAdapter("status")
require.True(t, ok)
require.Equal(t, "status", adapter.TableName())
_, ok = registry.GetAdapter("unknown_alias")
require.False(t, ok)
_, ok = registry.GetAdapter("dolt_unknown")
require.False(t, ok)
})
t.Run("NormalizeName", func(t *testing.T) {
normalized := registry.NormalizeName("status")
require.Equal(t, "dolt_status", normalized)
normalized = registry.NormalizeName("log")
require.Equal(t, "dolt_log", normalized)
normalized = registry.NormalizeName("dolt_status")
require.Equal(t, "dolt_status", normalized)
normalized = registry.NormalizeName("unknown_table")
require.Equal(t, "unknown_table", normalized)
})
}
@@ -292,11 +292,7 @@ func makePeopleTable(ctx context.Context, dEnv *env.DoltEnv) (*env.DoltEnv, erro
}
func mustStringToColumnDefault(defaultString string) *sql.ColumnDefaultValue {
def, err := planbuilder.StringToColumnDefaultValue(sql.NewEmptyContext(), defaultString)
if err != nil {
panic(err)
}
return def
return planbuilder.MustStringToColumnDefaultValue(sql.NewEmptyContext(), defaultString, nil, true)
}
func schemaNewColumnWithDefault(name string, tag uint64, kind types.NomsKind, partOfPK bool, defaultVal string, constraints ...schema.ColConstraint) schema.Column {
@@ -40,6 +40,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
)
@@ -965,7 +966,7 @@ func convertVitessJsonExpressionString(ctx *sql.Context, value sqltypes.Value) (
return nil, fmt.Errorf("unable to access running SQL server")
}
binder := planbuilder.New(ctx, server.Engine.Analyzer.Catalog, server.Engine.EventScheduler, server.Engine.Parser)
binder := planbuilder.New(ctx, server.Engine.Analyzer.Catalog, server.Engine.EventScheduler)
node, _, _, qFlags, err := binder.Parse("SELECT "+strValue, nil, false)
if err != nil {
return nil, err
@@ -976,7 +977,8 @@ func convertVitessJsonExpressionString(ctx *sql.Context, value sqltypes.Value) (
return nil, err
}
rowIter, err := rowexec.DefaultBuilder.Build(ctx, analyze, nil)
engOverrides := overrides.EngineOverridesFromContext(ctx)
rowIter, err := rowexec.NewBuilder(nil, engOverrides).Build(ctx, analyze, nil)
if err != nil {
return nil, err
}
+12 -6
View File
@@ -44,10 +44,12 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/rebase"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/adapters"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dprocedures"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
@@ -478,8 +480,11 @@ func (db Database) getTableInsensitiveWithRoot(ctx *sql.Context, head *doltdb.Co
} else if err != nil {
return nil, false, err
}
dt, err := dtables.NewCommitDiffTable(ctx, db.Name(), tname, db.ddb, root, stagedRoot)
headRef, err := db.rsr.CWBHeadRef(ctx)
if err != nil {
return nil, false, err
}
dt, err := dtables.NewCommitDiffTable(ctx, db.Name(), tname, db.ddb, root, stagedRoot, headRef)
if err != nil {
return nil, false, err
}
@@ -618,7 +623,7 @@ func (db Database) getTableInsensitiveWithRoot(ctx *sql.Context, head *doltdb.Co
var dt sql.Table
found := false
tname := doltdb.TableName{Name: lwrName, Schema: db.schemaName}
switch lwrName {
switch adapters.DoltTableAdapterRegistry.NormalizeName(lwrName) {
case doltdb.GetLogTableName(), doltdb.LogTableName:
isDoltgresSystemTable, err := resolve.IsDoltgresSystemTable(ctx, tname, root)
if err != nil {
@@ -747,7 +752,7 @@ func (db Database) getTableInsensitiveWithRoot(ctx *sql.Context, head *doltdb.Co
if !resolve.UseSearchPath || isDoltgresSystemTable {
dt, found = dtables.NewCommitAncestorsTable(ctx, db.Name(), lwrName, db.ddb), true
}
case doltdb.GetStatusTableName(), doltdb.StatusTableName:
case doltdb.StatusTableName:
isDoltgresSystemTable, err := resolve.IsDoltgresSystemTable(ctx, tname, root)
if err != nil {
return nil, false, err
@@ -2670,7 +2675,7 @@ func (db Database) doltSchemaTableHash(ctx *sql.Context) (hash.Hash, error) {
// createEventDefinitionFromFragment creates an EventDefinition instance from the schema fragment |frag|.
func (db Database) createEventDefinitionFromFragment(ctx *sql.Context, frag schemaFragment) (*sql.EventDefinition, error) {
b := planbuilder.New(ctx, db.getCatalog(ctx), db.getEventScheduler(ctx), nil)
b := planbuilder.New(ctx, db.getCatalog(ctx), db.getEventScheduler(ctx))
b.SetParserOptions(sql.NewSqlModeFromString(frag.sqlMode).ParserOptions())
parsed, _, _, _, err := b.Parse(updateEventStatusTemporarilyForNonDefaultBranch(db.revision, frag.fragment), nil, false)
if err != nil {
@@ -3027,7 +3032,8 @@ func (db Database) LoadRebasePlan(ctx *sql.Context) (*rebase.RebasePlan, error)
Column: expression.NewGetField(0, rebaseSchema[0].Type, "rebase_order", false),
Order: sql.Ascending,
}}, resolvedTable)
iter, err := rowexec.DefaultBuilder.Build(ctx, sort, nil)
engOverrides := overrides.EngineOverridesFromContext(ctx)
iter, err := rowexec.NewBuilder(nil, engOverrides).Build(ctx, sort, nil)
if err != nil {
return nil, err
}
@@ -61,6 +61,7 @@ type DoltDatabaseProvider struct {
isStandby *bool
mu *sync.RWMutex
droppedDatabaseManager *droppedDatabaseManager
overrides sql.EngineOverrides
defaultBranch string
dbFactoryUrl string
@@ -92,21 +93,21 @@ func (p *DoltDatabaseProvider) WithTableFunctions(fns ...sql.TableFunction) (sql
// NewDoltDatabaseProvider returns a new provider, initialized without any databases, along with any
// errors that occurred while trying to create the database provider.
func NewDoltDatabaseProvider(defaultBranch string, fs filesys.Filesys) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, nil, nil)
func NewDoltDatabaseProvider(defaultBranch string, fs filesys.Filesys, overrides sql.EngineOverrides) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, nil, nil, overrides)
}
// NewDoltDatabaseProviderWithDatabase returns a new provider, initialized with one database at the
// specified location, and any error that occurred along the way.
func NewDoltDatabaseProviderWithDatabase(defaultBranch string, fs filesys.Filesys, database dsess.SqlDatabase, dbLocation filesys.Filesys) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, []dsess.SqlDatabase{database}, []filesys.Filesys{dbLocation})
func NewDoltDatabaseProviderWithDatabase(defaultBranch string, fs filesys.Filesys, database dsess.SqlDatabase, dbLocation filesys.Filesys, overrides sql.EngineOverrides) (*DoltDatabaseProvider, error) {
return NewDoltDatabaseProviderWithDatabases(defaultBranch, fs, []dsess.SqlDatabase{database}, []filesys.Filesys{dbLocation}, overrides)
}
// NewDoltDatabaseProviderWithDatabases returns a new provider, initialized with the specified databases,
// at the specified locations. For every database specified, there must be a corresponding filesystem
// specified that represents where the database is located. If the number of specified databases is not the
// same as the number of specified locations, an error is returned.
func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys) (*DoltDatabaseProvider, error) {
func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Filesys, databases []dsess.SqlDatabase, locations []filesys.Filesys, overrides sql.EngineOverrides) (*DoltDatabaseProvider, error) {
if len(databases) != len(locations) {
return nil, fmt.Errorf("unable to create DoltDatabaseProvider: "+
"incorrect number of databases (%d) and database locations (%d) specified", len(databases), len(locations))
@@ -159,6 +160,7 @@ func NewDoltDatabaseProviderWithDatabases(defaultBranch string, fs filesys.Files
dbFactoryUrl: dbFactoryUrl,
isStandby: new(bool),
droppedDatabaseManager: newDroppedDatabaseManager(fs),
overrides: overrides,
}, nil
}
@@ -1465,6 +1467,11 @@ func (p *DoltDatabaseProvider) ensureReplicaHeadExists(ctx *sql.Context, branch
return db.CreateLocalBranchFromRemote(ctx, ref.NewBranchRef(branch))
}
// EngineOverrides returns the overrides that were given during the creation of the provider.
func (p *DoltDatabaseProvider) EngineOverrides() sql.EngineOverrides {
return p.overrides
}
// isBranch returns whether a branch with the given name is in scope for the database given
func isBranch(ctx context.Context, db dsess.SqlDatabase, branchName string) (string, bool, error) {
ddbs := db.DoltDatabases()
@@ -32,7 +32,7 @@ const DeprecatedHashOfFuncName = "hashof"
const HashOfFuncName = "dolt_hashof"
type HashOf struct {
expression.UnaryExpression
expression.UnaryExpressionStub
name string
}
@@ -47,7 +47,7 @@ func NewHashOfFunc(name string) sql.CreateFunc1Args {
// newHashOf creates a new HashOf expression.
func newHashOf(e sql.Expression, name string) sql.Expression {
return &HashOf{expression.UnaryExpression{Child: e}, name}
return &HashOf{expression.UnaryExpressionStub{Child: e}, name}
}
// Eval implements the Expression interface.
@@ -29,14 +29,14 @@ import (
const HashOfTableFuncName = "dolt_hashof_table"
type HashOfTable struct {
expression.UnaryExpression
expression.UnaryExpressionStub
}
var _ sql.FunctionExpression = (*HashOfTable)(nil)
// NewHashOfTable creates a new HashOfTable expression.
func NewHashOfTable(e sql.Expression) sql.Expression {
return &HashOfTable{expression.UnaryExpression{Child: e}}
return &HashOfTable{expression.UnaryExpressionStub{Child: e}}
}
// Eval implements the Expression interface.
@@ -79,7 +79,7 @@ func (c *JoinCost) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
pro := dSess.Provider()
eng := gms.NewDefault(pro)
binder := planbuilder.New(ctx, eng.Analyzer.Catalog, eng.EventScheduler, eng.Parser)
binder := planbuilder.New(ctx, eng.Analyzer.Catalog, eng.EventScheduler)
parsed, _, _, qFlags, err := binder.Parse(q, nil, false)
if err != nil {
return nil, err
@@ -15,260 +15,340 @@
package dprocedures
import (
"errors"
"fmt"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/cmd/dolt/doltversion"
"github.com/dolthub/dolt/go/cmd/dolt/errhand"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/datas/pull"
"github.com/dolthub/dolt/go/store/types"
)
const (
DoltBackupFuncName = "dolt_backup"
DoltBackupProcedureName = "dolt_backup"
statusOk = 0
statusErr = 1
DoltBackupParamAdd = "add"
DoltBackupParamRemove = "remove"
DoltBackupParamRm = "rm"
DoltBackupParamSync = "sync"
DoltBackupParamSyncUrl = "sync-url"
DoltBackupParamRestore = "restore"
)
// doltBackup is the stored procedure version for the CLI command `dolt backup`.
var awsParamsUsage = []string{
fmt.Sprintf("--%s=<region>", dbfactory.AWSRegionParam),
fmt.Sprintf("--%s=<type>", dbfactory.AWSCredsTypeParam),
fmt.Sprintf("--%s=<file>", dbfactory.AWSCredsFileParam),
fmt.Sprintf("--%s=<profile>", dbfactory.AWSCredsProfile),
}
// doltBackup implements backup operations for Dolt databases. It routes |args| to the appropriate operation handler
// based on the first argument. The procedure requires superuser privileges and write access to the current database.
// Supported operations are: add, remove/rm, sync, sync-url, and restore.
func doltBackup(ctx *sql.Context, args ...string) (sql.RowIter, error) {
res, err := doDoltBackup(ctx, args)
apr, err := cli.CreateBackupArgParser().Parse(args)
if err != nil {
return nil, err
}
return rowToIter(int64(res)), nil
}
func doDoltBackup(ctx *sql.Context, args []string) (int, error) {
dbName := ctx.GetCurrentDatabase()
if len(dbName) == 0 {
return statusErr, fmt.Errorf("Empty database name.")
}
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return statusErr, err
if apr.NArg() == 0 || (apr.NArg() == 1 && apr.Contains(cli.VerboseFlag)) {
return nil, fmt.Errorf("use '%s' table to list backups", doltdb.BackupsTableName)
}
apr, err := cli.CreateBackupArgParser().Parse(args)
var dbName string
funcParam := apr.Arg(0)
if funcParam != DoltBackupParamRestore {
dbName = ctx.GetCurrentDatabase()
if dbName == "" {
return nil, fmt.Errorf("empty database name")
}
}
err = branch_control.CheckAccess(ctx, branch_control.Permissions_Write)
if err != nil {
return statusErr, err
return nil, err
}
invalidParams := []string{dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile, dbfactory.AWSCredsTypeParam, dbfactory.AWSRegionParam}
for _, param := range invalidParams {
if apr.Contains(param) {
return statusErr, fmt.Errorf("parameter '%s' is not supported when running this command via SQL", param)
}
if sqlserver.RunningInServerMode() && apr.ContainsAny(cli.AwsParams...) {
return nil, fmt.Errorf("AWS parameters are unavailable when running in server mode")
}
sess := dsess.DSessFromSess(ctx.Session)
dbData, ok := sess.GetDbData(ctx, dbName)
if !ok {
return statusErr, sql.ErrDatabaseNotFound.New(dbName)
doltSess := dsess.DSessFromSess(ctx.Session)
dbData, ok := doltSess.GetDbData(ctx, dbName)
if !ok && funcParam != DoltBackupParamRestore {
return nil, sql.ErrDatabaseNotFound.New(dbName)
}
if apr.NArg() == 0 {
return statusErr, fmt.Errorf("listing existing backup endpoints in sql is not currently implemented. Let us know if you need this by opening a GitHub issue: https://github.com/dolthub/dolt/issues")
}
switch apr.Arg(0) {
case cli.AddBackupId:
err = addBackup(ctx, dbData, apr)
if err != nil {
return statusErr, fmt.Errorf("error adding backup: %w", err)
switch funcParam {
case DoltBackupParamAdd:
if apr.NArg() != 3 {
return nil, errDoltBackupUsage(funcParam, []string{"name", "url"}, awsParamsUsage)
}
case cli.RemoveBackupId, cli.RemoveBackupShortId:
err = removeBackup(ctx, dbData, apr)
if err != nil {
return statusErr, fmt.Errorf("error removing backup: %w", err)
err = doltBackupAdd(ctx, dbData, doltSess, apr)
case DoltBackupParamRemove, DoltBackupParamRm:
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"name"}, nil)
}
case cli.RestoreBackupId:
if err = restoreBackup(ctx, dbData, apr); err != nil {
return statusErr, fmt.Errorf("error restoring backup: %w", err)
name := apr.Arg(1)
err = dbData.Rsw.RemoveBackup(ctx, name)
case DoltBackupParamSync:
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"name"}, nil)
}
case cli.SyncBackupUrlId:
err = syncBackupViaUrl(ctx, dbData, sess, apr)
if err != nil {
return statusErr, fmt.Errorf("error syncing backup url: %w", err)
name := apr.Arg(1)
err = doltBackupSync(ctx, dbData, doltSess, name)
case DoltBackupParamSyncUrl:
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"remote_url"}, awsParamsUsage)
}
case cli.SyncBackupId:
err = syncBackupViaName(ctx, dbData, sess, apr)
if err != nil {
return statusErr, fmt.Errorf("error syncing backup: %w", err)
err = doltBackupSyncUrl(ctx, dbData, doltSess, apr)
case DoltBackupParamRestore:
if apr.NArg() != 3 {
forceParamUsage := []string{fmt.Sprintf("--%s", cli.ForceFlag)}
return nil, errDoltBackupUsage(funcParam, []string{"remote_url", "new_db_name"}, append(forceParamUsage, awsParamsUsage...))
}
err = doltBackupRestore(ctx, dbData, doltSess, apr)
default:
return statusErr, fmt.Errorf("unrecognized dolt_backup parameter: %s", apr.Arg(0))
return nil, fmt.Errorf("unrecognized %s parameter '%s'", DoltBackupProcedureName, funcParam)
}
return statusOk, nil
return rowToIter(int64(0)), err
}
func addBackup(ctx *sql.Context, dbData env.DbData[*sql.Context], apr *argparser.ArgParseResults) error {
if apr.NArg() != 3 {
return fmt.Errorf("usage: dolt_backup('add', 'backup_name', 'backup-url')")
}
backupName := strings.TrimSpace(apr.Arg(1))
backupUrl := apr.Arg(2)
cfg := loadConfig(ctx)
scheme, absBackupUrl, err := env.GetAbsRemoteUrl(filesys.LocalFS, cfg, backupUrl)
if err != nil {
return fmt.Errorf("error: '%s' is not valid, %s", backupUrl, err.Error())
} else if scheme == dbfactory.HTTPScheme || scheme == dbfactory.HTTPSScheme {
// not sure how to get the dialer so punting on this
return fmt.Errorf("sync-url does not support http or https backup locations currently")
}
params, err := cli.ProcessBackupArgs(apr, scheme, absBackupUrl)
// doltBackupAdd adds a new backup entry with the name and URL specified in |apr|. The URL is normalized to an absolute
// path. AWS parameters are extracted from command-line flags in |apr| if present, otherwise they are loaded from
// session variables if the URL scheme matches.
func doltBackupAdd(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
backupName := apr.Arg(1)
backupUrlScheme, backupUrl, err := newAbsRemoteUrl(dsess, apr.Arg(2))
if err != nil {
return err
}
r := env.NewRemote(backupName, absBackupUrl, params)
err = dbData.Rsw.AddBackup(r)
switch err {
case nil:
return nil
case env.ErrBackupAlreadyExists:
return fmt.Errorf("error: a backup named '%s' already exists, remove it before running this command again", r.Name)
case env.ErrBackupNotFound:
return fmt.Errorf("error: unknown backup: '%s' ", r.Name)
case env.ErrInvalidBackupURL:
return fmt.Errorf("error: '%s' is not valid, cause: %s", r.Url, err.Error())
case env.ErrInvalidBackupName:
return fmt.Errorf("error: invalid backup name: '%s'", r.Name)
default:
return fmt.Errorf("error: Unable to save changes, cause: %s", err.Error())
}
}
func restoreBackup(ctx *sql.Context, _ env.DbData[*sql.Context], apr *argparser.ArgParseResults) error {
if apr.NArg() != 3 {
return fmt.Errorf("usage: dolt_backup('restore', 'backup_url', 'database_name')")
}
// Only allow admins to restore a database
if err := checkBackupRestorePrivs(ctx); err != nil {
return err
}
backupUrl := strings.TrimSpace(apr.Arg(1))
dbName := strings.TrimSpace(apr.Arg(2))
force := apr.Contains(cli.ForceFlag)
sess := dsess.DSessFromSess(ctx.Session)
params, err := loadAwsParams(ctx, sess, apr, backupUrl, "restore")
backupParams, err := newParams(apr, backupUrl, backupUrlScheme)
if err != nil {
return err
}
r := env.NewRemote("", backupUrl, params)
srcDb, err := r.GetRemoteDB(ctx, types.Format_Default, nil)
if err != nil {
return err
}
existingDbData, restoringExistingDb := sess.GetDbData(ctx, dbName)
if restoringExistingDb {
if !force {
return fmt.Errorf("error: cannot restore backup into %s. "+
"A database with that name already exists. Did you mean to supply --force?", dbName)
}
return syncRootsFromBackup(ctx, existingDbData, sess, r)
} else {
// Track whether the db directory existed before we tried to create it, so we can clean up on errors
userDirExisted, _ := sess.Provider().FileSystem().Exists(dbName)
// Create a new Dolt env for the clone; use env.NoRemote to avoid origin upstream
clonedEnv, err := actions.EnvForClone(ctx, srcDb.ValueReadWriter().Format(), env.NoRemote, dbName,
sess.Provider().FileSystem(), doltversion.Version, env.GetCurrentUserHomeDir)
if err != nil {
return errhand.VerboseErrorFromError(err)
}
// make empty repo state
_, err = env.CreateRepoState(clonedEnv.FS, env.DefaultInitBranch)
if len(backupParams) == 0 && backupUrlScheme == dbfactory.AWSScheme {
backupParams, err = newParamsWithAwsSessionVars(ctx, backupUrlScheme)
if err != nil {
return err
}
}
if err = syncRootsFromBackup(ctx, clonedEnv.DbData(ctx), sess, r); err != nil {
// If we're cloning into a directory that already exists do not erase it.
// Otherwise, make a best effort to delete any directory we created.
if userDirExisted {
_ = clonedEnv.FS.Delete(dbfactory.DoltDir, true)
} else {
_ = clonedEnv.FS.Delete(".", true)
}
}
backupRemote := env.NewRemote(backupName, backupUrl, backupParams)
err = dbData.Rsw.AddBackup(backupRemote)
return err
}
// doltBackupSync syncs the current database to an existing backup identified by name in |apr|. The backup is looked up
// from the repository state via |dbData.Rsr|. The sync operation copies all roots from the current database to the
// backup location, overwriting any existing data.
func doltBackupSync(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, backupName string) error {
backups, err := dbData.Rsr.GetBackups()
if err != nil {
return err
}
}
func removeBackup(ctx *sql.Context, dbData env.DbData[*sql.Context], apr *argparser.ArgParseResults) error {
if apr.NArg() != 2 {
return fmt.Errorf("usage: dolt_backup('remove', 'backup_name')")
backupRemote, ok := backups.Get(backupName)
if !ok {
return env.ErrBackupNotFound.New(backupName)
}
backupName := strings.TrimSpace(apr.Arg(1))
err := dbData.Rsw.RemoveBackup(ctx, backupName)
switch err {
case nil:
return nil
case env.ErrFailedToWriteRepoState:
return fmt.Errorf("error: failed to save change to repo state, cause: %s", err.Error())
case env.ErrFailedToDeleteBackup:
return fmt.Errorf("error: failed to delete backup tracking ref, cause: %s", err.Error())
case env.ErrFailedToReadFromDb:
return fmt.Errorf("error: failed to read from db, cause: %s", err.Error())
case env.ErrBackupNotFound:
return fmt.Errorf("error: unknown backup: '%s' ", backupName)
default:
return fmt.Errorf("error: unknown error, cause: %s", err.Error())
}
return syncRemote(ctx, dbData, dsess, backupRemote)
}
func loadAwsParams(ctx *sql.Context, sess *dsess.DoltSession, apr *argparser.ArgParseResults, backupUrl, backupCmd string) (map[string]string, error) {
cfg := loadConfig(ctx)
scheme, absBackupUrl, err := env.GetAbsRemoteUrl(filesys.LocalFS, cfg, backupUrl)
// doltBackupSyncUrl syncs the current database to a remote URL specified in |apr| without requiring the remote to exist
// in the backups list. The URL is normalized to an absolute path. AWS parameters are extracted from command-line flags
// in |apr| if present, otherwise they are loaded from session variables if the URL scheme matches. The sync operation
// copies all roots from the current database to the remote location, overwriting any existing data.
func doltBackupSyncUrl(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
remoteUrlScheme, remoteUrl, err := newAbsRemoteUrl(dsess, apr.Arg(1))
if err != nil {
return nil, fmt.Errorf("error: '%s' is not valid.", backupUrl)
} else if scheme == dbfactory.HTTPScheme || scheme == dbfactory.HTTPSScheme {
// not sure how to get the dialer so punting on this
return nil, fmt.Errorf("%s does not support http or https backup locations currently", backupCmd)
return err
}
params, err := cli.ProcessBackupArgs(apr, scheme, absBackupUrl)
remoteParams, err := newParams(apr, remoteUrl, remoteUrlScheme)
if err != nil {
return err
}
if len(remoteParams) == 0 && remoteUrlScheme == dbfactory.AWSScheme {
remoteParams, err = newParamsWithAwsSessionVars(ctx, remoteUrlScheme)
if err != nil {
return err
}
}
remote := env.NewRemote(DoltBackupParamSyncUrl, remoteUrl, remoteParams)
return syncRemote(ctx, dbData, dsess, remote)
}
// doltBackupRestore clones a database from the remote URL specified in |apr| into a new database with the name
// specified. The URL is normalized to an absolute path. AWS parameters are extracted from command-line flags in |apr|
// if present. If no command-line parameters are provided, AWS parameters are loaded from session variables if the URL
// scheme matches.
//
// If the target database already exists, the restore operation fails unless the --force flag is provided, in which case
// the existing database is dropped before cloning.
func doltBackupRestore(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
remoteUrlScheme, remoteUrl, err := newAbsRemoteUrl(dsess, apr.Arg(1))
if err != nil {
return err
}
remoteParams, err := newParams(apr, remoteUrl, remoteUrlScheme)
if err != nil {
return err
}
if len(remoteParams) == 0 && remoteUrlScheme == dbfactory.AWSScheme {
remoteParams, err = newParamsWithAwsSessionVars(ctx, remoteUrlScheme)
if err != nil {
return err
}
}
remote := env.NewRemote(DoltBackupParamRestore, remoteUrl, remoteParams)
// Use default format if no database context is available (e.g., when run from invalid directory).
format := types.Format_Default
if dbData.Ddb != nil {
format = dbData.Ddb.Format()
}
remoteDb, err := dsess.Provider().GetRemoteDB(ctx, format, remote, true)
if err != nil {
return err
}
lookupDbName := apr.Arg(2)
hasLookupDb := dsess.Provider().HasDatabase(ctx, lookupDbName)
// We can't only check the databases from memory since this command can be run from subdirectories.
fileSys := dsess.GetFileSystem()
lookupDbInFileSys, _ := fileSys.Exists(lookupDbName)
forceRestore := apr.Contains(cli.ForceFlag)
if (hasLookupDb || lookupDbInFileSys) && !forceRestore {
return fmt.Errorf("database '%s' already exists, use '--%s' to overwrite", lookupDbName, cli.ForceFlag)
}
if hasLookupDb {
err = dsess.Provider().DropDatabase(ctx, lookupDbName)
if err != nil {
return err
}
}
if lookupDbInFileSys && !hasLookupDb {
err = fileSys.Delete(lookupDbName, forceRestore)
if err != nil {
return err
}
}
err = dsess.Provider().CreateDatabase(ctx, lookupDbName)
if err != nil {
return err
}
newDb, _, err := dsess.Provider().SessionDatabase(ctx, lookupDbName)
if err != nil {
return err
}
// Unlike CloneDatabaseFromRemote which clones tracking branches (remote refs), we need all local changes.
return actions.SyncRoots(ctx, remoteDb, newDb.DbData().Ddb, fileSys.TempDir(), runProgFuncs, stopProgFuncs)
}
// syncRemote syncs the roots from |dbData| to the remote specified by |remote|. It prepares the remote database
// location using PrepareDB, which creates directories for file:// URLs if they do not exist. The sync operation copies
// all chunks from the source database to the destination, effectively overwriting the destination to match the source.
func syncRemote(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, remote env.Remote) error {
// Commit the current session's working set to the persistent chunk store. This ensures that uncommitted transaction
// changes (e.g. INSERTs) are usually visible to the backup procedure, which reads directly from the roots.
err := dsess.CommitWorkingSet(ctx, ctx.GetCurrentDatabase(), ctx.GetTransaction())
if err != nil {
return err
}
params := map[string]interface{}{}
for k, v := range remote.Params {
params[k] = v
}
// This fails with unsupported schemes (i.e. http[s]), but in such cases we shouldn't have to prepare the database.
// We primarily use this to initialize the directory for file URLs without a directory.
_ = dbfactory.PrepareDB(ctx, dbData.Ddb.Format(), remote.Url, params)
destDb, err := dsess.Provider().GetRemoteDB(ctx, dbData.Ddb.Format(), remote, true)
if err != nil {
return err
}
err = actions.SyncRoots(ctx, dbData.Ddb, destDb, dsess.GetFileSystem().TempDir(), runProgFuncs, stopProgFuncs)
if err != nil && !errors.Is(err, pull.ErrDBUpToDate) {
return err
}
return nil
}
// newParams extracts AWS-specific parameters from command-line flags in |apr| if |urlScheme| is AWS. If the scheme is
// not AWS, it verifies that no AWS parameters are present in |apr|.
func newParams(apr *argparser.ArgParseResults, url string, urlScheme string) (map[string]string, error) {
params := map[string]string{}
var err error
switch urlScheme {
case dbfactory.AWSScheme:
err = cli.AddAWSParams(url, apr, params)
case dbfactory.OSSScheme:
// TODO(elianddb): This func mainly interfaces with apr to set the OSS key-vals in params, but the backup arg
// parser does not include any OSS-related flags? I'm guessing they must be processed elsewhere?
err = cli.AddOSSParams(url, apr, params)
default:
err = cli.VerifyNoAwsParams(apr)
}
return params, err
}
// newParamsWithAwsSessionVars extracts AWS-specific parameters from read-only session variables in |ctx|. It reads
// aws_credentials_file, aws_credentials_profile, and aws_credentials_region session variables and builds a parameter
// map. If URL scheme is not AWS, an empty parameter map is returned.
func newParamsWithAwsSessionVars(ctx *sql.Context, urlScheme string) (map[string]string, error) {
params := map[string]string{}
credsFile, err := ctx.Session.GetSessionVariable(ctx, dsess.AwsCredsFile)
if err != nil {
return nil, err
}
credsFile, _ := sess.GetSessionVariable(ctx, dsess.AwsCredsFile)
credsFileStr, isStr := credsFile.(string)
if isStr && len(credsFileStr) > 0 {
params[dbfactory.AWSCredsFileParam] = credsFileStr
}
credsProfile, err := sess.GetSessionVariable(ctx, dsess.AwsCredsProfile)
credsProfile, err := ctx.Session.GetSessionVariable(ctx, dsess.AwsCredsProfile)
if err != nil {
return nil, err
}
profStr, isStr := credsProfile.(string)
if isStr && len(profStr) > 0 {
params[dbfactory.AWSCredsProfile] = profStr
}
credsRegion, err := sess.GetSessionVariable(ctx, dsess.AwsCredsRegion)
credsRegion, err := ctx.Session.GetSessionVariable(ctx, dsess.AwsCredsRegion)
if err != nil {
return nil, err
}
regionStr, isStr := credsRegion.(string)
if isStr && len(regionStr) > 0 {
params[dbfactory.AWSRegionParam] = regionStr
@@ -277,104 +357,46 @@ func loadAwsParams(ctx *sql.Context, sess *dsess.DoltSession, apr *argparser.Arg
return params, nil
}
func syncBackupViaUrl(ctx *sql.Context, dbData env.DbData[*sql.Context], sess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 2 {
return fmt.Errorf("usage: dolt_backup('sync-url', BACKUP_URL)")
// newAbsRemoteUrl normalizes the |url| to an absolute path and returns the URL scheme and the normalized URL. It loads
// the Dolt CLI configuration from the filesystem accessible via |dsess| and uses GetAbsRemoteUrl to perform the
// normalization. HTTPS URLs without an explicit scheme default to the configured remotes API host.
func newAbsRemoteUrl(dsess *dsess.DoltSession, url string) (string, string, error) {
if url == "" {
return "", "", env.ErrBackupInvalidUrl.New(url)
}
backupUrl := strings.TrimSpace(apr.Arg(1))
params, err := loadAwsParams(ctx, sess, apr, backupUrl, "sync-url")
config, err := env.LoadDoltCliConfig(env.GetCurrentUserHomeDir, dsess.GetFileSystem())
if err != nil {
return err
return "", "", err
}
b := env.NewRemote("__temp__", backupUrl, params)
return syncRootsToBackup(ctx, dbData, sess, b)
return env.GetAbsRemoteUrl(dsess.GetFileSystem(), config, url)
}
func syncBackupViaName(ctx *sql.Context, dbData env.DbData[*sql.Context], sess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 2 {
return fmt.Errorf("usage: dolt_backup('sync', BACKUP_NAME)")
// errDoltBackupUsage constructs a usage error message for the dolt_backup procedure. It formats |funcParam| as the
// operation, |requiredParams| as required positional arguments, and |optionalParams| as optional flag arguments. The
// resulting error message follows the format:
// "usage: dolt_backup('<param>', '<required1>', ..., ['<optional1>'], ...)".
func errDoltBackupUsage(funcParam string, requiredParams, optionalParams []string) error {
var builder strings.Builder
builder.WriteString("usage: ")
builder.WriteString(DoltBackupProcedureName)
builder.WriteString("('")
builder.WriteString(funcParam)
builder.WriteByte('\'')
for _, req := range requiredParams {
builder.WriteString(", '")
builder.WriteString(req)
builder.WriteByte('\'')
}
backupName := strings.TrimSpace(apr.Arg(1))
backups, err := dbData.Rsr.GetBackups()
if err != nil {
return err
for _, opt := range optionalParams {
builder.WriteString(", ['")
builder.WriteString(opt)
builder.WriteString("']")
}
b, ok := backups.Get(backupName)
if !ok {
return fmt.Errorf("error: unknown backup: '%s'; %v", backupName, backups)
}
builder.WriteByte(')')
return syncRootsToBackup(ctx, dbData, sess, b)
}
// syncRootsToBackup syncs the roots from |dbData| to the backup specified by |backup|.
func syncRootsToBackup(ctx *sql.Context, dbData env.DbData[*sql.Context], sess *dsess.DoltSession, backup env.Remote) error {
destDb, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), backup, true)
if err != nil {
return fmt.Errorf("error loading backup destination: %w", err)
}
tmpDir, err := dbData.Rsw.TempTableFilesDir()
if err != nil {
return err
}
err = actions.SyncRoots(ctx, dbData.Ddb, destDb, tmpDir, runProgFuncs, stopProgFuncs)
if err != nil && err != pull.ErrDBUpToDate {
return fmt.Errorf("error syncing backup: %w", err)
}
return nil
}
// syncRootsFromBackup syncs the roots from the backup specified by |backup| to |dbData|.
func syncRootsFromBackup[C doltdb.Context](ctx *sql.Context, dbData env.DbData[C], sess *dsess.DoltSession, backup env.Remote) error {
destDb, err := sess.Provider().GetRemoteDB(ctx, dbData.Ddb.ValueReadWriter().Format(), backup, true)
if err != nil {
return fmt.Errorf("error loading backup destination: %w", err)
}
tmpDir, err := dbData.Rsw.TempTableFilesDir()
if err != nil {
return err
}
err = actions.SyncRoots(ctx, destDb, dbData.Ddb, tmpDir, runProgFuncs, stopProgFuncs)
if err != nil && err != pull.ErrDBUpToDate {
return fmt.Errorf("error syncing backup: %w", err)
}
return nil
}
// UserHasSuperAccess returns whether the current user has SUPER access. This is used by
// Doltgres to check the user role by its own authentication methods.
var UserHasSuperAccess = userHasSuperAccess
func userHasSuperAccess(ctx *sql.Context) (bool, error) {
privs, counter := ctx.GetPrivilegeSet()
if counter == 0 {
return false, fmt.Errorf("unable to check user privileges")
}
return privs.Has(sql.PrivilegeType_Super) == true, nil
}
// checkBackupRestorePrivs returns an error if the user requesting to restore a database
// does not have SUPER access. Since this is a potentially destructive operation, we restrict it to admins,
// even though the SUPER privilege has been deprecated, since there isn't another appropriate global privilege.
func checkBackupRestorePrivs(ctx *sql.Context) error {
isSuper, err := UserHasSuperAccess(ctx)
if err != nil {
return fmt.Errorf("error in dolt_backup() restore subcommand: %w", err)
}
if !isSuper {
return sql.ErrPrivilegeCheckFailed.New(ctx.Session.Client().User)
}
return nil
return errors.New(builder.String())
}
@@ -41,7 +41,7 @@ func doDoltClean(ctx *sql.Context, args []string) (int, error) {
return 1, fmt.Errorf("Empty database name.")
}
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return statusErr, err
return 1, err
}
dSess := dsess.DSessFromSess(ctx.Session)
@@ -118,7 +118,13 @@ func doDoltMerge(ctx *sql.Context, args []string) (string, int, int, string, err
}
if apr.ContainsAll(cli.SquashParam, cli.NoFFParam) {
return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together.\n", cli.SquashParam, cli.NoFFParam)
return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together", cli.SquashParam, cli.NoFFParam)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.NoFFParam) {
return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together", cli.FFOnlyParam, cli.NoFFParam)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.SquashParam) {
return "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together", cli.FFOnlyParam, cli.SquashParam)
}
ws, err := sess.WorkingSet(ctx, dbName)
@@ -226,7 +232,7 @@ func performMerge(
}
if canFF {
if spec.NoFF {
if spec.FFMode == merge.NoFastForward {
var commit *doltdb.Commit
ws, commit, err = executeNoFFMerge(ctx, sess, spec, msg, dbName, ws, noCommit)
if err == doltdb.ErrUnresolvedConflictsOrViolations {
@@ -264,6 +270,10 @@ func performMerge(
return ws, h.String(), noConflictsOrViolations, fastForwardMerge, "merge successful", nil
}
if spec.FFMode == merge.FastForwardOnly {
return ws, "", noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("fatal: Not possible to fast-forward, aborting")
}
dbState, ok, err := sess.LookupDbState(ctx, dbName)
if err != nil {
return ws, "", noConflictsOrViolations, threeWayMerge, "", err
@@ -482,6 +492,15 @@ func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, a
if apr.Contains(cli.NoCommitFlag) && apr.Contains(cli.CommitFlag) {
return nil, errors.New("cannot define both 'commit' and 'no-commit' flags at the same time")
}
// Determine FastForwardMode based on flags. validation of mutually exclusive flags done earlier
var ffMode merge.FastForwardMode = merge.FastForwardDefault
if apr.Contains(cli.NoFFParam) {
ffMode = merge.NoFastForward
} else if apr.Contains(cli.FFOnlyParam) {
ffMode = merge.FastForwardOnly
}
return merge.NewMergeSpec(
ctx,
dbData.Rsr,
@@ -492,7 +511,7 @@ func createMergeSpec(ctx *sql.Context, sess *dsess.DoltSession, dbName string, a
commitSpecStr,
t,
merge.WithSquash(apr.Contains(cli.SquashParam)),
merge.WithNoFF(apr.Contains(cli.NoFFParam)),
merge.WithFastForwardMode(ffMode),
merge.WithForce(apr.Contains(cli.ForceFlag)),
merge.WithNoCommit(apr.Contains(cli.NoCommitFlag)),
merge.WithNoEdit(apr.Contains(cli.NoEditFlag)),
@@ -96,6 +96,14 @@ func doDoltPull(ctx *sql.Context, args []string) (int, int, string, error) {
return noConflictsOrViolations, threeWayMerge, "", actions.ErrInvalidPullArgs
}
// Validate conflicting flags
if apr.ContainsAll(cli.FFOnlyParam, cli.NoFFParam) {
return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together", cli.FFOnlyParam, cli.NoFFParam)
}
if apr.ContainsAll(cli.FFOnlyParam, cli.SquashParam) {
return noConflictsOrViolations, threeWayMerge, "", fmt.Errorf("error: Flags '--%s' and '--%s' cannot be used together", cli.FFOnlyParam, cli.SquashParam)
}
var remoteName, remoteRefName string
if apr.NArg() == 1 {
remoteName = apr.Arg(0)
@@ -27,11 +27,6 @@ func doltPurgeDroppedDatabases(ctx *sql.Context, args ...string) (sql.RowIter, e
return nil, fmt.Errorf("dolt_purge_dropped_databases does not take any arguments")
}
// Only allow admins to purge dropped databases
if err := checkDoltPurgeDroppedDatabasesPrivs(ctx); err != nil {
return nil, err
}
doltSession := dsess.DSessFromSess(ctx.Session)
err := doltSession.Provider().PurgeDroppedDatabases(ctx)
if err != nil {
@@ -40,18 +35,3 @@ func doltPurgeDroppedDatabases(ctx *sql.Context, args ...string) (sql.RowIter, e
return rowToIter(int64(cmdSuccess)), nil
}
// checkDoltPurgeDroppedDatabasesPrivs returns an error if the user requesting to purge dropped databases
// does not have SUPER access. Since this is a permanent and destructive operation, we restrict it to admins,
// even though the SUPER privilege has been deprecated, since there isn't another appropriate global privilege.
func checkDoltPurgeDroppedDatabasesPrivs(ctx *sql.Context) error {
privs, counter := ctx.GetPrivilegeSet()
if counter == 0 {
return fmt.Errorf("unable to check user privileges for dolt_purge_dropped_databases procedure")
}
if privs.Has(sql.PrivilegeType_Super) == false {
return sql.ErrPrivilegeCheckFailed.New(ctx.Session.Client().User)
}
return nil
}
@@ -318,3 +318,7 @@ func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName s
func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
}
func (e emptyRevisionDatabaseProvider) EngineOverrides() sql.EngineOverrides {
return sql.EngineOverrides{}
}
@@ -0,0 +1,45 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dsess
import (
"context"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
)
func init() {
// Due to import cycles, we need to set the function to a separate package, since this package is referenced in too
// many locations. We also can't have this in a higher-level package (such as sqle) due to import cycles once more.
overrides.EngineOverridesFromContext = EngineOverridesFromContext
}
// EngineOverridesFromContext is defined here due to import cycles.
func EngineOverridesFromContext(ctx context.Context) sql.EngineOverrides {
if ctx == nil {
return sql.EngineOverrides{}
}
sqlCtx, ok := ctx.(*sql.Context)
if !ok || sqlCtx == nil {
return sql.EngineOverrides{}
}
dsess, ok := sqlCtx.Session.(*DoltSession)
if !ok || dsess == nil {
return sql.EngineOverrides{}
}
return dsess.provider.EngineOverrides()
}
@@ -111,6 +111,8 @@ type DoltDatabaseProvider interface {
// PurgeDroppedDatabases permanently deletes any dropped databases that are being held in temporary storage
// in case they need to be restored. This operation is not reversible, so use with caution!
PurgeDroppedDatabases(ctx *sql.Context) error
// EngineOverrides returns the overrides that were given during the creation of the provider.
EngineOverrides() sql.EngineOverrides
}
type SessionDatabaseBranchSpec struct {
@@ -121,8 +121,8 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
return env.ErrRemoteAlreadyExists
}
if strings.IndexAny(remote.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
return env.ErrInvalidBackupName
if strings.IndexAny(remote.Name, env.InvalidRemoteNameCharacters) != -1 {
return env.ErrInvalidRemoteName
}
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
@@ -137,7 +137,7 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
// can have multiple remotes with the same address, but no conflicting backups
if rem, found := env.CheckRemoteAddressConflict(remote.Url, nil, repoState.Backups); found {
return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, rem.Name, rem.Url)
return env.ErrRemoteAddressConflict.New(rem.Name, rem.Url)
}
s.remotes.Set(remote.Name, remote)
@@ -145,33 +145,34 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
return repoState.Save(fs)
}
func (s SessionStateAdapter) AddBackup(backup env.Remote) error {
if _, ok := s.backups.Get(backup.Name); ok {
return env.ErrBackupAlreadyExists
func (s SessionStateAdapter) AddBackup(remote env.Remote) error {
if remote.Name == "" || strings.IndexAny(remote.Name, env.InvalidRemoteNameCharacters) != -1 {
return env.ErrBackupInvalidName.New(remote.Name)
}
if strings.IndexAny(backup.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
return env.ErrInvalidBackupName
if _, ok := s.backups.Get(remote.Name); ok {
return env.ErrBackupAlreadyExists.New(remote.Name)
}
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
if conflict, found := env.CheckRemoteAddressConflict(remote.Url, s.remotes, s.backups); found {
return env.ErrRemoteAddressConflict.New(conflict.Name, conflict.Url)
}
s.backups.Set(remote.Name, remote)
fileSys, err := s.session.Provider().FileSystemForDatabase(s.dbName)
if err != nil {
return err
}
parsedRepoState, err := env.LoadRepoState(fileSys)
if err != nil {
return err
}
repoState, err := env.LoadRepoState(fs)
if err != nil {
return err
}
// no conflicting remote or backup addresses
if bac, found := env.CheckRemoteAddressConflict(backup.Url, repoState.Remotes, repoState.Backups); found {
return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, bac.Name, bac.Url)
}
s.backups.Set(backup.Name, backup)
repoState.AddBackup(backup)
return repoState.Save(fs)
// TODO(elianddb): This is a known limitation of repo_state.json; may lose concurrent modifications.
// See: https://www.dolthub.com/blog/2021-08-06-long-dark-rewrite-of-the-soul/
parsedRepoState.Backups = s.backups
return parsedRepoState.Save(fileSys)
}
func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error {
@@ -201,29 +202,26 @@ func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error
}
func (s SessionStateAdapter) RemoveBackup(_ context.Context, name string) error {
backup, ok := s.backups.Get(name)
_, ok := s.backups.Get(name)
if !ok {
return env.ErrBackupNotFound
return env.ErrBackupNotFound.New(name)
}
s.backups.Delete(backup.Name)
s.backups.Delete(name)
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
if err != nil {
return err
}
repoState, err := env.LoadRepoState(fs)
parsedRepoState, err := env.LoadRepoState(fs)
if err != nil {
return err
}
backup, ok = repoState.Backups.Get(name)
if !ok {
// sanity check
return env.ErrBackupNotFound
}
repoState.Backups.Delete(name)
return repoState.Save(fs)
// TODO(elianddb): This is a known limitation of repo_state.json; may lose concurrent modifications.
// See: https://www.dolthub.com/blog/2021-08-06-long-dark-rewrite-of-the-soul/
parsedRepoState.Backups = s.backups
return parsedRepoState.Save(fs)
}
func (s SessionStateAdapter) TempTableFilesDir() (string, error) {
@@ -0,0 +1,190 @@
// Copyright 2020-2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dtablefunctions
import (
"fmt"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/store/prolly/tree"
)
const jsonDiffTableDefaultRowCount = 1000
var _ sql.TableFunction = (*DiffTableFunction)(nil)
var _ sql.ExecSourceRel = (*DiffTableFunction)(nil)
var _ sql.AuthorizationCheckerNode = (*DiffTableFunction)(nil)
// JsonDiffTableFunction implements the DOLT_JSON_DIFF table function.
// It takes two arguments, which it interprets as JSON objects.
// Each row of the result table represents a key that has changed between the two documents.
type JsonDiffTableFunction struct {
fromExpr sql.Expression
toExpr sql.Expression
database sql.Database
}
// NewInstance creates a new instance of TableFunction interface
func (dtf *JsonDiffTableFunction) NewInstance(ctx *sql.Context, database sql.Database, expressions []sql.Expression) (sql.Node, error) {
newInstance := &JsonDiffTableFunction{
database: database,
}
node, err := newInstance.WithExpressions(expressions...)
if err != nil {
return nil, err
}
return node, nil
}
func (dtf *JsonDiffTableFunction) DataLength(ctx *sql.Context) (uint64, error) {
numBytesPerRow := schema.SchemaAvgLength(dtf.Schema())
numRows, _, err := dtf.RowCount(ctx)
if err != nil {
return 0, err
}
return numBytesPerRow * numRows, nil
}
func (dtf *JsonDiffTableFunction) RowCount(_ *sql.Context) (uint64, bool, error) {
return jsonDiffTableDefaultRowCount, false, nil
}
// Database implements the sql.Databaser interface
func (dtf *JsonDiffTableFunction) Database() sql.Database {
return dtf.database
}
// WithDatabase implements the sql.Databaser interface
func (dtf *JsonDiffTableFunction) WithDatabase(database sql.Database) (sql.Node, error) {
ndtf := *dtf
ndtf.database = database
return &ndtf, nil
}
// Expressions implements the sql.Expressioner interface
func (dtf *JsonDiffTableFunction) Expressions() []sql.Expression {
exprs := []sql.Expression{dtf.fromExpr, dtf.toExpr}
return exprs
}
// WithExpressions implements the sql.Expressioner interface
func (dtf *JsonDiffTableFunction) WithExpressions(expressions ...sql.Expression) (sql.Node, error) {
if len(expressions) != 2 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf, len(expressions), 2)
}
newDtf := *dtf
newDtf.fromExpr = expressions[0]
newDtf.toExpr = expressions[1]
return &newDtf, nil
}
// Children implements the sql.Node interface
func (dtf *JsonDiffTableFunction) Children() []sql.Node {
return nil
}
// RowIter implements the sql.Node interface
func (dtf *JsonDiffTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
fromValue, err := dtf.fromExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
fromJson, _, err := gmstypes.JSON.Convert(ctx, fromValue)
if err != nil {
return nil, err
}
toValue, err := dtf.toExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
toJson, _, err := gmstypes.JSON.Convert(ctx, toValue)
if err != nil {
return nil, err
}
differ, err := tree.NewJsonDiffer(ctx, fromJson.(sql.JSONWrapper), toJson.(sql.JSONWrapper))
if err != nil {
return nil, err
}
return jsonDiffRowIter{differ}, nil
}
type jsonDiffRowIter struct {
differ tree.JsonDiffer
}
func (j jsonDiffRowIter) Next(ctx *sql.Context) (sql.Row, error) {
jsonDiff, err := j.differ.Next(ctx)
if err != nil {
return nil, err
}
mySqlJsonPath := tree.MySqlJsonPathFromKey(jsonDiff.Key)
return sql.NewRow(jsonDiff.Type.DiffTypeString(), mySqlJsonPath, jsonDiff.From, jsonDiff.To), nil
}
func (j jsonDiffRowIter) Close(ctx *sql.Context) error {
return nil
}
var _ sql.RowIter = jsonDiffRowIter{}
// WithChildren implements the sql.Node interface
func (dtf *JsonDiffTableFunction) WithChildren(node ...sql.Node) (sql.Node, error) {
if len(node) != 0 {
return nil, fmt.Errorf("unexpected children")
}
return dtf, nil
}
var jsonDiffTableSchema = sql.Schema{
&sql.Column{Name: "diff_type", Type: gmstypes.Text},
&sql.Column{Name: "path", Type: gmstypes.Text},
&sql.Column{Name: "from_value", Type: gmstypes.JSON},
&sql.Column{Name: "to_value", Type: gmstypes.JSON},
}
// Schema implements the sql.Node interface
func (dtf *JsonDiffTableFunction) Schema() sql.Schema {
return jsonDiffTableSchema
}
// Resolved implements the sql.Resolvable interface
func (dtf *JsonDiffTableFunction) Resolved() bool {
return dtf.fromExpr.Resolved() && dtf.toExpr.Resolved()
}
func (dtf *JsonDiffTableFunction) IsReadOnly() bool {
return true
}
// String implements the Stringer interface
func (dtf *JsonDiffTableFunction) String() string {
return fmt.Sprintf("DOLT_JSON_DIFF(%s, %s)", dtf.fromExpr.String(), dtf.toExpr.String())
}
// Name implements the sql.TableFunction interface
func (dtf *JsonDiffTableFunction) Name() string {
return "dolt_json_diff"
}
@@ -0,0 +1,64 @@
// Copyright 2025 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dtablefunctions
import (
"io"
"testing"
"github.com/dolthub/dolt/go/store/prolly/tree"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/stretchr/testify/require"
)
// TestDoltJsonDiffTableFunction uses the same test cases as TestJsonDiff,
// but transforms them to use the DOLT_JSON_DIFF system table function
// instead.
func TestDoltJsonDiffTableFunction(t *testing.T) {
for _, testCase := range tree.SimpleJsonDiffTests {
ctx := sql.NewEmptyContext()
t.Run(testCase.Name, func(t *testing.T) {
jsonDiffTableFunction := JsonDiffTableFunction{
fromExpr: expression.NewLiteral(testCase.From, types.JSON),
toExpr: expression.NewLiteral(testCase.To, types.JSON),
database: nil,
}
rowIter, err := jsonDiffTableFunction.RowIter(ctx, nil)
require.NoError(t, err)
var expectedRows []sql.Row
for _, expectedDiff := range testCase.ExpectedDiffs {
expectedRows = append(expectedRows, sql.NewRow(
expectedDiff.Type.DiffTypeString(),
tree.MySqlJsonPathFromKey(expectedDiff.Key),
expectedDiff.From,
expectedDiff.To,
))
}
var actualRows []sql.Row
for {
row, err := rowIter.Next(ctx)
if err == io.EOF {
break
}
require.NoError(t, err)
actualRows = append(actualRows, row)
}
require.Equal(t, expectedRows, actualRows)
})
}
}
@@ -177,7 +177,7 @@ var logTableSchema = sql.Schema{
&sql.Column{Name: "commit_hash", Type: types.Text},
&sql.Column{Name: "committer", Type: types.Text},
&sql.Column{Name: "email", Type: types.Text},
&sql.Column{Name: "date", Type: types.Datetime},
&sql.Column{Name: "date", Type: types.Datetime3},
&sql.Column{Name: "message", Type: types.Text},
&sql.Column{Name: "commit_order", Type: types.Uint64},
}
@@ -35,6 +35,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dtables"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
@@ -502,7 +503,8 @@ func getPatchNodes(ctx *sql.Context, dbData env.DbData[*sql.Context], tableDelta
if cerr != nil {
return nil, cerr
}
alterDBCollStmt := sqlfmt.AlterDatabaseCollateStmt(dbName, fromColl, toColl)
formatter := overrides.SchemaFormatterFromContext(ctx)
alterDBCollStmt := sqlfmt.AlterDatabaseCollateStmt(formatter, dbName, fromColl, toColl)
patches = append(patches, &patchNode{
tblName: td.FromName,
schemaPatchStmts: []string{alterDBCollStmt},
@@ -315,14 +315,14 @@ func (ds *SchemaDiffTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.R
var fromCreate, toCreate string
if delta.FromTable != nil {
fromCreate, err = sqlfmt.GenerateCreateTableStatement(delta.FromName.Name, delta.FromSch, delta.FromFks, delta.FromFksParentSch)
fromCreate, err = sqlfmt.GenerateCreateTableStatement(ctx, delta.FromName.Name, delta.FromSch, delta.FromFks, delta.FromFksParentSch)
if err != nil {
return nil, err
}
}
if delta.ToTable != nil {
toCreate, err = sqlfmt.GenerateCreateTableStatement(delta.ToName.Name, delta.ToSch, delta.ToFks, delta.ToFksParentSch)
toCreate, err = sqlfmt.GenerateCreateTableStatement(ctx, delta.ToName.Name, delta.ToSch, delta.ToFks, delta.ToFksParentSch)
if err != nil {
return nil, err
}
@@ -29,6 +29,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
)
const testsRunDefaultRowCount = 10
@@ -303,9 +304,10 @@ func (trtf *TestsRunTableFunction) getDoltTestsData(arg string) ([]sql.Row, erro
}
func IsWriteQuery(query string, ctx *sql.Context, catalog sql.Catalog) (bool, error) {
builder := planbuilder.New(ctx, catalog, nil, nil)
builder := planbuilder.New(ctx, catalog, nil)
parsed, _, _, err := sql.GlobalParser.Parse(ctx, query, false)
parser := overrides.ParserFromContext(ctx)
parsed, _, _, err := parser.Parse(ctx, query, false)
if err != nil {
return false, err
}
@@ -29,4 +29,5 @@ var DoltTableFunctions = []sql.TableFunction{
&ReflogTableFunction{},
&QueryDiffTableFunction{},
&TestsRunTableFunction{},
&JsonDiffTableFunction{},
}
@@ -47,11 +47,11 @@ func (bt BackupsTable) String() string {
}
func (bt BackupsTable) Schema() sql.Schema {
columns := []*sql.Column{
{Name: "name", Type: types.Text, Source: bt.tableName, PrimaryKey: true, Nullable: false, DatabaseSource: bt.db.Name()},
{Name: "url", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: false, DatabaseSource: bt.db.Name()},
return []*sql.Column{
{Name: "name", Type: types.Text, PrimaryKey: true, Nullable: false},
{Name: "url", Type: types.Text, PrimaryKey: false, Nullable: false},
{Name: "params", Type: types.JSON, PrimaryKey: false, Nullable: false},
}
return columns
}
func (bt BackupsTable) Collation() sql.CollationID {
@@ -67,9 +67,10 @@ func (bt BackupsTable) PartitionRows(context *sql.Context, _ sql.Partition) (sql
}
type backupsItr struct {
urls map[string]string
names []string
idx int
names []string
urls map[string]string
params map[string]map[string]string
idx int
}
var _ sql.RowIter = (*backupsItr)(nil)
@@ -78,7 +79,14 @@ func (bi *backupsItr) Next(ctx *sql.Context) (sql.Row, error) {
if bi.idx < len(bi.names) {
bi.idx++
name := bi.names[bi.idx-1]
return sql.NewRow(name, bi.urls[name]), nil
url := bi.urls[name]
params, _, err := types.JSON.Convert(ctx, bi.params[name])
if err != nil {
return nil, err
}
return sql.NewRow(name, url, params), nil
}
return nil, io.EOF
}
@@ -103,14 +111,16 @@ func newBackupsIter(ctx *sql.Context, dbName string) (*backupsItr, error) {
names := make([]string, 0)
urls := map[string]string{}
params := map[string]map[string]string{}
backups.Iter(func(key string, val env.Remote) bool {
names = append(names, key)
urls[key] = val.Url
params[key] = val.Params
return true
})
sort.Strings(names)
return &backupsItr{names: names, urls: urls, idx: 0}, nil
return &backupsItr{names: names, urls: urls, params: params, idx: 0}, nil
}
@@ -22,6 +22,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/overrides"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlfmt"
)
@@ -80,7 +81,7 @@ func NewBlameView(ctx *sql.Context, tableName doltdb.TableName, root doltdb.Root
return "", nil
}
blameViewExpression, err := createDoltBlameViewExpression(tableName.Name, sch.GetPKCols().GetColumns())
blameViewExpression, err := createDoltBlameViewExpression(ctx, tableName.Name, sch.GetPKCols().GetColumns())
if err != nil {
return "", err
}
@@ -91,7 +92,7 @@ func NewBlameView(ctx *sql.Context, tableName doltdb.TableName, root doltdb.Root
// createDoltBlameViewExpression creates a view expression string to generate the DOLT_BLAME system
// view for the specified table, with the specified primary keys. The DOLT_BLAME system view is built
// from the data in the DOLT_DIFF system table for the same specified table name.
func createDoltBlameViewExpression(tableName string, pks []schema.Column) (string, error) {
func createDoltBlameViewExpression(ctx *sql.Context, tableName string, pks []schema.Column) (string, error) {
if len(pks) == 0 {
return "", errUnblameableTable
}
@@ -101,6 +102,7 @@ func createDoltBlameViewExpression(tableName string, pks []schema.Column) (strin
pksOrderByExpression := ""
pksSelectExpression := ""
formatter := overrides.SchemaFormatterFromContext(ctx)
for i, pk := range pks {
if i > 0 {
allToPks += ", "
@@ -108,13 +110,13 @@ func createDoltBlameViewExpression(tableName string, pks []schema.Column) (strin
pksOrderByExpression += ", "
}
toPk := sqlfmt.QuoteIdentifier("to_" + pk.Name)
fromPk := sqlfmt.QuoteIdentifier("from_" + pk.Name)
toPk := sqlfmt.QuoteIdentifier(ctx, "to_"+pk.Name)
fromPk := sqlfmt.QuoteIdentifier(ctx, "from_"+pk.Name)
allToPks += toPk
pksPartitionByExpression += fmt.Sprintf("coalesce(%s, %s)", toPk, fromPk)
pksOrderByExpression += fmt.Sprintf("sd.%s ASC ", toPk)
pksSelectExpression += fmt.Sprintf("sd.%s AS %s, ", toPk, sqlfmt.QuoteIdentifier(pk.Name))
pksSelectExpression += fmt.Sprintf("sd.%s AS %s, ", toPk, formatter.QuoteIdentifier(pk.Name))
}
return fmt.Sprintf(viewExpressionTemplate, allToPks, pksPartitionByExpression, tableName,
@@ -50,10 +50,10 @@ func (bat *BranchActivityTable) String() string {
func (bat *BranchActivityTable) Schema() sql.Schema {
return []*sql.Column{
{Name: "branch", Type: types.Text, Source: bat.tableName, PrimaryKey: true, Nullable: false, DatabaseSource: bat.db.Name()},
{Name: "last_read", Type: types.Datetime, Source: bat.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bat.db.Name()},
{Name: "last_write", Type: types.Datetime, Source: bat.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bat.db.Name()},
{Name: "last_read", Type: types.DatetimeMaxPrecision, Source: bat.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bat.db.Name()},
{Name: "last_write", Type: types.DatetimeMaxPrecision, Source: bat.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bat.db.Name()},
{Name: "active_sessions", Type: types.Int32, Source: bat.tableName, PrimaryKey: false, Nullable: false, DatabaseSource: bat.db.Name()},
{Name: "system_start_time", Type: types.Datetime, Source: bat.tableName, PrimaryKey: false, Nullable: false, DatabaseSource: bat.db.Name()},
{Name: "system_start_time", Type: types.DatetimeMaxPrecision, Source: bat.tableName, PrimaryKey: false, Nullable: false, DatabaseSource: bat.db.Name()},
}
}
@@ -235,7 +235,7 @@ func (bt *BranchesTable) Schema() sql.Schema {
{Name: "hash", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: false, DatabaseSource: bt.db.Name()},
{Name: "latest_committer", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bt.db.Name()},
{Name: "latest_committer_email", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bt.db.Name()},
{Name: "latest_commit_date", Type: types.Datetime, Source: bt.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bt.db.Name()},
{Name: "latest_commit_date", Type: types.Datetime3, Source: bt.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bt.db.Name()},
{Name: "latest_commit_message", Type: types.Text, Source: bt.tableName, PrimaryKey: false, Nullable: true, DatabaseSource: bt.db.Name()},
}
if !bt.remote {
@@ -24,6 +24,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/rowconv"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
@@ -40,6 +41,7 @@ var ErrInvalidCommitDiffTableArgs = errors.New("commit_diff_<table> requires one
type CommitDiffTable struct {
workingRoot doltdb.RootValue
stagedRoot doltdb.RootValue
headRef ref.DoltRef
requiredFilterErr error
ddb *doltdb.DoltDB
table *doltdb.Table
@@ -59,7 +61,7 @@ var _ sql.Table = (*CommitDiffTable)(nil)
var _ sql.IndexAddressable = (*CommitDiffTable)(nil)
var _ sql.StatisticsTable = (*CommitDiffTable)(nil)
func NewCommitDiffTable(ctx *sql.Context, dbName string, tblName doltdb.TableName, ddb *doltdb.DoltDB, wRoot, sRoot doltdb.RootValue) (sql.Table, error) {
func NewCommitDiffTable(ctx *sql.Context, dbName string, tblName doltdb.TableName, ddb *doltdb.DoltDB, wRoot, sRoot doltdb.RootValue, headRef ref.DoltRef) (sql.Table, error) {
diffTblName := doltdb.DoltCommitDiffTablePrefix + tblName.Name
var table *doltdb.Table
@@ -91,6 +93,7 @@ func NewCommitDiffTable(ctx *sql.Context, dbName string, tblName doltdb.TableNam
ddb: ddb,
workingRoot: wRoot,
stagedRoot: sRoot,
headRef: headRef,
joiner: j,
sqlSch: sqlSch,
targetSchema: sch,
@@ -287,7 +290,7 @@ func (dt *CommitDiffTable) rootValForHash(ctx *sql.Context, hashStr string) (dol
return nil, "", nil, err
}
optCmt, err := dt.ddb.Resolve(ctx, cs, nil)
optCmt, err := dt.ddb.Resolve(ctx, cs, dt.headRef)
if err != nil {
return nil, "", nil, err
}
@@ -76,7 +76,7 @@ func (ct *CommitsTable) Schema() sql.Schema {
{Name: "commit_hash", Type: types.Text, Source: ct.tableName, PrimaryKey: true, DatabaseSource: ct.dbName},
{Name: "committer", Type: types.Text, Source: ct.tableName, PrimaryKey: false, DatabaseSource: ct.dbName},
{Name: "email", Type: types.Text, Source: ct.tableName, PrimaryKey: false, DatabaseSource: ct.dbName},
{Name: "date", Type: types.Datetime, Source: ct.tableName, PrimaryKey: false, DatabaseSource: ct.dbName},
{Name: "date", Type: types.Datetime3, Source: ct.tableName, PrimaryKey: false, DatabaseSource: ct.dbName},
{Name: "message", Type: types.Text, Source: ct.tableName, PrimaryKey: false, DatabaseSource: ct.dbName},
}
}
@@ -90,7 +90,7 @@ func (dt *LogTable) Schema() sql.Schema {
{Name: "commit_hash", Type: types.Text, Source: dt.tableName, PrimaryKey: true, DatabaseSource: dt.dbName},
{Name: "committer", Type: types.Text, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
{Name: "email", Type: types.Text, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
{Name: "date", Type: types.Datetime, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
{Name: "date", Type: types.Datetime3, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
{Name: "message", Type: types.Text, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
{Name: "commit_order", Type: types.Uint64, Source: dt.tableName, PrimaryKey: false, DatabaseSource: dt.dbName},
}
@@ -139,7 +139,7 @@ func (c ProllyRowConverter) putFields(ctx context.Context, tup val.Tuple, proj v
c.warnFn(rowconv.DatatypeCoercionFailureWarningCode, rowconv.DatatypeCoercionFailureWarning, col.Name)
dstRow[j] = nil
err = nil
} else if !inRange {
} else if inRange != sql.InRange && c.warnFn != nil {
c.warnFn(rowconv.TruncatedOutOfRangeValueWarningCode, rowconv.TruncatedOutOfRangeValueWarning, t, f)
dstRow[j] = nil
} else if err != nil {
@@ -171,7 +171,7 @@ func newSchemaConflict(ctx *sql.Context, table doltdb.TableName, baseRoot doltdb
var base string
if baseSch != nil {
var err error
base, err = getCreateTableStatement(table.Name, baseSch, baseFKs, bs)
base, err = getCreateTableStatement(ctx, table.Name, baseSch, baseFKs, bs)
if err != nil {
return schemaConflict{}, err
}
@@ -182,7 +182,7 @@ func newSchemaConflict(ctx *sql.Context, table doltdb.TableName, baseRoot doltdb
var ours string
if c.ToSch != nil {
var err error
ours, err = getCreateTableStatement(table.Name, c.ToSch, c.ToFks, c.ToParentSchemas)
ours, err = getCreateTableStatement(ctx, table.Name, c.ToSch, c.ToFks, c.ToParentSchemas)
if err != nil {
return schemaConflict{}, err
}
@@ -199,7 +199,7 @@ func newSchemaConflict(ctx *sql.Context, table doltdb.TableName, baseRoot doltdb
var theirs string
if c.FromSch != nil {
var err error
theirs, err = getCreateTableStatement(table.Name, c.FromSch, c.FromFks, c.FromParentSchemas)
theirs, err = getCreateTableStatement(ctx, table.Name, c.FromSch, c.FromFks, c.FromParentSchemas)
if err != nil {
return schemaConflict{}, err
}
@@ -237,8 +237,8 @@ func newSchemaConflict(ctx *sql.Context, table doltdb.TableName, baseRoot doltdb
}, nil
}
func getCreateTableStatement(table string, sch schema.Schema, fks []doltdb.ForeignKey, parents map[doltdb.TableName]schema.Schema) (string, error) {
return sqlfmt.GenerateCreateTableStatement(table, sch, fks, parents)
func getCreateTableStatement(ctx *sql.Context, table string, sch schema.Schema, fks []doltdb.ForeignKey, parents map[doltdb.TableName]schema.Schema) (string, error) {
return sqlfmt.GenerateCreateTableStatement(ctx, table, sch, fks, parents)
}
func getSchemaConflictDescription(ctx *sql.Context, table doltdb.TableName, base, ours, theirs schema.Schema) (string, error) {
@@ -25,6 +25,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/adapters"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
)
@@ -61,20 +62,12 @@ func (st StatusTable) String() string {
return st.tableName
}
func getDoltStatusSchema(tableName string) sql.Schema {
return []*sql.Column{
{Name: "table_name", Type: types.Text, Source: tableName, PrimaryKey: true, Nullable: false},
{Name: "staged", Type: types.Boolean, Source: tableName, PrimaryKey: true, Nullable: false},
{Name: "status", Type: types.Text, Source: tableName, PrimaryKey: true, Nullable: false},
}
}
// GetDoltStatusSchema returns the schema of the dolt_status system table. This is used
// by Doltgres to update the dolt_status schema using Doltgres types.
var GetDoltStatusSchema = getDoltStatusSchema
func (st StatusTable) Schema() sql.Schema {
return GetDoltStatusSchema(st.tableName)
return []*sql.Column{
{Name: "table_name", Type: types.Text, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false},
{Name: "staged", Type: types.Boolean, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false},
{Name: "status", Type: types.Text, Source: doltdb.StatusTableName, PrimaryKey: true, Nullable: false},
}
}
func (st StatusTable) Collation() sql.CollationID {
@@ -89,8 +82,19 @@ func (st StatusTable) PartitionRows(context *sql.Context, _ sql.Partition) (sql.
return newStatusItr(context, &st)
}
// NewStatusTable creates a StatusTable
func NewStatusTable(_ *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table {
// NewStatusTable creates a new StatusTable using either an integrators' [adapters.TableAdapter] or the
// NewStatusTableWithNoAdapter constructor (the default implementation provided by Dolt).
func NewStatusTable(ctx *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table {
adapter, ok := adapters.DoltTableAdapterRegistry.GetAdapter(tableName)
if ok {
return adapter.NewTable(ctx, tableName, ddb, ws, rp)
}
return NewStatusTableWithNoAdapter(ctx, tableName, ddb, ws, rp)
}
// NewStatusTableWithNoAdapter returns a new StatusTable.
func NewStatusTableWithNoAdapter(_ *sql.Context, tableName string, ddb *doltdb.DoltDB, ws *doltdb.WorkingSet, rp env.RootsProvider[*sql.Context]) sql.Table {
return &StatusTable{
tableName: tableName,
ddb: ddb,
@@ -107,7 +111,8 @@ type StatusItr struct {
type statusTableRow struct {
tableName string
status string
isStaged bool
isStaged byte // not a bool bc wire protocol confuses bools and tinyint(1), resulting in in consistent display
// of this table when you are using local vs remote sql connections.
}
func containsTableName(name string, names []doltdb.TableName) bool {
@@ -174,7 +179,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
for _, tbl := range ms.TablesWithSchemaConflicts() {
rows = append(rows, statusTableRow{
tableName: tbl.String(),
isStaged: false,
isStaged: byte(0),
status: "schema conflict",
})
}
@@ -182,7 +187,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
for _, tbl := range ms.MergedTables() {
rows = append(rows, statusTableRow{
tableName: tbl.String(),
isStaged: true,
isStaged: byte(1),
status: mergedStatus,
})
}
@@ -209,7 +214,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
}
rows = append(rows, statusTableRow{
tableName: tblName,
isStaged: true,
isStaged: byte(1),
status: statusString(td),
})
}
@@ -223,7 +228,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
}
rows = append(rows, statusTableRow{
tableName: tblName,
isStaged: false,
isStaged: byte(0),
status: statusString(td),
})
}
@@ -231,7 +236,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
for _, sd := range stagedSchemas {
rows = append(rows, statusTableRow{
tableName: sd.CurName(),
isStaged: true,
isStaged: byte(1),
status: schemaStatusString(sd),
})
}
@@ -239,7 +244,7 @@ func newStatusItr(ctx *sql.Context, st *StatusTable) (*StatusItr, error) {
for _, sd := range unstagedSchemas {
rows = append(rows, statusTableRow{
tableName: sd.CurName(),
isStaged: false,
isStaged: byte(0),
status: schemaStatusString(sd),
})
}
@@ -71,7 +71,7 @@ func (tt *TagsTable) Schema() sql.Schema {
{Name: "tag_hash", Type: types.Text, Source: tt.tableName, PrimaryKey: true},
{Name: "tagger", Type: types.Text, Source: tt.tableName, PrimaryKey: false},
{Name: "email", Type: types.Text, Source: tt.tableName, PrimaryKey: false},
{Name: "date", Type: types.Datetime, Source: tt.tableName, PrimaryKey: false},
{Name: "date", Type: types.Datetime3, Source: tt.tableName, PrimaryKey: false},
{Name: "message", Type: types.Text, Source: tt.tableName, PrimaryKey: false},
}
}
@@ -158,6 +158,7 @@ var BranchActivityTests = []queries.ScriptTest{
"SELECT last_write INTO @lw FROM dolt_branch_activity WHERE branch = 'other_branch'",
"SELECT SLEEP(2)", // Ensure time stamp difference is noticeable
"UPDATE `mydb/other_branch`.t SET v='baz' WHERE id=1",
"SELECT SLEEP(0.5)", // branch activity update is async, give it a moment.
},
Assertions: []queries.ScriptTestAssertion{
{
@@ -1420,7 +1420,7 @@ func TestBranchControl(t *testing.T) {
defer engine.Close()
ctx := enginetest.NewContext(harness)
ctx.NewCtxWithClient(sql.Client{
ctx.WithClient(sql.Client{
User: "root",
Address: "localhost",
})
@@ -1440,7 +1440,7 @@ func TestBranchControl(t *testing.T) {
if host == "" {
host = "localhost"
}
ctx = ctx.NewCtxWithClient(sql.Client{
ctx = ctx.WithClient(sql.Client{
User: user,
Address: host,
})
@@ -1478,7 +1478,7 @@ func TestBranchControlBlocks(t *testing.T) {
defer engine.Close()
rootCtx := enginetest.NewContext(harness)
rootCtx.NewCtxWithClient(sql.Client{
rootCtx.WithClient(sql.Client{
User: "root",
Address: "localhost",
})
@@ -1522,7 +1522,7 @@ func TestBranchControlBlocks(t *testing.T) {
defer engine.Close()
rootCtx := enginetest.NewContext(harness)
rootCtx.NewCtxWithClient(sql.Client{
rootCtx.WithClient(sql.Client{
User: "root",
Address: "localhost",
})
@@ -603,7 +603,7 @@ func TestIndexedAccess(t *testing.T, e enginetest.QueryEngine, harness enginetes
}
func analyzeQuery(ctx *sql.Context, e enginetest.QueryEngine, query string) (sql.Node, error) {
binder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler(), nil)
binder := planbuilder.New(ctx, e.EngineAnalyzer().Catalog, e.EngineEventScheduler())
parsed, _, _, qFlags, err := binder.Parse(query, nil, false)
if err != nil {
return nil, err
@@ -1595,6 +1595,16 @@ func TestLogTableFunctionPrepared(t *testing.T) {
RunLogTableFunctionTestsPrepared(t, harness)
}
func TestJsonDiffTableFunction(t *testing.T) {
harness := newDoltEnginetestHarness(t)
RunJsonDiffTableFunctionTests(t, harness)
}
func TestJsonDiffTableFunctionPrepared(t *testing.T) {
harness := newDoltEnginetestHarness(t)
RunJsonDiffTableFunctionTestsPrepared(t, harness)
}
func TestBranchStatusTableFunction(t *testing.T) {
harness := newDoltEnginetestHarness(t)
RunBranchStatusTableFunctionTests(t, harness)
@@ -2191,3 +2201,36 @@ func TestBranchActivity(t *testing.T) {
defer h.Close()
RunBranchActivityTests(t, h)
}
// TestDriverExecution verifies that queries work in dolt driver, where the MySQLDb is not initialized.
func TestDriverExecution(t *testing.T) {
h := newDoltHarness(t)
h.UseLocalFileSystem()
defer h.Close()
engine, err := h.NewEngine(t)
if err != nil {
t.Fatal(err)
}
defer engine.Close()
// Simulate driver environment. The MySQLDb is initialized but with no users (not even root). The context user is
// "root" still though. This mimics the dolthub/driver initialization of the engine (no PrivFilePath provided).
engine.EngineAnalyzer().Catalog.MySQLDb = mysql_db.CreateEmptyMySQLDb()
ctx := enginetest.NewContextWithClient(h, sql.Client{
User: "root",
Address: "localhost",
})
q := "call dolt_backup('add', 'backup1', 'file:///tmp/backup1');"
enginetest.TestQueryWithContext(t, ctx, engine, h, q, []sql.Row{{0}}, nil, nil, nil)
q = "select name from dolt_backups where name = 'backup1'"
enginetest.TestQueryWithContext(t, ctx, engine, h, q, []sql.Row{{"backup1"}}, nil, nil, nil)
q = "call dolt_backup('sync-url', 'file:///tmp/backup_sync_url');"
enginetest.TestQueryWithContext(t, ctx, engine, h, q, []sql.Row{{0}}, nil, nil, nil)
q = "call dolt_backup('remove', 'backup1');"
enginetest.TestQueryWithContext(t, ctx, engine, h, q, []sql.Row{{0}}, nil, nil, nil)
}
@@ -516,6 +516,7 @@ func RunDoltStoredProceduresTest(t *testing.T, h DoltEnginetestHarness) {
for _, script := range DoltProcedureTests {
func() {
h := h.NewHarness(t)
h.UseLocalFileSystem()
defer h.Close()
enginetest.TestScript(t, h, script)
}()
@@ -526,6 +527,7 @@ func RunDoltStoredProceduresPreparedTest(t *testing.T, h DoltEnginetestHarness)
for _, script := range DoltProcedureTests {
func() {
h := h.NewHarness(t)
h.UseLocalFileSystem()
defer h.Close()
enginetest.TestScriptPrepared(t, h, script)
}()
@@ -1324,6 +1326,26 @@ func RunLogTableFunctionTestsPrepared(t *testing.T, harness DoltEnginetestHarnes
}
}
func RunJsonDiffTableFunctionTests(t *testing.T, harness DoltEnginetestHarness) {
for _, test := range JsonDiffTableFunctionScriptTests {
harness = harness.NewHarness(t)
defer harness.Close()
harness.Setup(setup.MydbData)
harness.SkipSetupCommit()
enginetest.TestScript(t, harness, test)
}
}
func RunJsonDiffTableFunctionTestsPrepared(t *testing.T, harness DoltEnginetestHarness) {
for _, test := range JsonDiffTableFunctionScriptTests {
harness = harness.NewHarness(t)
defer harness.Close()
harness.Setup(setup.MydbData)
harness.SkipSetupCommit()
enginetest.TestScriptPrepared(t, harness, test)
}
}
func RunBranchStatusTableFunctionTests(t *testing.T, harness DoltEnginetestHarness) {
for _, test := range BranchStatusTableFunctionScriptTests {
t.Run(test.Name, func(t *testing.T) {
@@ -275,7 +275,7 @@ func (d *DoltHarness) NewEngine(t *testing.T) (enginetest.QueryEngine, error) {
if err != nil {
return nil, err
}
e.Analyzer.ExecBuilder = rowexec.NewOverrideBuilder(kvexec.Builder{})
e.Analyzer.ExecBuilder = rowexec.NewBuilder(kvexec.Builder{}, e.Analyzer.Overrides)
d.engine = e
sqlCtx := enginetest.NewContext(d)
@@ -521,7 +521,7 @@ func (d *DoltHarness) NewReadOnlyEngine(provider sql.DatabaseProvider) (enginete
locations[i] = loc
}
readOnlyProvider, err := sqle.NewDoltDatabaseProviderWithDatabases("main", ddp.FileSystem(), dbs, locations)
readOnlyProvider, err := sqle.NewDoltDatabaseProviderWithDatabases("main", ddp.FileSystem(), dbs, locations, sql.EngineOverrides{})
if err != nil {
return nil, err
}
@@ -569,7 +569,7 @@ func (d *DoltHarness) newProvider(ctx context.Context) sql.MutableDatabaseProvid
d.multiRepoEnv = mrEnv
b := env.GetDefaultInitBranch(d.multiRepoEnv.Config())
pro, err := sqle.NewDoltDatabaseProvider(b, d.multiRepoEnv.FileSystem())
pro, err := sqle.NewDoltDatabaseProvider(b, d.multiRepoEnv.FileSystem(), sql.EngineOverrides{})
require.NoError(d.t, err)
return pro
+412 -1
View File
@@ -15,12 +15,423 @@
package enginetest
import (
"fmt"
"os"
"path/filepath"
"github.com/dolthub/go-mysql-server/enginetest/queries"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
)
func init() {
DoltProcedureTests = append(DoltProcedureTests, DoltBackupProcedureScripts...)
}
// fileUrl returns a file:// URL path.
func fileUrl(path string) string {
path = filepath.Join(os.TempDir(), path)
return "file://" + filepath.ToSlash(filepath.Clean(path))
}
// awsUrl returns an aws:// URL with the given dynamo table, S3 bucket, and database path.
func awsUrl(dynamoTable, s3Bucket, path string) string {
return fmt.Sprintf("aws://[%s:%s]/%s", dynamoTable, s3Bucket, path)
}
var DoltBackupProcedureScripts = []queries.ScriptTest{
{
Name: "dolt_backup add",
SetUpScript: []string{
fmt.Sprintf("call dolt_backup('add', 'bak1', '%s');", fileUrl("dolt_backup1")),
},
Assertions: []queries.ScriptTestAssertion{
{
Query: fmt.Sprintf("call dolt_backup('add', 'bak2', '%s');", fileUrl("dolt_backup2")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups order by name;",
Expected: []sql.Row{
{"bak1", fileUrl("dolt_backup1"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
{"bak2", fileUrl("dolt_backup2"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
},
},
{
// Invalid URLs are accepted but will fail when used in 'sync'.
Query: "call dolt_backup('add', 'bak3', 'invalid://url');",
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups order by name",
Expected: []sql.Row{
{"bak1", fileUrl("dolt_backup1"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
{"bak2", fileUrl("dolt_backup2"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
{"bak3", "invalid://url", gmstypes.JSONDocument{Val: map[string]interface{}{}}},
},
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'aws_params', '%s', '--aws-region=<region>', '--aws-creds-type=file', '--aws-creds-file=<file>', '--aws-creds-profile=<profile>');", awsUrl("test-dynamo", "test-bucket", "testdb-params")),
Expected: []sql.Row{{0}},
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'aws_partial', '%s', '--aws-region=eu-west-1', '--aws-creds-profile=<profile>');", awsUrl("test-dynamo", "test-bucket", "testdb-partial")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups where url like 'aws://%' order by name;",
Expected: []sql.Row{
{
"aws_params",
awsUrl("test-dynamo", "test-bucket", "testdb-params"),
gmstypes.JSONDocument{
Val: map[string]interface{}{
"aws-region": "<region>",
"aws-creds-type": "file",
"aws-creds-file": "<file>",
"aws-creds-profile": "<profile>",
},
},
},
{
"aws_partial",
awsUrl("test-dynamo", "test-bucket", "testdb-partial"),
gmstypes.JSONDocument{
Val: map[string]interface{}{
"aws-creds-profile": "<profile>",
"aws-region": "eu-west-1",
},
},
},
},
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'aws_conflict', '%s', '--aws-region=<region>', '--aws-creds-type=file', '--aws-creds-file=<file>', '--aws-creds-profile=<profile>');", awsUrl("test-dynamo", "test-bucket", "testdb-params")),
ExpectedErr: env.ErrRemoteAddressConflict,
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'aws_conflict', '%s', '--aws-creds-type=<err>');", awsUrl("test-dynamo", "test-bucket", "testdb-params")),
ExpectedErrStr: "<err> is not a valid option for 'aws-creds-type'. valid options are: role|env|file",
},
{
Query: "call dolt_backup('add', 'bak2');",
ExpectedErrStr: "usage: dolt_backup('add', 'name', 'url', ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: "call dolt_backup('add');",
ExpectedErrStr: "usage: dolt_backup('add', 'name', 'url', ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'bak1', '%s');", fileUrl("dolt_backup1")),
ExpectedErrStr: "backup 'bak1' already exists",
},
{
Query: fmt.Sprintf("call dolt_backup('add', '', '%s');", fileUrl("dolt_backup2")),
ExpectedErrStr: "backup name '' is invalid",
},
{
Query: "call dolt_backup('add', 'bak2', '');",
ExpectedErrStr: "backup URL '' is invalid",
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'backup with spaces', '%s');", fileUrl("dolt_backup2")),
ExpectedErrStr: "backup name 'backup with spaces' is invalid",
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'backup/slash', '%s');", fileUrl("dolt_backup2")),
ExpectedErrStr: "backup name 'backup/slash' is invalid",
},
},
},
{
Name: "dolt_backup remove",
SetUpScript: []string{
fmt.Sprintf("call dolt_backup('add', 'bak1', '%s');", fileUrl("dolt_backup1")),
fmt.Sprintf("call dolt_backup('add', 'bak2', '%s');", fileUrl("dolt_backup2")),
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select * from dolt_backups order by name;",
Expected: []sql.Row{
{"bak1", fileUrl("dolt_backup1"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
{"bak2", fileUrl("dolt_backup2"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
},
},
{
Query: "call dolt_backup('rm', 'bak2');",
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups order by name;",
Expected: []sql.Row{
{"bak1", fileUrl("dolt_backup1"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
},
},
{
Query: fmt.Sprintf("call dolt_backup('add', 'bak2', '%s');", fileUrl("dolt_backup2")),
Expected: []sql.Row{{0}},
},
{
Query: "call dolt_backup('remove', 'bak1');",
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups;",
Expected: []sql.Row{
{"bak2", fileUrl("dolt_backup2"), gmstypes.JSONDocument{Val: map[string]interface{}{}}},
},
},
{
Query: "create table t (t text);",
Expected: []sql.Row{{gmstypes.OkResult{}}},
},
{
// Testing that remove only affects the dolt_backups table, but keeps the original backup intact.
Query: "call dolt_backup('sync', 'bak2')",
Expected: []sql.Row{{0}},
},
{
Query: "call dolt_backup('remove', 'bak2');",
Expected: []sql.Row{{0}},
},
{
Query: "select * from dolt_backups;",
Expected: []sql.Row{},
},
{
Query: "drop table t;",
Expected: []sql.Row{{gmstypes.OkResult{}}},
},
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s', 'restored_db');", fileUrl("dolt_backup2")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from restored_db.t;",
Expected: []sql.Row{},
},
{
Query: "call dolt_backup('remove', 'nonexistent');",
ExpectedErrStr: "backup 'nonexistent' not found",
},
{
Query: "call dolt_backup('remove');",
ExpectedErrStr: "usage: dolt_backup('remove', 'name')",
},
{
Query: "call dolt_backup('remove', 'bak1', 'extra');",
ExpectedErrStr: "usage: dolt_backup('remove', 'name')",
},
{
Query: "call dolt_backup('remove', '');",
ExpectedErrStr: "backup '' not found",
},
},
},
{
Name: "dolt_backup sync",
SetUpScript: []string{
fmt.Sprintf("call dolt_backup('add', 'bak1', '%s');", fileUrl("dolt_backup1")),
"create table t(a int primary key, b int);",
"insert into t values (1, 100), (2, 200);",
"call dolt_add('t');",
"call dolt_commit('-m', 'initial commit');",
"call dolt_backup('add', 'invalid_backup', 'invalid://url');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "call dolt_backup('sync', 'bak1');",
Expected: []sql.Row{{0}},
},
{
Query: "select * from t;",
Expected: []sql.Row{{1, 100}, {2, 200}},
},
{
Query: "call dolt_backup('sync', 'nonexistent');",
ExpectedErrStr: "backup 'nonexistent' not found",
},
{
Query: "call dolt_backup('sync');",
ExpectedErrStr: "usage: dolt_backup('sync', 'name')",
},
{
Query: "call dolt_backup('sync', 'dolt_backup1', 'extra');",
ExpectedErrStr: "usage: dolt_backup('sync', 'name')",
},
{
Query: "call dolt_backup('sync', 'invalid_backup');",
ExpectedErrStr: "unknown url scheme: 'invalid'",
},
{
Query: "call dolt_backup('sync', '');",
ExpectedErrStr: "backup '' not found",
},
},
},
{
Name: "dolt_backup sync-url",
SetUpScript: []string{
fmt.Sprintf("call dolt_backup('add', 'bak1', '%s');", fileUrl("dolt_backup1")),
"create table t(a int primary key, b int);",
"insert into t values (1, 100), (2, 200);",
"call dolt_add('t');",
"call dolt_commit('-m', 'initial commit');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: fmt.Sprintf("call dolt_backup('sync-url', '%s');", fileUrl("dolt_backup2")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from t;",
Expected: []sql.Row{{1, 100}, {2, 200}},
},
{
Query: fmt.Sprintf("call dolt_backup('sync-url', '%s');", fileUrl("dolt_backup1")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from t;",
Expected: []sql.Row{{1, 100}, {2, 200}},
},
{
Query: "call dolt_backup('sync-url');",
ExpectedErrStr: "usage: dolt_backup('sync-url', 'remote_url', ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: "call dolt_backup('sync-url', '', 'extra');",
ExpectedErrStr: "usage: dolt_backup('sync-url', 'remote_url', ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: "call dolt_backup('sync-url', 'invalid://url');",
ExpectedErrStr: "unknown url scheme: 'invalid'",
},
{
Query: "call dolt_backup('sync-url', '');",
ExpectedErrStr: "backup URL '' is invalid",
},
},
},
{
Name: "dolt_backup restore",
SetUpScript: []string{
fmt.Sprintf("call dolt_backup('add', 'dolt_backup1', '%s');", fileUrl("dolt_backup1")),
"call dolt_backup('sync', 'dolt_backup1');",
"create table t(a int primary key, b int);",
"insert into t values (1, 100), (2, 200);",
"call dolt_add('t');",
"call dolt_commit('-m', 'restore this commit');",
fmt.Sprintf("call dolt_backup('add', 'dolt_backup2', '%s');", fileUrl("dolt_backup2")),
"call dolt_backup('sync', 'dolt_backup2');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s', 'restored_db');", fileUrl("dolt_backup2")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from restored_db.t order by a;",
Expected: []sql.Row{{1, 100}, {2, 200}},
},
{
Query: "select message from restored_db.dolt_log order by commit_order;",
Expected: []sql.Row{{"Initialize data repository"}, {"checkpoint enginetest database mydb"}, {"restore this commit"}},
},
{
Query: "call dolt_backup('restore');",
ExpectedErrStr: "usage: dolt_backup('restore', 'remote_url', 'new_db_name', ['--force'], ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s');", fileUrl("dolt_backup1")),
ExpectedErrStr: "usage: dolt_backup('restore', 'remote_url', 'new_db_name', ['--force'], ['--aws-region=<region>'], ['--aws-creds-type=<type>'], ['--aws-creds-file=<file>'], ['--aws-creds-profile=<profile>'])",
},
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s', 'restored_db');", fileUrl("dolt_backup2")),
ExpectedErrStr: "database 'restored_db' already exists, use '--force' to overwrite",
},
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s', 'restored_db', '--force');", fileUrl("dolt_backup1")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from restored_db.t",
ExpectedErr: sql.ErrTableNotFound,
},
{
Query: "select message from restored_db.dolt_log order by commit_order;",
Expected: []sql.Row{{"Initialize data repository"}, {"checkpoint enginetest database mydb"}},
},
{
Query: "call dolt_backup('restore', 'invalid://url', 'restored_db2');",
ExpectedErrStr: "unknown url scheme: 'invalid'",
},
},
},
{
Name: "dolt_backup error",
Assertions: []queries.ScriptTestAssertion{
{
Query: fmt.Sprintf("call dolt_backup('invalid', 'dolt_backup1', '%s');", fileUrl("dolt_backup1")),
ExpectedErrStr: "unrecognized dolt_backup parameter 'invalid'",
},
{
Query: "call dolt_backup();",
ExpectedErrStr: "use 'dolt_backups' table to list backups",
},
{
Query: "call dolt_backup('--verbose');",
ExpectedErrStr: "use 'dolt_backups' table to list backups",
},
},
},
{
Name: "dolt_backup transactional operations",
SetUpScript: []string{
"create table t(a int primary key);",
`CREATE PROCEDURE transaction_backup_ops()
BEGIN
START TRANSACTION;
INSERT INTO t VALUES (1);
CALL dolt_backup('add', 'txn_bak', '` + fileUrl("txn_backup") + `');
CALL dolt_backup('sync', 'txn_bak');
CALL dolt_backup('sync-url', '` + fileUrl("txn_backup_url") + `');
CALL dolt_backup('restore', '` + fileUrl("txn_backup") + `', 'restored_txn_db');
CALL dolt_backup('remove', 'txn_bak');
COMMIT;
END`,
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "call transaction_backup_ops();",
Expected: []sql.Row{{gmstypes.OkResult{}}},
},
{
Query: "select * from restored_txn_db.t;",
Expected: []sql.Row{{1}},
},
{
Query: fmt.Sprintf("call dolt_backup('restore', '%s', 'restored_txn_url_db');", fileUrl("txn_backup_url")),
Expected: []sql.Row{{0}},
},
{
Query: "select * from restored_txn_url_db.t;",
Expected: []sql.Row{{1}},
},
{
Query: "select count(*) from dolt_backups where name='txn_bak';",
Expected: []sql.Row{{0}},
},
},
},
}
var DoltProcedureTests = []queries.ScriptTest{
{
Name: "dolt_commit in a loop",
@@ -692,7 +692,7 @@ var DoltRevisionDbScripts = []queries.ScriptTest{
},
{
Query: "select * from dolt_status",
Expected: []sql.Row{{"t01", false, "modified"}},
Expected: []sql.Row{{"t01", byte(0), "modified"}},
},
{
Query: "call dolt_checkout('t01')",
@@ -726,11 +726,11 @@ var DoltScripts = []queries.ScriptTest{
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT * FROM dolt_status;",
Expected: []sql.Row{{"t", false, "new table"}},
Expected: []sql.Row{{"t", byte(0), "new table"}},
},
{
Query: "SELECT * FROM `mydb/main`.dolt_status;",
Expected: []sql.Row{{"t", false, "new table"}},
Expected: []sql.Row{{"t", byte(0), "new table"}},
},
{
Query: "SELECT * FROM dolt_status AS OF 'tag1';",
@@ -743,7 +743,7 @@ var DoltScripts = []queries.ScriptTest{
{
// HEAD is a special revision spec
Query: "SELECT * FROM dolt_status AS OF 'head';",
Expected: []sql.Row{{"t", false, "new table"}},
Expected: []sql.Row{{"t", byte(0), "new table"}},
},
{
Query: "SELECT * FROM dolt_status AS OF 'HEAD~1';",
@@ -751,11 +751,11 @@ var DoltScripts = []queries.ScriptTest{
},
{
Query: "SELECT * FROM dolt_status AS OF 'branch1';",
Expected: []sql.Row{{"abc", true, "new table"}},
Expected: []sql.Row{{"abc", byte(1), "new table"}},
},
{
Query: "SELECT * FROM `mydb/branch1`.dolt_status;",
Expected: []sql.Row{{"abc", true, "new table"}},
Expected: []sql.Row{{"abc", byte(1), "new table"}},
},
},
},
@@ -3709,7 +3709,7 @@ var DoltCheckoutScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"t1", true, "modified"},
{"t1", byte(1), "modified"},
},
},
{
@@ -3740,7 +3740,7 @@ var DoltCheckoutScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"t2", true, "modified"},
{"t2", byte(1), "modified"},
},
},
{
@@ -3790,8 +3790,8 @@ var DoltCheckoutScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"t1", true, "modified"},
{"t2", true, "modified"},
{"t1", byte(1), "modified"},
{"t2", byte(1), "modified"},
},
},
{
@@ -4469,7 +4469,7 @@ var DoltResetTestScripts = []queries.ScriptTest{
{
// dolt_status should only show the unstaged table t being added
Query: "select * from dolt_status",
Expected: []sql.Row{{"t", false, "new table"}},
Expected: []sql.Row{{"t", byte(0), "new table"}},
},
},
},
@@ -4492,7 +4492,7 @@ var DoltResetTestScripts = []queries.ScriptTest{
{
// dolt_status should only show the unstaged table t being added
Query: "select * from dolt_status",
Expected: []sql.Row{{"t", false, "new table"}},
Expected: []sql.Row{{"t", byte(0), "new table"}},
},
},
},
@@ -5324,7 +5324,7 @@ var LogTableFunctionScriptTests = []queries.ScriptTest{
{"commit_hash", "text", "NO", "PRI", nil, ""},
{"committer", "text", "NO", "", nil, ""},
{"email", "text", "NO", "", nil, ""},
{"date", "datetime", "NO", "", nil, ""},
{"date", "datetime(3)", "NO", "", nil, ""},
{"message", "text", "NO", "", nil, ""},
{"commit_order", "bigint unsigned", "NO", "", nil, ""},
},
@@ -5428,6 +5428,107 @@ var LogTableFunctionScriptTests = []queries.ScriptTest{
},
}
var JsonDiffTableFunctionScriptTests = []queries.ScriptTest{
{
Name: "basic functionality with JSON literals",
SetUpScript: []string{},
Assertions: []queries.ScriptTestAssertion{
{
Query: `SELECT * from dolt_json_diff('{"a":1}', '{"a":2}');`,
Expected: []sql.Row{{"modified", "$.a", types.JSONDocument{Val: 1}, types.JSONDocument{Val: 2}}},
},
{
Query: `SELECT * from dolt_json_diff('{}', '{"added_key":"added_value"}');`,
Expected: []sql.Row{{"added", "$.added_key", nil, types.JSONDocument{Val: "added_value"}}},
},
{
Query: `SELECT * from dolt_json_diff('{"removed_key":true}', '{}');`,
Expected: []sql.Row{{"removed", "$.removed_key", types.JSONDocument{Val: true}, nil}},
},
{
Query: `SELECT * from dolt_json_diff('{"a":null}', '{"b":null}');`,
Expected: []sql.Row{
{"removed", "$.a", types.JSONDocument{Val: nil}, nil},
{"added", "$.b", nil, types.JSONDocument{Val: nil}},
},
},
{
Query: `SELECT * from dolt_json_diff('{"a": [0, 1, 2]}', '{"a": [0, 1, 3]}');`,
Expected: []sql.Row{
{"modified", "$.a[2]", types.JSONDocument{Val: 2}, types.JSONDocument{Val: 3}},
},
},
{
Query: `SELECT * from dolt_json_diff('[0, 1, 2]', '[0, 1, 3]');`,
Expected: []sql.Row{
{"modified", "$[2]", types.JSONDocument{Val: 2}, types.JSONDocument{Val: 3}},
},
},
},
},
{
Name: "lateral join with small json objects retrieved from tables",
SetUpScript: []string{
"CREATE TABLE test_table(pk int primary key, from_json json, to_json json);",
`INSERT INTO test_table VALUES (0, '{"a":1}', '{"a":2}');`,
`INSERT INTO test_table VALUES (1, '{"b":3}', '{"c":3}');`,
`INSERT INTO test_table VALUES (2, '[0, 1, 2]', '[0, 1, 3]');`,
`INSERT INTO test_table VALUES (3, '{"a": [0, 1, 2]}', '{"a": [0, 1, 3]}');`,
},
Assertions: []queries.ScriptTestAssertion{
{
Query: `SELECT * FROM dolt_json_diff((select from_json FROM test_table where pk = 0), (select to_json FROM test_table where pk = 0));`,
Expected: []sql.Row{{"modified", "$.a", types.JSONDocument{Val: 1}, types.JSONDocument{Val: 2}}},
},
{
Query: `SELECT from_json, to_json, diff_type, path, from_value, to_value FROM test_table JOIN LATERAL (SELECT * from dolt_json_diff(from_json, to_json)) sq;`,
Expected: []sql.Row{
{
types.JSONDocument{Val: types.JsonObject{"a": 1}},
types.JSONDocument{Val: types.JsonObject{"a": 2}},
"modified",
"$.a",
types.JSONDocument{Val: 1},
types.JSONDocument{Val: 2},
},
{
types.JSONDocument{Val: types.JsonObject{"b": 3}},
types.JSONDocument{Val: types.JsonObject{"c": 3}},
"removed",
"$.b",
types.JSONDocument{Val: 3},
nil,
},
{
types.JSONDocument{Val: types.JsonObject{"b": 3}},
types.JSONDocument{Val: types.JsonObject{"c": 3}},
"added",
"$.c",
nil,
types.JSONDocument{Val: 3},
},
{
types.JSONDocument{Val: types.JsonArray{0, 1, 2}},
types.JSONDocument{Val: types.JsonArray{0, 1, 3}},
"modified",
"$[2]",
types.JSONDocument{Val: 2},
types.JSONDocument{Val: 3},
},
{
types.JSONDocument{Val: types.JsonObject{"a": types.JsonArray{0, 1, 2}}},
types.JSONDocument{Val: types.JsonObject{"a": types.JsonArray{0, 1, 3}}},
"modified",
"$.a[2]",
types.JSONDocument{Val: 2},
types.JSONDocument{Val: 3},
},
},
},
},
},
}
var BranchStatusTableFunctionScriptTests = []queries.ScriptTest{
{
// * anc
@@ -7374,7 +7475,7 @@ var DoltCherryPickTests = []queries.ScriptTest{
},
{
Query: "select * from dolt_status",
Expected: []sql.Row{{"t", false, "modified"}, {"t", false, "conflict"}},
Expected: []sql.Row{{"t", byte(0), "modified"}, {"t", byte(0), "conflict"}},
},
{
Query: "select base_pk, base_v, our_pk, our_diff_type, their_pk, their_diff_type from dolt_conflicts_t;",
@@ -7388,7 +7489,7 @@ var DoltCherryPickTests = []queries.ScriptTest{
},
{
Query: "select * from dolt_status",
Expected: []sql.Row{{"t", false, "modified"}},
Expected: []sql.Row{{"t", byte(0), "modified"}},
},
{
Query: "select * from dolt_conflicts;",
@@ -7539,7 +7640,7 @@ var DoltCherryPickTests = []queries.ScriptTest{
{
// An ignored table should still be present (and unstaged) after aborting the merge.
Query: "select * from dolt_status;",
Expected: []sql.Row{{"generated_foo", false, "new table"}},
Expected: []sql.Row{{"generated_foo", byte(0), "new table"}},
},
{
// Changes made to the table during the merge should not be reverted.
@@ -5428,25 +5428,36 @@ var CommitDiffSystemTableScriptTests = []queries.ScriptTest{
},
},
{
Name: "working and staged commits",
Name: "working, staged, and head commits",
SetUpScript: []string{
"create table t (pk int primary key, c1 int, c2 int);",
"call dolt_commit('-Am', 'created table');",
"set @Commit0 = HASHOF('HEAD');",
"call dolt_branch('Commit0');",
"insert into t values (7, 8, 9);",
"call dolt_commit('-Am', 'insert into table');",
"call dolt_branch('Commit1')",
"insert into t values (1, 2, 3);",
"call dolt_add('.');",
"insert into t values (4, 5, 6);",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t WHERE TO_COMMIT='WORKING' and FROM_COMMIT=@Commit0;",
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t WHERE TO_COMMIT='WORKING' and FROM_COMMIT='HEAD';",
Expected: []sql.Row{
{1, 2, 3, nil, nil, nil, "added"},
{4, 5, 6, nil, nil, nil, "added"},
},
},
{
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t WHERE TO_COMMIT='STAGED' and FROM_COMMIT=@Commit0;",
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t WHERE TO_COMMIT='WORKING' and FROM_COMMIT='HEAD~';",
Expected: []sql.Row{
{1, 2, 3, nil, nil, nil, "added"},
{4, 5, 6, nil, nil, nil, "added"},
{7, 8, 9, nil, nil, nil, "added"},
},
},
{
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t WHERE TO_COMMIT='STAGED' and FROM_COMMIT='HEAD';",
Expected: []sql.Row{
{1, 2, 3, nil, nil, nil, "added"},
},
@@ -5463,6 +5474,12 @@ var CommitDiffSystemTableScriptTests = []queries.ScriptTest{
{nil, nil, nil, 4, 5, 6, "removed"},
},
},
{
Query: "SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type FROM DOLT_COMMIT_DIFF_t AS OF Commit1 WHERE TO_COMMIT='HEAD' and FROM_COMMIT='Commit0';",
Expected: []sql.Row{
{7, 8, 9, nil, nil, nil, "added"},
},
},
},
},
{
@@ -6163,7 +6180,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"__DATABASE__mydb", false, "modified"},
{"__DATABASE__mydb", byte(0), "modified"},
},
},
{
@@ -6182,7 +6199,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"__DATABASE__mydb", true, "modified"},
{"__DATABASE__mydb", byte(1), "modified"},
},
},
{
@@ -6245,7 +6262,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"__DATABASE__mydb", false, "modified"},
{"__DATABASE__mydb", byte(0), "modified"},
},
},
{
@@ -6264,7 +6281,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"__DATABASE__mydb", true, "modified"},
{"__DATABASE__mydb", byte(1), "modified"},
},
},
{
@@ -6327,7 +6344,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"__DATABASE__mydb", false, "modified"},
{"__DATABASE__mydb", byte(0), "modified"},
},
},
{
@@ -6377,7 +6394,7 @@ var DoltDatabaseCollationScriptTests = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"__DATABASE__mydb", false, "modified"},
{"__DATABASE__mydb", byte(0), "modified"},
},
},
{
@@ -176,7 +176,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"aTable", false, "constraint violation"},
{"aTable", byte(0), "constraint violation"},
},
},
{
@@ -513,7 +513,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", true, "modified"}},
Expected: []sql.Row{{"test", byte(1), "modified"}},
},
{
Query: "SELECT COUNT(*) FROM dolt_log",
@@ -577,7 +577,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", false, "modified"}, {"test", false, "conflict"}},
Expected: []sql.Row{{"test", byte(0), "modified"}, {"test", byte(0), "conflict"}},
},
{
Query: "SELECT COUNT(*) FROM dolt_log",
@@ -605,7 +605,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", false, "modified"}},
Expected: []sql.Row{{"test", byte(0), "modified"}},
},
{
Query: "SELECT * from test ORDER BY pk",
@@ -682,7 +682,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", false, "conflict"}},
Expected: []sql.Row{{"test", byte(0), "conflict"}},
},
{
Query: "SELECT COUNT(*) FROM dolt_log",
@@ -775,7 +775,7 @@ var MergeScripts = []queries.ScriptTest{
{
Skip: true,
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", false, "schema conflict"}},
Expected: []sql.Row{{"test", byte(0), "schema conflict"}},
},
{
Skip: true,
@@ -810,7 +810,7 @@ var MergeScripts = []queries.ScriptTest{
{
Skip: true,
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", true, "merged"}},
Expected: []sql.Row{{"test", byte(1), "merged"}},
},
{
Skip: true,
@@ -856,7 +856,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{{"test", false, "modified"}, {"test", false, "conflict"}},
Expected: []sql.Row{{"test", byte(0), "modified"}, {"test", byte(0), "conflict"}},
},
{
// errors because creating a new branch implicitly commits the current transaction
@@ -1331,7 +1331,7 @@ var MergeScripts = []queries.ScriptTest{
},
{
Query: "SELECT * from dolt_status",
Expected: []sql.Row{{"test", false, "modified"}, {"test", false, "conflict"}},
Expected: []sql.Row{{"test", byte(0), "modified"}, {"test", byte(0), "conflict"}},
},
{
Query: "SELECT COUNT(*) FROM dolt_conflicts",
@@ -1580,7 +1580,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"child", false, "constraint violation"},
{"child", byte(0), "constraint violation"},
},
},
},
@@ -1670,7 +1670,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"t", false, "constraint violation"},
{"t", byte(0), "constraint violation"},
},
},
},
@@ -1708,7 +1708,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"t", false, "constraint violation"},
{"t", byte(0), "constraint violation"},
},
},
},
@@ -1746,7 +1746,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"t", false, "constraint violation"},
{"t", byte(0), "constraint violation"},
},
},
},
@@ -1784,7 +1784,7 @@ var MergeScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"t", false, "constraint violation"},
{"t", byte(0), "constraint violation"},
},
},
},
@@ -2965,6 +2965,116 @@ var MergeScripts = []queries.ScriptTest{
},
},
},
{
Name: "--ff-only flag success when fast-forward is possible",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"INSERT INTO t VALUES (1, 'main1'), (2, 'main2');",
"CALL dolt_commit('-Am', 'main commit');",
"CALL dolt_checkout('-b', 'feature');",
"INSERT INTO t VALUES (3, 'feature1');",
"CALL dolt_commit('-am', 'feature commit');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', 'feature');",
Expected: []sql.Row{{doltCommit, 1, 0, "merge successful"}},
},
{
Query: "SELECT * FROM t ORDER BY pk;",
Expected: []sql.Row{{1, "main1"}, {2, "main2"}, {3, "feature1"}},
},
},
},
{
Name: "--ff-only flag failure when fast-forward is not possible",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"INSERT INTO t VALUES (1, 'main1'), (2, 'main2');",
"CALL dolt_commit('-Am', 'main commit');",
"CALL dolt_checkout('-b', 'feature');",
"INSERT INTO t VALUES (3, 'feature1');",
"CALL dolt_commit('-am', 'feature commit');",
"CALL dolt_checkout('main');",
"INSERT INTO t VALUES (4, 'main3');",
"CALL dolt_commit('-am', 'main commit 2');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', 'feature');",
ExpectedErrStr: "fatal: Not possible to fast-forward, aborting",
},
{
Query: "SELECT * FROM t ORDER BY pk;",
Expected: []sql.Row{{1, "main1"}, {2, "main2"}, {4, "main3"}}, // No changes
},
},
},
{
Name: "--ff-only flag with already up-to-date branch",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"INSERT INTO t VALUES (1, 'main1'), (2, 'main2');",
"CALL dolt_commit('-Am', 'main commit');",
"CALL dolt_checkout('-b', 'feature');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', 'feature');",
Expected: []sql.Row{{"", 0, 0, "Everything up-to-date"}},
},
},
},
{
Name: "--ff-only conflicts with --no-ff",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"CALL dolt_commit('-Am', 'initial commit');",
"CALL dolt_checkout('-b', 'feature');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', '--no-ff', 'feature');",
ExpectedErrStr: "error: Flags '--ff-only' and '--no-ff' cannot be used together",
},
},
},
{
Name: "--ff-only conflicts with --squash",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"CALL dolt_commit('-Am', 'initial commit');",
"CALL dolt_checkout('-b', 'feature');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', '--squash', 'feature');",
ExpectedErrStr: "error: Flags '--ff-only' and '--squash' cannot be used together",
},
},
},
{
Name: "--ff-only with no-commit flag should work",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY KEY, c1 varchar(20));",
"INSERT INTO t VALUES (1, 'main1');",
"CALL dolt_commit('-Am', 'main commit');",
"CALL dolt_checkout('-b', 'feature');",
"INSERT INTO t VALUES (2, 'feature1');",
"CALL dolt_commit('-am', 'feature commit');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL dolt_merge('--ff-only', '--no-commit', 'feature');",
Expected: []sql.Row{{doltCommit, 1, 0, "merge successful"}}, // Fast-forward merge with commit hash
},
},
},
}
var KeylessMergeCVsAndConflictsScripts = []queries.ScriptTest{
@@ -4343,7 +4453,7 @@ var SchemaConflictScripts = []queries.ScriptTest{
{
Query: "select * from dolt_status",
Expected: []sql.Row{
{"t", false, "schema conflict"},
{"t", byte(0), "schema conflict"},
},
},
},
@@ -135,7 +135,7 @@ var RevertScripts = []queries.ScriptTest{
},
{
Query: "select * from dolt_status",
Expected: []sql.Row{{"dont_track", false, "new table"}},
Expected: []sql.Row{{"dont_track", byte(0), "new table"}},
},
},
},
@@ -140,8 +140,8 @@ var DoltRmTests = []queries.ScriptTest{
{
Query: "select * from dolt_status;",
Expected: []sql.Row{
{"test", true, "deleted"},
{"test", false, "new table"},
{"test", byte(1), "deleted"},
{"test", byte(0), "new table"},
},
},
},
@@ -160,7 +160,7 @@ var DoltRmTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"test", false, "new table"},
{"test", byte(0), "new table"},
},
},
},
@@ -181,9 +181,9 @@ var DoltRmTests = []queries.ScriptTest{
{
Query: "SELECT * FROM dolt_status",
Expected: []sql.Row{
{"committed", true, "deleted"},
{"staged", false, "new table"},
{"committed", false, "new table"},
{"committed", byte(1), "deleted"},
{"staged", byte(0), "new table"},
{"committed", byte(0), "new table"},
},
},
},
@@ -321,8 +321,8 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"test", true, "modified"},
{"test", false, "modified"},
{"test", byte(1), "modified"},
{"test", byte(0), "modified"},
},
},
{
@@ -340,7 +340,7 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM dolt_status;",
Expected: []sql.Row{
{"test", false, "modified"},
{"test", byte(0), "modified"},
},
},
},
@@ -356,8 +356,8 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"test", true, "new table"},
{"new", false, "new table"},
{"test", byte(1), "new table"},
{"new", byte(0), "new table"},
},
},
{
@@ -375,8 +375,8 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM dolt_status;",
Expected: []sql.Row{
{"test", true, "new table"},
{"new", false, "new table"},
{"test", byte(1), "new table"},
{"new", byte(0), "new table"},
},
},
},
@@ -393,9 +393,9 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"new", true, "new table"},
{"test", false, "new table"},
{"new", false, "modified"},
{"new", byte(1), "new table"},
{"test", byte(0), "new table"},
{"new", byte(0), "modified"},
},
},
{
@@ -405,7 +405,7 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"test", false, "new table"},
{"test", byte(0), "new table"},
},
},
{
@@ -415,8 +415,8 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"new", true, "new table"},
{"test", false, "new table"},
{"new", byte(1), "new table"},
{"test", byte(0), "new table"},
},
},
},
@@ -432,7 +432,7 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS;",
Expected: []sql.Row{
{"new_tab", false, "deleted"},
{"new_tab", byte(0), "deleted"},
},
},
{
@@ -521,7 +521,7 @@ var DoltStashTests = []queries.ScriptTest{
{
Query: "SELECT * FROM DOLT_STATUS",
Expected: []sql.Row{
{"test", false, "deleted"},
{"test", byte(0), "deleted"},
},
},
},

Some files were not shown because too many files have changed in this diff Show More