amend dolt_backup proc to block aws params on server

This commit is contained in:
elianddb
2025-11-25 15:24:09 -08:00
parent 8558aeea77
commit c65a8e566d
4 changed files with 55 additions and 47 deletions

View File

@@ -382,21 +382,21 @@ func CreateGlobalArgParser(name string) *argparser.ArgParser {
return ap
}
var awsParams = []string{dbfactory.AWSRegionParam, dbfactory.AWSCredsTypeParam, dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile}
var AwsParams = []string{dbfactory.AWSRegionParam, dbfactory.AWSCredsTypeParam, dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile}
var ossParams = []string{dbfactory.OSSCredsFileParam, dbfactory.OSSCredsProfile}
func AddAWSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[string]string) error {
isAWS := strings.HasPrefix(remoteUrl, "aws")
if !isAWS {
for _, p := range awsParams {
for _, p := range AwsParams {
if _, ok := apr.GetValue(p); ok {
return fmt.Errorf("%s param is only valid for aws cloud remotes in the format aws://dynamo-table:s3-bucket/database", p)
}
}
}
for _, p := range awsParams {
for _, p := range AwsParams {
if val, ok := apr.GetValue(p); ok {
params[p] = val
}
@@ -426,7 +426,7 @@ func AddOSSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[s
}
func VerifyNoAwsParams(apr *argparser.ArgParseResults) error {
if awsParams := apr.GetValues(awsParams...); len(awsParams) > 0 {
if awsParams := apr.GetValues(AwsParams...); len(awsParams) > 0 {
awsParamKeys := make([]string, 0, len(awsParams))
for k := range awsParams {
awsParamKeys = append(awsParamKeys, k)

View File

@@ -53,6 +53,8 @@ const (
tempTablesDir = "temptf"
TmpDirName = "tmp"
InvalidNameCharacters = " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|"
)
var zeroHashStr = (hash.Hash{}).String()
@@ -1023,7 +1025,7 @@ func (dEnv *DoltEnv) AddRemote(r Remote) error {
return ErrRemoteAlreadyExists
}
if strings.IndexAny(r.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
if strings.IndexAny(r.Name, InvalidNameCharacters) != -1 {
return ErrInvalidRemoteName
}
@@ -1055,7 +1057,7 @@ func (dEnv *DoltEnv) AddBackup(r Remote) error {
return ErrBackupAlreadyExists.New(r.Name)
}
if strings.IndexAny(r.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 {
if strings.IndexAny(r.Name, InvalidNameCharacters) != -1 {
return ErrBackupInvalidName.New(r.Name)
}

View File

@@ -28,6 +28,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/dolt/go/libraries/utils/argparser"
"github.com/dolthub/dolt/go/store/datas/pull"
)
@@ -43,7 +44,7 @@ const (
DoltBackupParamRestore = "restore"
)
var optionalAwsParams = []string{
var awsParamsUsage = []string{
fmt.Sprintf("--%s=<region>", dbfactory.AWSRegionParam),
fmt.Sprintf("--%s=<type>", dbfactory.AWSCredsTypeParam),
fmt.Sprintf("--%s=<file>", dbfactory.AWSCredsFileParam),
@@ -76,6 +77,10 @@ func doltBackup(ctx *sql.Context, args ...string) (sql.RowIter, error) {
return nil, err
}
if sqlserver.RunningInServerMode() && apr.ContainsAny(cli.AwsParams...) {
return nil, fmt.Errorf("AWS parameters are unavailable when running in server mode")
}
if apr.NArg() == 0 || (apr.NArg() == 1 && apr.Contains(cli.VerboseFlag)) {
return nil, fmt.Errorf("use '%s' table to list backups", doltdb.BackupsTableName)
}
@@ -89,17 +94,32 @@ func doltBackup(ctx *sql.Context, args ...string) (sql.RowIter, error) {
funcParam := apr.Arg(0)
switch funcParam {
case DoltBackupParamAdd:
if apr.NArg() != 3 {
return nil, errDoltBackupUsage(funcParam, []string{"name", "url"}, awsParamsUsage)
}
err = doltBackupAdd(ctx, dbData, doltSess, apr)
case DoltBackupParamRemove, DoltBackupParamRm:
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"name"}, nil)
}
err = dbData.Rsw.RemoveBackup(ctx, apr.Arg(1))
name := apr.Arg(1)
err = dbData.Rsw.RemoveBackup(ctx, name)
case DoltBackupParamSync:
err = doltBackupSync(ctx, dbData, doltSess, apr)
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"name"}, nil)
}
name := apr.Arg(1)
err = doltBackupSync(ctx, dbData, doltSess, name)
case DoltBackupParamSyncUrl:
if apr.NArg() != 2 {
return nil, errDoltBackupUsage(funcParam, []string{"remote_url"}, awsParamsUsage)
}
err = doltBackupSyncUrl(ctx, dbData, doltSess, apr)
case DoltBackupParamRestore:
if apr.NArg() != 3 {
forceParamUsage := []string{fmt.Sprintf("--%s", cli.ForceFlag)}
return nil, errDoltBackupUsage(funcParam, []string{"remote_url", "new_db_name"}, append(forceParamUsage, awsParamsUsage...))
}
err = doltBackupRestore(ctx, dbData, doltSess, apr)
default:
return nil, fmt.Errorf("unrecognized %s parameter '%s'", DoltBackupProcedureName, funcParam)
@@ -112,10 +132,6 @@ func doltBackup(ctx *sql.Context, args ...string) (sql.RowIter, error) {
// path. AWS parameters are extracted from command-line flags in |apr| if present, otherwise they are loaded from
// session variables if the URL scheme matches.
func doltBackupAdd(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 3 {
return errDoltBackupUsage(DoltBackupParamAdd, []string{"name", "url"}, optionalAwsParams)
}
backupName := apr.Arg(1)
backupUrlScheme, backupUrl, err := newAbsRemoteUrl(dsess, apr.Arg(2))
if err != nil {
@@ -127,7 +143,7 @@ func doltBackupAdd(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dse
return err
}
if len(backupParams) == 0 {
if len(backupParams) == 0 && backupUrlScheme == dbfactory.AWSScheme {
backupParams, err = newParamsWithAwsSessionVars(ctx, backupUrlScheme)
if err != nil {
return err
@@ -142,17 +158,12 @@ func doltBackupAdd(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dse
// doltBackupSync syncs the current database to an existing backup identified by name in |apr|. The backup is looked up
// from the repository state via |dbData.Rsr|. The sync operation copies all roots from the current database to the
// backup location, overwriting any existing data.
func doltBackupSync(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 2 {
return errDoltBackupUsage(DoltBackupParamSync, []string{"name"}, nil)
}
func doltBackupSync(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, backupName string) error {
backups, err := dbData.Rsr.GetBackups()
if err != nil {
return err
}
backupName := apr.Arg(1)
backupRemote, ok := backups.Get(backupName)
if !ok {
return env.ErrBackupNotFound.New(backupName)
@@ -166,10 +177,6 @@ func doltBackupSync(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *ds
// in |apr| if present, otherwise they are loaded from session variables if the URL scheme matches. The sync operation
// copies all roots from the current database to the remote location, overwriting any existing data.
func doltBackupSyncUrl(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 2 {
return errDoltBackupUsage(DoltBackupParamSyncUrl, []string{"remote_url"}, optionalAwsParams)
}
remoteUrlScheme, remoteUrl, err := newAbsRemoteUrl(dsess, apr.Arg(1))
if err != nil {
return err
@@ -180,7 +187,7 @@ func doltBackupSyncUrl(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess
return err
}
if len(remoteParams) == 0 {
if len(remoteParams) == 0 && remoteUrlScheme == dbfactory.AWSScheme {
remoteParams, err = newParamsWithAwsSessionVars(ctx, remoteUrlScheme)
if err != nil {
return err
@@ -199,10 +206,6 @@ func doltBackupSyncUrl(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess
// If the target database already exists, the restore operation fails unless the --force flag is provided, in which case
// the existing database is dropped before cloning.
func doltBackupRestore(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.DoltSession, apr *argparser.ArgParseResults) error {
if apr.NArg() != 3 {
return errDoltBackupUsage(DoltBackupParamRestore, []string{"remote_url", "new_db_name"}, optionalAwsParams)
}
remoteUrlScheme, remoteUrl, err := newAbsRemoteUrl(dsess, apr.Arg(1))
if err != nil {
return err
@@ -213,7 +216,7 @@ func doltBackupRestore(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess
return err
}
if len(remoteParams) == 0 {
if len(remoteParams) == 0 && remoteUrlScheme == dbfactory.AWSScheme {
remoteParams, err = newParamsWithAwsSessionVars(ctx, remoteUrlScheme)
if err != nil {
return err
@@ -294,15 +297,17 @@ func syncRemote(ctx *sql.Context, dbData env.DbData[*sql.Context], dsess *dsess.
func newParams(apr *argparser.ArgParseResults, url string, urlScheme string) (map[string]string, error) {
params := map[string]string{}
var err error
if urlScheme == dbfactory.AWSScheme {
switch urlScheme {
case dbfactory.AWSScheme:
err = cli.AddAWSParams(url, apr, params)
} else {
case dbfactory.OSSScheme:
// TODO(elianddb): This func mainly interfaces with apr to set the OSS key-vals in params, but the backup arg
// parser does not include any OSS-related flags? I'm guessing they must be processed elsewhere?
err = cli.AddOSSParams(url, apr, params)
default:
err = cli.VerifyNoAwsParams(apr)
}
if err != nil {
return nil, err
}
return params, nil
return params, err
}
// newParamsWithAwsSessionVars extracts AWS-specific parameters from read-only session variables in |ctx|. It reads
@@ -310,9 +315,6 @@ func newParams(apr *argparser.ArgParseResults, url string, urlScheme string) (ma
// map. If URL scheme is not AWS, an empty parameter map is returned.
func newParamsWithAwsSessionVars(ctx *sql.Context, urlScheme string) (map[string]string, error) {
params := map[string]string{}
if urlScheme != dbfactory.AWSScheme { // In case newParams() isn't used first.
return params, nil
}
credsFile, err := ctx.Session.GetSessionVariable(ctx, dsess.AwsCredsFile)
if err != nil {
@@ -358,10 +360,11 @@ func newAbsRemoteUrl(dsess *dsess.DoltSession, url string) (string, string, erro
return env.GetAbsRemoteUrl(dsess.GetFileSystem(), config, url)
}
// errDoltBackupUsage constructs a usage error message for the dolt_backup procedure. It formats |param| as the
// subcommand, |required| as required positional arguments, and |optional| as optional flag arguments. The resulting
// error message follows the format: "usage: dolt_backup('<param>', '<required1>', ..., ['<optional1>'], ...)".
func errDoltBackupUsage(funcParam string, required, optional []string) error {
// errDoltBackupUsage constructs a usage error message for the dolt_backup procedure. It formats |funcParam| as the
// operation, |requiredParams| as required positional arguments, and |optionalParams| as optional flag arguments. The
// resulting error message follows the format:
// "usage: dolt_backup('<param>', '<required1>', ..., ['<optional1>'], ...)".
func errDoltBackupUsage(funcParam string, requiredParams, optionalParams []string) error {
var builder strings.Builder
builder.WriteString("usage: ")
@@ -370,13 +373,13 @@ func errDoltBackupUsage(funcParam string, required, optional []string) error {
builder.WriteString(funcParam)
builder.WriteByte('\'')
for _, req := range required {
for _, req := range requiredParams {
builder.WriteString(", '")
builder.WriteString(req)
builder.WriteByte('\'')
}
for _, opt := range optional {
for _, opt := range optionalParams {
builder.WriteString(", ['")
builder.WriteString(opt)
builder.WriteString("']")

View File

@@ -121,6 +121,10 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
return env.ErrRemoteAlreadyExists
}
if strings.IndexAny(remote.Name, env.InvalidNameCharacters) != -1 {
return env.ErrInvalidRemoteName
}
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
if err != nil {
return err
@@ -142,8 +146,7 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
}
func (s SessionStateAdapter) AddBackup(remote env.Remote) error {
const invalidNameCharacters = " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|"
if remote.Name == "" || strings.IndexAny(remote.Name, invalidNameCharacters) != -1 {
if remote.Name == "" || strings.IndexAny(remote.Name, env.InvalidNameCharacters) != -1 {
return env.ErrBackupInvalidName.New(remote.Name)
}