mirror of
https://github.com/dolthub/dolt.git
synced 2026-03-17 23:56:33 -05:00
Merge remote-tracking branch 'origin/main' into dhruv/column-truncation
This commit is contained in:
@@ -321,9 +321,8 @@ func getCommitMessageFromEditor(ctx context.Context, dEnv *env.DoltEnv, suggeste
|
||||
finalMsg = parseCommitMessage(commitMsg)
|
||||
})
|
||||
|
||||
// if editor could not be opened or the message received is empty, use auto-generated/suggested msg.
|
||||
if err != nil || finalMsg == "" {
|
||||
return suggestedMsg, nil
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return finalMsg, nil
|
||||
|
||||
@@ -130,8 +130,6 @@ func NewSqlEngine(
|
||||
if bcController, err = branch_control.LoadData(config.BranchCtrlFilePath, config.DoltCfgDirPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Set the server's super user
|
||||
branch_control.SetSuperUser(config.ServerUser, config.ServerHost)
|
||||
|
||||
// Set up engine
|
||||
engine := gms.New(analyzer.NewBuilder(pro).WithParallelism(parallelism).Build(), &gms.Config{
|
||||
|
||||
@@ -159,9 +159,9 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string,
|
||||
}
|
||||
|
||||
suggestedMsg := fmt.Sprintf("Merge branch '%s' into %s", commitSpecStr, dEnv.RepoStateReader().CWBHeadRef().GetPath())
|
||||
msg, err := getCommitMessage(ctx, apr, dEnv, suggestedMsg)
|
||||
if err != nil {
|
||||
return handleCommitErr(ctx, dEnv, err, usage)
|
||||
msg := ""
|
||||
if m, ok := apr.GetValue(cli.MessageArg); ok {
|
||||
msg = m
|
||||
}
|
||||
|
||||
if apr.Contains(cli.NoCommitFlag) && apr.Contains(cli.CommitFlag) {
|
||||
@@ -214,25 +214,6 @@ func getUnmergedTableCount(ctx context.Context, root *doltdb.RootValue) (int, er
|
||||
return unmergedTableCount, nil
|
||||
}
|
||||
|
||||
// getCommitMessage returns commit message depending on whether user defined commit message and/or no-ff flag.
|
||||
// If user defined message, it will use that. If not, and no-ff flag is defined, it will ask user for commit message from editor.
|
||||
// If none of commit message or no-ff flag is defined, it will return suggested message.
|
||||
func getCommitMessage(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEnv, suggestedMsg string) (string, errhand.VerboseError) {
|
||||
if m, ok := apr.GetValue(cli.MessageArg); ok {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
if apr.Contains(cli.NoFFParam) {
|
||||
msg, err := getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", apr.Contains(cli.NoEditFlag))
|
||||
if err != nil {
|
||||
return "", errhand.VerboseErrorFromError(err)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func validateMergeSpec(ctx context.Context, spec *merge.MergeSpec) errhand.VerboseError {
|
||||
if spec.HeadH == spec.MergeH {
|
||||
//TODO - why is this different for merge/pull?
|
||||
@@ -484,38 +465,117 @@ func handleMergeErr(ctx context.Context, dEnv *env.DoltEnv, mergeErr error, hasC
|
||||
// FF merges will not surface constraint violations on their own; constraint verify --all
|
||||
// is required to reify violations.
|
||||
func performMerge(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
|
||||
var tblStats map[string]*merge.MergeStats
|
||||
if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); err != nil && !errors.Is(err, doltdb.ErrUpToDate) {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
if spec.Noff {
|
||||
tblStats, err = merge.ExecNoFFMerge(ctx, dEnv, spec)
|
||||
return tblStats, err
|
||||
return executeNoFFMergeAndCommit(ctx, dEnv, spec, suggestedMsg)
|
||||
}
|
||||
return nil, merge.ExecuteFFMerge(ctx, dEnv, spec)
|
||||
}
|
||||
tblStats, err := merge.ExecuteMerge(ctx, dEnv, spec)
|
||||
return executeMergeAndCommit(ctx, dEnv, spec, suggestedMsg)
|
||||
}
|
||||
|
||||
func executeNoFFMergeAndCommit(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
|
||||
tblToStats, err := merge.ExecNoFFMerge(ctx, dEnv, spec)
|
||||
if err != nil {
|
||||
return tblStats, err
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
if !spec.NoCommit && !hasConflictOrViolations(tblStats) {
|
||||
msg := spec.Msg
|
||||
if spec.Msg == "" {
|
||||
msg, err = getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", spec.NoEdit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
|
||||
|
||||
res := performCommit(ctx, "commit", []string{"-m", msg, "--author", author}, dEnv)
|
||||
if res != 0 {
|
||||
return nil, fmt.Errorf("dolt commit failed after merging")
|
||||
}
|
||||
if spec.NoCommit {
|
||||
cli.Println("Automatic merge went well; stopped before committing as requested")
|
||||
return tblToStats, nil
|
||||
}
|
||||
|
||||
return tblStats, nil
|
||||
// Reload roots since the above method writes new values to the working set
|
||||
roots, err := dEnv.Roots(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
ws, err := dEnv.WorkingSet(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
var mergeParentCommits []*doltdb.Commit
|
||||
if ws.MergeActive() {
|
||||
mergeParentCommits = []*doltdb.Commit{ws.MergeState().Commit()}
|
||||
}
|
||||
|
||||
msg, err := getCommitMsgForMerge(ctx, dEnv, spec.Msg, suggestedMsg, spec.NoEdit)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{
|
||||
Message: msg,
|
||||
Date: spec.Date,
|
||||
AllowEmpty: spec.AllowEmpty,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return tblToStats, fmt.Errorf("%w; failed to commit", err)
|
||||
}
|
||||
|
||||
err = dEnv.ClearMerge(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
func executeMergeAndCommit(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
|
||||
tblToStats, err := merge.ExecuteMerge(ctx, dEnv, spec)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
if hasConflictOrViolations(tblToStats) {
|
||||
return tblToStats, nil
|
||||
}
|
||||
|
||||
if spec.NoCommit {
|
||||
cli.Println("Automatic merge went well; stopped before committing as requested")
|
||||
return tblToStats, nil
|
||||
}
|
||||
|
||||
msg, err := getCommitMsgForMerge(ctx, dEnv, spec.Msg, suggestedMsg, spec.NoEdit)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
|
||||
|
||||
res := performCommit(ctx, "commit", []string{"-m", msg, "--author", author}, dEnv)
|
||||
if res != 0 {
|
||||
return nil, fmt.Errorf("dolt commit failed after merging")
|
||||
}
|
||||
|
||||
return tblToStats, nil
|
||||
}
|
||||
|
||||
// getCommitMsgForMerge returns user defined message if exists; otherwise, get the commit message from editor.
|
||||
func getCommitMsgForMerge(ctx context.Context, dEnv *env.DoltEnv, userDefinedMsg, suggestedMsg string, noEdit bool) (string, error) {
|
||||
if userDefinedMsg != "" {
|
||||
return userDefinedMsg, nil
|
||||
}
|
||||
|
||||
msg, err := getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", noEdit)
|
||||
if err != nil {
|
||||
return msg, err
|
||||
}
|
||||
|
||||
if msg == "" {
|
||||
return msg, fmt.Errorf("error: Empty commit message.\n" +
|
||||
"Not committing merge; use 'dolt commit' to complete the merge.")
|
||||
}
|
||||
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// hasConflictOrViolations checks for conflicts or constraint violation regardless of a table being modified
|
||||
|
||||
@@ -112,7 +112,7 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec)
|
||||
// Fetch all references
|
||||
branchRefs, err := srcDB.GetHeadRefs(ctx)
|
||||
if err != nil {
|
||||
return env.ErrFailedToReadDb
|
||||
return fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
|
||||
}
|
||||
|
||||
hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath())
|
||||
|
||||
@@ -285,12 +285,11 @@ func Serve(
|
||||
lgr.Errorf("error starting remotesapi server listeners for cluster config on port %d: %v", clusterController.RemoteSrvPort(), err)
|
||||
startError = err
|
||||
return
|
||||
} else {
|
||||
go func() {
|
||||
clusterRemoteSrv.Serve(listeners)
|
||||
}()
|
||||
}
|
||||
|
||||
go clusterRemoteSrv.Serve(listeners)
|
||||
go clusterController.Run()
|
||||
|
||||
clusterController.ManageQueryConnections(
|
||||
mySQLServer.SessionManager().Iter,
|
||||
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
|
||||
@@ -323,6 +322,9 @@ func Serve(
|
||||
if clusterRemoteSrv != nil {
|
||||
clusterRemoteSrv.GracefulStop()
|
||||
}
|
||||
if clusterController != nil {
|
||||
clusterController.GracefulStop()
|
||||
}
|
||||
|
||||
return mySQLServer.Close()
|
||||
})
|
||||
|
||||
@@ -182,7 +182,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
|
||||
cli.PrintErrln(color.RedString(err.Error()))
|
||||
return 1
|
||||
}
|
||||
dbToUse = filepath.Base(directory)
|
||||
dbToUse = strings.Replace(filepath.Base(directory), "-", "_", -1)
|
||||
}
|
||||
format := engine.FormatTabular
|
||||
if hasResultFormat {
|
||||
|
||||
@@ -62,7 +62,8 @@ const (
|
||||
primaryKeyParam = "pk"
|
||||
fileTypeParam = "file-type"
|
||||
delimParam = "delim"
|
||||
ignoreSkippedRows = "ignore-skipped-rows"
|
||||
quiet = "quiet"
|
||||
ignoreSkippedRows = "ignore-skipped-rows" // alias for quiet
|
||||
disableFkChecks = "disable-fk-checks"
|
||||
)
|
||||
|
||||
@@ -74,7 +75,7 @@ The schema for the new table can be specified explicitly by providing a SQL sche
|
||||
|
||||
If {{.EmphasisLeft}}--update-table | -u{{.EmphasisRight}} is given the operation will update {{.LessThan}}table{{.GreaterThan}} with the contents of file. The table's existing schema will be used, and field names will be used to match file fields with table fields unless a mapping file is specified.
|
||||
|
||||
During import, if there is an error importing any row, the import will be aborted by default. Use the {{.EmphasisLeft}}--continue{{.EmphasisRight}} flag to continue importing when an error is encountered. You can add the {{.EmphasisLeft}}--ignore-skipped-rows{{.EmphasisRight}} flag to prevent the import utility from printing all the skipped rows.
|
||||
During import, if there is an error importing any row, the import will be aborted by default. Use the {{.EmphasisLeft}}--continue{{.EmphasisRight}} flag to continue importing when an error is encountered. You can add the {{.EmphasisLeft}}--quiet{{.EmphasisRight}} flag to prevent the import utility from printing all the skipped rows.
|
||||
|
||||
If {{.EmphasisLeft}}--replace-table | -r{{.EmphasisRight}} is given the operation will replace {{.LessThan}}table{{.GreaterThan}} with the contents of the file. The table's existing schema will be used, and field names will be used to match file fields with table fields unless a mapping file is specified.
|
||||
|
||||
@@ -87,8 +88,8 @@ A mapping file can be used to map fields between the file being imported and the
|
||||
In create, update, and replace scenarios the file's extension is used to infer the type of the file. If a file does not have the expected extension then the {{.EmphasisLeft}}--file-type{{.EmphasisRight}} parameter should be used to explicitly define the format of the file in one of the supported formats (csv, psv, json, xlsx). For files separated by a delimiter other than a ',' (type csv) or a '|' (type psv), the --delim parameter can be used to specify a delimiter`,
|
||||
|
||||
Synopsis: []string{
|
||||
"-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--disable-fk-checks] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
|
||||
"-u [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
|
||||
"-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--quiet] [--disable-fk-checks] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
|
||||
"-u [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--quiet] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
|
||||
"-r [--map {{.LessThan}}file{{.GreaterThan}}] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
|
||||
},
|
||||
}
|
||||
@@ -96,17 +97,17 @@ In create, update, and replace scenarios the file's extension is used to infer t
|
||||
var bitTypeRegex = regexp.MustCompile(`(?m)b\'(\d+)\'`)
|
||||
|
||||
type importOptions struct {
|
||||
operation mvdata.TableImportOp
|
||||
destTableName string
|
||||
contOnErr bool
|
||||
force bool
|
||||
schFile string
|
||||
primaryKeys []string
|
||||
nameMapper rowconv.NameMapper
|
||||
src mvdata.DataLocation
|
||||
srcOptions interface{}
|
||||
ignoreSkippedRows bool
|
||||
disableFkChecks bool
|
||||
operation mvdata.TableImportOp
|
||||
destTableName string
|
||||
contOnErr bool
|
||||
force bool
|
||||
schFile string
|
||||
primaryKeys []string
|
||||
nameMapper rowconv.NameMapper
|
||||
src mvdata.DataLocation
|
||||
srcOptions interface{}
|
||||
quiet bool
|
||||
disableFkChecks bool
|
||||
}
|
||||
|
||||
func (m importOptions) IsBatched() bool {
|
||||
@@ -168,7 +169,7 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d
|
||||
schemaFile, _ := apr.GetValue(schemaParam)
|
||||
force := apr.Contains(forceParam)
|
||||
contOnErr := apr.Contains(contOnErrParam)
|
||||
ignore := apr.Contains(ignoreSkippedRows)
|
||||
quiet := apr.Contains(quiet)
|
||||
disableFks := apr.Contains(disableFkChecks)
|
||||
|
||||
val, _ := apr.GetValue(primaryKeyParam)
|
||||
@@ -238,17 +239,17 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d
|
||||
}
|
||||
|
||||
return &importOptions{
|
||||
operation: moveOp,
|
||||
destTableName: tableName,
|
||||
contOnErr: contOnErr,
|
||||
force: force,
|
||||
schFile: schemaFile,
|
||||
nameMapper: colMapper,
|
||||
primaryKeys: pks,
|
||||
src: srcLoc,
|
||||
srcOptions: srcOpts,
|
||||
ignoreSkippedRows: ignore,
|
||||
disableFkChecks: disableFks,
|
||||
operation: moveOp,
|
||||
destTableName: tableName,
|
||||
contOnErr: contOnErr,
|
||||
force: force,
|
||||
schFile: schemaFile,
|
||||
nameMapper: colMapper,
|
||||
primaryKeys: pks,
|
||||
src: srcLoc,
|
||||
srcOptions: srcOpts,
|
||||
quiet: quiet,
|
||||
disableFkChecks: disableFks,
|
||||
}, nil
|
||||
|
||||
}
|
||||
@@ -337,7 +338,8 @@ func (cmd ImportCmd) ArgParser() *argparser.ArgParser {
|
||||
ap.SupportsFlag(forceParam, "f", "If a create operation is being executed, data already exists in the destination, the force flag will allow the target to be overwritten.")
|
||||
ap.SupportsFlag(replaceParam, "r", "Replace existing table with imported data while preserving the original schema.")
|
||||
ap.SupportsFlag(contOnErrParam, "", "Continue importing when row import errors are encountered.")
|
||||
ap.SupportsFlag(ignoreSkippedRows, "", "Ignore the skipped rows printed by the --continue flag.")
|
||||
ap.SupportsFlag(quiet, "", "Suppress any warning messages about invalid rows when using the --continue flag.")
|
||||
ap.SupportsAlias(ignoreSkippedRows, quiet)
|
||||
ap.SupportsFlag(disableFkChecks, "", "Disables foreign key checks.")
|
||||
ap.SupportsString(schemaParam, "s", "schema_file", "The schema for the output data.")
|
||||
ap.SupportsString(mappingFileParam, "m", "mapping_file", "A file that lays out how fields should be mapped from input data to output data.")
|
||||
@@ -524,8 +526,8 @@ func move(ctx context.Context, rd table.SqlRowReader, wr *mvdata.SqlEngineTableW
|
||||
return true
|
||||
}
|
||||
|
||||
// Don't log the skipped rows when the ignore-skipped-rows param is specified.
|
||||
if options.ignoreSkippedRows {
|
||||
// Don't log the skipped rows when asked to suppress warning output
|
||||
if options.quiet {
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2019 Dolthub, Inc.
|
||||
// Copyright 2019-2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
@@ -12,5 +12,5 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// dolt is a command line tool for working with dolt data repositories stored in noms.
|
||||
// dolt is the command line tool for working with Dolt databases.
|
||||
package main
|
||||
|
||||
@@ -57,7 +57,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "0.50.15"
|
||||
Version = "0.51.1"
|
||||
)
|
||||
|
||||
var dumpDocsCommand = &commands.DumpDocsCmd{}
|
||||
|
||||
@@ -210,7 +210,7 @@ func (rcv *BranchControlAccess) TryBinlog(obj *BranchControlBinlog) (*BranchCont
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j int) bool {
|
||||
func (rcv *BranchControlAccess) Databases(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
@@ -222,7 +222,7 @@ func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j in
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
func (rcv *BranchControlAccess) TryDatabases(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
@@ -237,8 +237,43 @@ func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) DatabasesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
x = rcv._tab.Indirect(x)
|
||||
obj.Init(rcv._tab.Bytes, x)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
x = rcv._tab.Indirect(x)
|
||||
obj.Init(rcv._tab.Bytes, x)
|
||||
if BranchControlMatchExpressionNumFields < obj.Table().NumFields() {
|
||||
return false, flatbuffers.ErrTableHasUnknownFields
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) BranchesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -246,7 +281,7 @@ func (rcv *BranchControlAccess) BranchesLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) Users(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -258,7 +293,7 @@ func (rcv *BranchControlAccess) Users(obj *BranchControlMatchExpression, j int)
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) TryUsers(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -273,7 +308,7 @@ func (rcv *BranchControlAccess) TryUsers(obj *BranchControlMatchExpression, j in
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) UsersLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -281,7 +316,7 @@ func (rcv *BranchControlAccess) UsersLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) Hosts(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -293,7 +328,7 @@ func (rcv *BranchControlAccess) Hosts(obj *BranchControlMatchExpression, j int)
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) TryHosts(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -308,7 +343,7 @@ func (rcv *BranchControlAccess) TryHosts(obj *BranchControlMatchExpression, j in
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) HostsLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -316,7 +351,7 @@ func (rcv *BranchControlAccess) HostsLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) Values(obj *BranchControlAccessValue, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -328,7 +363,7 @@ func (rcv *BranchControlAccess) Values(obj *BranchControlAccessValue, j int) boo
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) TryValues(obj *BranchControlAccessValue, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -343,14 +378,14 @@ func (rcv *BranchControlAccess) TryValues(obj *BranchControlAccessValue, j int)
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccess) ValuesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const BranchControlAccessNumFields = 5
|
||||
const BranchControlAccessNumFields = 6
|
||||
|
||||
func BranchControlAccessStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(BranchControlAccessNumFields)
|
||||
@@ -358,26 +393,32 @@ func BranchControlAccessStart(builder *flatbuffers.Builder) {
|
||||
func BranchControlAccessAddBinlog(builder *flatbuffers.Builder, binlog flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(binlog), 0)
|
||||
}
|
||||
func BranchControlAccessAddDatabases(builder *flatbuffers.Builder, databases flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(databases), 0)
|
||||
}
|
||||
func BranchControlAccessStartDatabasesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlAccessAddBranches(builder *flatbuffers.Builder, branches flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branches), 0)
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branches), 0)
|
||||
}
|
||||
func BranchControlAccessStartBranchesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlAccessAddUsers(builder *flatbuffers.Builder, users flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(users), 0)
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(users), 0)
|
||||
}
|
||||
func BranchControlAccessStartUsersVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlAccessAddHosts(builder *flatbuffers.Builder, hosts flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(hosts), 0)
|
||||
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(hosts), 0)
|
||||
}
|
||||
func BranchControlAccessStartHostsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlAccessAddValues(builder *flatbuffers.Builder, values flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(values), 0)
|
||||
builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(values), 0)
|
||||
}
|
||||
func BranchControlAccessStartValuesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
@@ -430,7 +471,7 @@ func (rcv *BranchControlAccessValue) Table() flatbuffers.Table {
|
||||
return rcv._tab
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) Branch() []byte {
|
||||
func (rcv *BranchControlAccessValue) Database() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -438,7 +479,7 @@ func (rcv *BranchControlAccessValue) Branch() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) User() []byte {
|
||||
func (rcv *BranchControlAccessValue) Branch() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -446,7 +487,7 @@ func (rcv *BranchControlAccessValue) User() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) Host() []byte {
|
||||
func (rcv *BranchControlAccessValue) User() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -454,8 +495,16 @@ func (rcv *BranchControlAccessValue) Host() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) Permissions() uint64 {
|
||||
func (rcv *BranchControlAccessValue) Host() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) Permissions() uint64 {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
return rcv._tab.GetUint64(o + rcv._tab.Pos)
|
||||
}
|
||||
@@ -463,25 +512,28 @@ func (rcv *BranchControlAccessValue) Permissions() uint64 {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlAccessValue) MutatePermissions(n uint64) bool {
|
||||
return rcv._tab.MutateUint64Slot(10, n)
|
||||
return rcv._tab.MutateUint64Slot(12, n)
|
||||
}
|
||||
|
||||
const BranchControlAccessValueNumFields = 4
|
||||
const BranchControlAccessValueNumFields = 5
|
||||
|
||||
func BranchControlAccessValueStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(BranchControlAccessValueNumFields)
|
||||
}
|
||||
func BranchControlAccessValueAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(database), 0)
|
||||
}
|
||||
func BranchControlAccessValueAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(branch), 0)
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
|
||||
}
|
||||
func BranchControlAccessValueAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(user), 0)
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
|
||||
}
|
||||
func BranchControlAccessValueAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(host), 0)
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
|
||||
}
|
||||
func BranchControlAccessValueAddPermissions(builder *flatbuffers.Builder, permissions uint64) {
|
||||
builder.PrependUint64Slot(3, permissions, 0)
|
||||
builder.PrependUint64Slot(4, permissions, 0)
|
||||
}
|
||||
func BranchControlAccessValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
@@ -560,7 +612,7 @@ func (rcv *BranchControlNamespace) TryBinlog(obj *BranchControlBinlog) (*BranchC
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j int) bool {
|
||||
func (rcv *BranchControlNamespace) Databases(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
@@ -572,7 +624,7 @@ func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
func (rcv *BranchControlNamespace) TryDatabases(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
@@ -587,8 +639,43 @@ func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) DatabasesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
x = rcv._tab.Indirect(x)
|
||||
obj.Init(rcv._tab.Bytes, x)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
x = rcv._tab.Indirect(x)
|
||||
obj.Init(rcv._tab.Bytes, x)
|
||||
if BranchControlMatchExpressionNumFields < obj.Table().NumFields() {
|
||||
return false, flatbuffers.ErrTableHasUnknownFields
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) BranchesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -596,7 +683,7 @@ func (rcv *BranchControlNamespace) BranchesLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) Users(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -608,7 +695,7 @@ func (rcv *BranchControlNamespace) Users(obj *BranchControlMatchExpression, j in
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) TryUsers(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -623,7 +710,7 @@ func (rcv *BranchControlNamespace) TryUsers(obj *BranchControlMatchExpression, j
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) UsersLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -631,7 +718,7 @@ func (rcv *BranchControlNamespace) UsersLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) Hosts(obj *BranchControlMatchExpression, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -643,7 +730,7 @@ func (rcv *BranchControlNamespace) Hosts(obj *BranchControlMatchExpression, j in
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) TryHosts(obj *BranchControlMatchExpression, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -658,7 +745,7 @@ func (rcv *BranchControlNamespace) TryHosts(obj *BranchControlMatchExpression, j
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) HostsLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
@@ -666,7 +753,7 @@ func (rcv *BranchControlNamespace) HostsLength() int {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) Values(obj *BranchControlNamespaceValue, j int) bool {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -678,7 +765,7 @@ func (rcv *BranchControlNamespace) Values(obj *BranchControlNamespaceValue, j in
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) TryValues(obj *BranchControlNamespaceValue, j int) (bool, error) {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
x := rcv._tab.Vector(o)
|
||||
x += flatbuffers.UOffsetT(j) * 4
|
||||
@@ -693,14 +780,14 @@ func (rcv *BranchControlNamespace) TryValues(obj *BranchControlNamespaceValue, j
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespace) ValuesLength() int {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
return rcv._tab.VectorLen(o)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
const BranchControlNamespaceNumFields = 5
|
||||
const BranchControlNamespaceNumFields = 6
|
||||
|
||||
func BranchControlNamespaceStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(BranchControlNamespaceNumFields)
|
||||
@@ -708,26 +795,32 @@ func BranchControlNamespaceStart(builder *flatbuffers.Builder) {
|
||||
func BranchControlNamespaceAddBinlog(builder *flatbuffers.Builder, binlog flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(binlog), 0)
|
||||
}
|
||||
func BranchControlNamespaceAddDatabases(builder *flatbuffers.Builder, databases flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(databases), 0)
|
||||
}
|
||||
func BranchControlNamespaceStartDatabasesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlNamespaceAddBranches(builder *flatbuffers.Builder, branches flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branches), 0)
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branches), 0)
|
||||
}
|
||||
func BranchControlNamespaceStartBranchesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlNamespaceAddUsers(builder *flatbuffers.Builder, users flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(users), 0)
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(users), 0)
|
||||
}
|
||||
func BranchControlNamespaceStartUsersVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlNamespaceAddHosts(builder *flatbuffers.Builder, hosts flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(hosts), 0)
|
||||
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(hosts), 0)
|
||||
}
|
||||
func BranchControlNamespaceStartHostsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
}
|
||||
func BranchControlNamespaceAddValues(builder *flatbuffers.Builder, values flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(values), 0)
|
||||
builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(values), 0)
|
||||
}
|
||||
func BranchControlNamespaceStartValuesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
|
||||
return builder.StartVector(4, numElems, 4)
|
||||
@@ -780,7 +873,7 @@ func (rcv *BranchControlNamespaceValue) Table() flatbuffers.Table {
|
||||
return rcv._tab
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespaceValue) Branch() []byte {
|
||||
func (rcv *BranchControlNamespaceValue) Database() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -788,7 +881,7 @@ func (rcv *BranchControlNamespaceValue) Branch() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespaceValue) User() []byte {
|
||||
func (rcv *BranchControlNamespaceValue) Branch() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -796,7 +889,7 @@ func (rcv *BranchControlNamespaceValue) User() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlNamespaceValue) Host() []byte {
|
||||
func (rcv *BranchControlNamespaceValue) User() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -804,19 +897,30 @@ func (rcv *BranchControlNamespaceValue) Host() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
const BranchControlNamespaceValueNumFields = 3
|
||||
func (rcv *BranchControlNamespaceValue) Host() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
const BranchControlNamespaceValueNumFields = 4
|
||||
|
||||
func BranchControlNamespaceValueStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(BranchControlNamespaceValueNumFields)
|
||||
}
|
||||
func BranchControlNamespaceValueAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(database), 0)
|
||||
}
|
||||
func BranchControlNamespaceValueAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(branch), 0)
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
|
||||
}
|
||||
func BranchControlNamespaceValueAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(user), 0)
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
|
||||
}
|
||||
func BranchControlNamespaceValueAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(host), 0)
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
|
||||
}
|
||||
func BranchControlNamespaceValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
@@ -972,7 +1076,7 @@ func (rcv *BranchControlBinlogRow) MutateIsInsert(n bool) bool {
|
||||
return rcv._tab.MutateBoolSlot(4, n)
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) Branch() []byte {
|
||||
func (rcv *BranchControlBinlogRow) Database() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -980,7 +1084,7 @@ func (rcv *BranchControlBinlogRow) Branch() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) User() []byte {
|
||||
func (rcv *BranchControlBinlogRow) Branch() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -988,7 +1092,7 @@ func (rcv *BranchControlBinlogRow) User() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) Host() []byte {
|
||||
func (rcv *BranchControlBinlogRow) User() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
@@ -996,8 +1100,16 @@ func (rcv *BranchControlBinlogRow) Host() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) Permissions() uint64 {
|
||||
func (rcv *BranchControlBinlogRow) Host() []byte {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
|
||||
if o != 0 {
|
||||
return rcv._tab.ByteVector(o + rcv._tab.Pos)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) Permissions() uint64 {
|
||||
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
|
||||
if o != 0 {
|
||||
return rcv._tab.GetUint64(o + rcv._tab.Pos)
|
||||
}
|
||||
@@ -1005,10 +1117,10 @@ func (rcv *BranchControlBinlogRow) Permissions() uint64 {
|
||||
}
|
||||
|
||||
func (rcv *BranchControlBinlogRow) MutatePermissions(n uint64) bool {
|
||||
return rcv._tab.MutateUint64Slot(12, n)
|
||||
return rcv._tab.MutateUint64Slot(14, n)
|
||||
}
|
||||
|
||||
const BranchControlBinlogRowNumFields = 5
|
||||
const BranchControlBinlogRowNumFields = 6
|
||||
|
||||
func BranchControlBinlogRowStart(builder *flatbuffers.Builder) {
|
||||
builder.StartObject(BranchControlBinlogRowNumFields)
|
||||
@@ -1016,17 +1128,20 @@ func BranchControlBinlogRowStart(builder *flatbuffers.Builder) {
|
||||
func BranchControlBinlogRowAddIsInsert(builder *flatbuffers.Builder, isInsert bool) {
|
||||
builder.PrependBoolSlot(0, isInsert, false)
|
||||
}
|
||||
func BranchControlBinlogRowAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(database), 0)
|
||||
}
|
||||
func BranchControlBinlogRowAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branch), 0)
|
||||
}
|
||||
func BranchControlBinlogRowAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(user), 0)
|
||||
}
|
||||
func BranchControlBinlogRowAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
|
||||
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
|
||||
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(host), 0)
|
||||
}
|
||||
func BranchControlBinlogRowAddPermissions(builder *flatbuffers.Builder, permissions uint64) {
|
||||
builder.PrependUint64Slot(4, permissions, 0)
|
||||
builder.PrependUint64Slot(5, permissions, 0)
|
||||
}
|
||||
func BranchControlBinlogRowEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
return builder.EndObject()
|
||||
|
||||
@@ -58,7 +58,7 @@ require (
|
||||
github.com/cenkalti/backoff/v4 v4.1.3
|
||||
github.com/cespare/xxhash v1.1.0
|
||||
github.com/creasty/defaults v1.6.0
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0
|
||||
github.com/google/flatbuffers v2.0.6+incompatible
|
||||
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6
|
||||
github.com/mitchellh/go-ps v1.0.0
|
||||
|
||||
@@ -180,8 +180,8 @@ github.com/dolthub/flatbuffers v1.13.0-dh.1 h1:OWJdaPep22N52O/0xsUevxJ6Qfw1M2txC
|
||||
github.com/dolthub/flatbuffers v1.13.0-dh.1/go.mod h1:CorYGaDmXjHz1Z7i50PYXG1Ricn31GcA2wNOTFIQAKE=
|
||||
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
|
||||
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579 h1:rOV6whqkxka2wGMGD/DOgUgX0jWw/gaJwTMqJ1ye2wA=
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579/go.mod h1:KtpU4Sf7J+SIat/nxoA733QTn3tdL34NtoGxEBFcTsA=
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0 h1:BSTAs705aUez54ENrjZCQ3px0s/17nPi58XvDlpA0sI=
|
||||
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0/go.mod h1:KtpU4Sf7J+SIat/nxoA733QTn3tdL34NtoGxEBFcTsA=
|
||||
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g=
|
||||
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms=
|
||||
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=
|
||||
|
||||
@@ -39,15 +39,17 @@ const (
|
||||
type Access struct {
|
||||
binlog *Binlog
|
||||
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []AccessValue
|
||||
RWMutex *sync.RWMutex
|
||||
Databases []MatchExpression
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []AccessValue
|
||||
RWMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// AccessValue contains the user-facing values of a particular row, along with the permissions for a row.
|
||||
type AccessValue struct {
|
||||
Database string
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
@@ -57,22 +59,19 @@ type AccessValue struct {
|
||||
// newAccess returns a new Access.
|
||||
func newAccess() *Access {
|
||||
return &Access{
|
||||
binlog: NewAccessBinlog(nil),
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
binlog: NewAccessBinlog(nil),
|
||||
Databases: nil,
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Match returns whether any entries match the given branch, user, and host, along with their permissions. Requires
|
||||
// external synchronization handling, therefore manually manage the RWMutex.
|
||||
func (tbl *Access) Match(branch string, user string, host string) (bool, Permissions) {
|
||||
if IsSuperUser(user, host) {
|
||||
return true, Permissions_Admin
|
||||
}
|
||||
|
||||
// Match returns whether any entries match the given database, branch, user, and host, along with their permissions.
|
||||
// Requires external synchronization handling, therefore manually manage the RWMutex.
|
||||
func (tbl *Access) Match(database string, branch string, user string, host string) (bool, Permissions) {
|
||||
filteredIndexes := Match(tbl.Users, user, sql.Collation_utf8mb4_0900_bin)
|
||||
|
||||
filteredHosts := tbl.filterHosts(filteredIndexes)
|
||||
@@ -80,6 +79,11 @@ func (tbl *Access) Match(branch string, user string, host string) (bool, Permiss
|
||||
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredHosts)
|
||||
|
||||
filteredDatabases := tbl.filterDatabases(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredDatabases, database, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredDatabases)
|
||||
|
||||
filteredBranches := tbl.filterBranches(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
@@ -93,9 +97,9 @@ func (tbl *Access) Match(branch string, user string, host string) (bool, Permiss
|
||||
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
|
||||
// returns -1. Assumes that the given expressions have already been folded. Requires external synchronization handling,
|
||||
// therefore manually manage the RWMutex.
|
||||
func (tbl *Access) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
|
||||
func (tbl *Access) GetIndex(databaseExpr string, branchExpr string, userExpr string, hostExpr string) int {
|
||||
for i, value := range tbl.Values {
|
||||
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
if value.Database == databaseExpr && value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
@@ -107,16 +111,6 @@ func (tbl *Access) GetBinlog() *Binlog {
|
||||
return tbl.binlog
|
||||
}
|
||||
|
||||
// GetSuperUser returns the server-level super user. Intended for display purposes only.
|
||||
func (tbl *Access) GetSuperUser() string {
|
||||
return superUser
|
||||
}
|
||||
|
||||
// GetSuperHost returns the server-level super user's host. Intended for display purposes only.
|
||||
func (tbl *Access) GetSuperHost() string {
|
||||
return superHost
|
||||
}
|
||||
|
||||
// Serialize returns the offset for the Access table written to the given builder.
|
||||
func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
tbl.RWMutex.RLock()
|
||||
@@ -125,11 +119,15 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
// Serialize the binlog
|
||||
binlog := tbl.binlog.Serialize(b)
|
||||
// Initialize field offset slices
|
||||
databaseOffsets := make([]flatbuffers.UOffsetT, len(tbl.Databases))
|
||||
branchOffsets := make([]flatbuffers.UOffsetT, len(tbl.Branches))
|
||||
userOffsets := make([]flatbuffers.UOffsetT, len(tbl.Users))
|
||||
hostOffsets := make([]flatbuffers.UOffsetT, len(tbl.Hosts))
|
||||
valueOffsets := make([]flatbuffers.UOffsetT, len(tbl.Values))
|
||||
// Get field offsets
|
||||
for i, matchExpr := range tbl.Databases {
|
||||
databaseOffsets[i] = matchExpr.Serialize(b)
|
||||
}
|
||||
for i, matchExpr := range tbl.Branches {
|
||||
branchOffsets[i] = matchExpr.Serialize(b)
|
||||
}
|
||||
@@ -143,6 +141,11 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
valueOffsets[i] = val.Serialize(b)
|
||||
}
|
||||
// Get the field vectors
|
||||
serial.BranchControlAccessStartDatabasesVector(b, len(databaseOffsets))
|
||||
for i := len(databaseOffsets) - 1; i >= 0; i-- {
|
||||
b.PrependUOffsetT(databaseOffsets[i])
|
||||
}
|
||||
databases := b.EndVector(len(databaseOffsets))
|
||||
serial.BranchControlAccessStartBranchesVector(b, len(branchOffsets))
|
||||
for i := len(branchOffsets) - 1; i >= 0; i-- {
|
||||
b.PrependUOffsetT(branchOffsets[i])
|
||||
@@ -166,6 +169,7 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
// Write the table
|
||||
serial.BranchControlAccessStart(b)
|
||||
serial.BranchControlAccessAddBinlog(b, binlog)
|
||||
serial.BranchControlAccessAddDatabases(b, databases)
|
||||
serial.BranchControlAccessAddBranches(b, branches)
|
||||
serial.BranchControlAccessAddUsers(b, users)
|
||||
serial.BranchControlAccessAddHosts(b, hosts)
|
||||
@@ -183,7 +187,10 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
|
||||
return fmt.Errorf("cannot deserialize to a non-empty access table")
|
||||
}
|
||||
// Verify that all fields have the same length
|
||||
if fb.BranchesLength() != fb.UsersLength() || fb.UsersLength() != fb.HostsLength() || fb.HostsLength() != fb.ValuesLength() {
|
||||
if fb.DatabasesLength() != fb.BranchesLength() ||
|
||||
fb.BranchesLength() != fb.UsersLength() ||
|
||||
fb.UsersLength() != fb.HostsLength() ||
|
||||
fb.HostsLength() != fb.ValuesLength() {
|
||||
return fmt.Errorf("cannot deserialize an access table with differing field lengths")
|
||||
}
|
||||
// Read the binlog
|
||||
@@ -195,10 +202,17 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
|
||||
return err
|
||||
}
|
||||
// Initialize every slice
|
||||
tbl.Databases = make([]MatchExpression, fb.DatabasesLength())
|
||||
tbl.Branches = make([]MatchExpression, fb.BranchesLength())
|
||||
tbl.Users = make([]MatchExpression, fb.UsersLength())
|
||||
tbl.Hosts = make([]MatchExpression, fb.HostsLength())
|
||||
tbl.Values = make([]AccessValue, fb.ValuesLength())
|
||||
// Read the databases
|
||||
for i := 0; i < fb.DatabasesLength(); i++ {
|
||||
serialMatchExpr := &serial.BranchControlMatchExpression{}
|
||||
fb.Databases(serialMatchExpr, i)
|
||||
tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr)
|
||||
}
|
||||
// Read the branches
|
||||
for i := 0; i < fb.BranchesLength(); i++ {
|
||||
serialMatchExpr := &serial.BranchControlMatchExpression{}
|
||||
@@ -222,6 +236,7 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
|
||||
serialAccessValue := &serial.BranchControlAccessValue{}
|
||||
fb.Values(serialAccessValue, i)
|
||||
tbl.Values[i] = AccessValue{
|
||||
Database: string(serialAccessValue.Database()),
|
||||
Branch: string(serialAccessValue.Branch()),
|
||||
User: string(serialAccessValue.User()),
|
||||
Host: string(serialAccessValue.Host()),
|
||||
@@ -231,6 +246,18 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterDatabases returns all databases that match the given collection indexes.
|
||||
func (tbl *Access) filterDatabases(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Databases[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterBranches returns all branches that match the given collection indexes.
|
||||
func (tbl *Access) filterBranches(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
@@ -281,7 +308,7 @@ func (tbl *Access) gatherPermissions(collectionIndexes []uint32) Permissions {
|
||||
func (tbl *Access) insertDefaultRow() {
|
||||
// Check if the appropriate row already exists
|
||||
for _, value := range tbl.Values {
|
||||
if value.Branch == "%" && value.User == "%" && value.Host == "%" {
|
||||
if value.Database == "%" && value.Branch == "%" && value.User == "%" && value.Host == "%" {
|
||||
// Getting to this state will be disallowed in the future, but if the row exists without any perms, then add
|
||||
// the Write perm
|
||||
if uint64(value.Permissions) == 0 {
|
||||
@@ -290,17 +317,21 @@ func (tbl *Access) insertDefaultRow() {
|
||||
return
|
||||
}
|
||||
}
|
||||
tbl.insert("%", "%", "%", Permissions_Write)
|
||||
tbl.insert("%", "%", "%", "%", Permissions_Write)
|
||||
}
|
||||
|
||||
// insert adds the given expressions to the table. This does not perform any sort of validation whatsoever, so it is
|
||||
// important to ensure that the expressions are valid before insertion.
|
||||
func (tbl *Access) insert(branch string, user string, host string, perms Permissions) {
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
func (tbl *Access) insert(database string, branch string, user string, host string, perms Permissions) {
|
||||
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
|
||||
database = strings.ToLower(FoldExpression(database))
|
||||
branch = strings.ToLower(FoldExpression(branch))
|
||||
user = FoldExpression(user)
|
||||
host = strings.ToLower(FoldExpression(host))
|
||||
// Each expression is capped at 2¹⁶-1 values, so we truncate to 2¹⁶-2 and add the any-match character at the end if it's over
|
||||
if len(database) > math.MaxUint16 {
|
||||
database = string(append([]byte(database[:math.MaxUint16-1]), byte('%')))
|
||||
}
|
||||
if len(branch) > math.MaxUint16 {
|
||||
branch = string(append([]byte(branch[:math.MaxUint16-1]), byte('%')))
|
||||
}
|
||||
@@ -311,16 +342,19 @@ func (tbl *Access) insert(branch string, user string, host string, perms Permiss
|
||||
host = string(append([]byte(host[:math.MaxUint16-1]), byte('%')))
|
||||
}
|
||||
// Add the expression strings to the binlog
|
||||
tbl.binlog.Insert(branch, user, host, uint64(perms))
|
||||
tbl.binlog.Insert(database, branch, user, host, uint64(perms))
|
||||
// Parse and insert the expressions
|
||||
databaseExpr := ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
branchExpr := ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
userExpr := ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
|
||||
hostExpr := ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
nextIdx := uint32(len(tbl.Values))
|
||||
tbl.Databases = append(tbl.Databases, MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
|
||||
tbl.Branches = append(tbl.Branches, MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
|
||||
tbl.Users = append(tbl.Users, MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
|
||||
tbl.Hosts = append(tbl.Hosts, MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
|
||||
tbl.Values = append(tbl.Values, AccessValue{
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
@@ -330,11 +364,13 @@ func (tbl *Access) insert(branch string, user string, host string, perms Permiss
|
||||
|
||||
// Serialize returns the offset for the AccessValue written to the given builder.
|
||||
func (val *AccessValue) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
database := b.CreateString(val.Database)
|
||||
branch := b.CreateString(val.Branch)
|
||||
user := b.CreateString(val.User)
|
||||
host := b.CreateString(val.Host)
|
||||
|
||||
serial.BranchControlAccessValueStart(b)
|
||||
serial.BranchControlAccessValueAddDatabase(b, database)
|
||||
serial.BranchControlAccessValueAddBranch(b, branch)
|
||||
serial.BranchControlAccessValueAddUser(b, user)
|
||||
serial.BranchControlAccessValueAddHost(b, host)
|
||||
|
||||
@@ -35,6 +35,7 @@ type Binlog struct {
|
||||
// BinlogRow is a row within the Binlog.
|
||||
type BinlogRow struct {
|
||||
IsInsert bool
|
||||
Database string
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
@@ -55,6 +56,7 @@ func NewAccessBinlog(vals []AccessValue) *Binlog {
|
||||
for i, val := range vals {
|
||||
rows[i] = BinlogRow{
|
||||
IsInsert: true,
|
||||
Database: val.Database,
|
||||
Branch: val.Branch,
|
||||
User: val.User,
|
||||
Host: val.Host,
|
||||
@@ -74,6 +76,7 @@ func NewNamespaceBinlog(vals []NamespaceValue) *Binlog {
|
||||
for i, val := range vals {
|
||||
rows[i] = BinlogRow{
|
||||
IsInsert: true,
|
||||
Database: val.Database,
|
||||
Branch: val.Branch,
|
||||
User: val.User,
|
||||
Host: val.Host,
|
||||
@@ -126,6 +129,7 @@ func (binlog *Binlog) Deserialize(fb *serial.BranchControlBinlog) error {
|
||||
fb.Rows(serialBinlogRow, i)
|
||||
binlog.rows[i] = BinlogRow{
|
||||
IsInsert: serialBinlogRow.IsInsert(),
|
||||
Database: string(serialBinlogRow.Database()),
|
||||
Branch: string(serialBinlogRow.Branch()),
|
||||
User: string(serialBinlogRow.User()),
|
||||
Host: string(serialBinlogRow.Host()),
|
||||
@@ -163,12 +167,13 @@ func (binlog *Binlog) MergeOverlay(overlay *BinlogOverlay) error {
|
||||
}
|
||||
|
||||
// Insert adds an insert entry to the Binlog.
|
||||
func (binlog *Binlog) Insert(branch string, user string, host string, permissions uint64) {
|
||||
func (binlog *Binlog) Insert(database string, branch string, user string, host string, permissions uint64) {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
binlog.rows = append(binlog.rows, BinlogRow{
|
||||
IsInsert: true,
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
@@ -177,12 +182,13 @@ func (binlog *Binlog) Insert(branch string, user string, host string, permission
|
||||
}
|
||||
|
||||
// Delete adds a delete entry to the Binlog.
|
||||
func (binlog *Binlog) Delete(branch string, user string, host string, permissions uint64) {
|
||||
func (binlog *Binlog) Delete(database string, branch string, user string, host string, permissions uint64) {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
binlog.rows = append(binlog.rows, BinlogRow{
|
||||
IsInsert: false,
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
@@ -197,12 +203,14 @@ func (binlog *Binlog) Rows() []BinlogRow {
|
||||
|
||||
// Serialize returns the offset for the BinlogRow written to the given builder.
|
||||
func (row *BinlogRow) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
database := b.CreateString(row.Database)
|
||||
branch := b.CreateString(row.Branch)
|
||||
user := b.CreateString(row.User)
|
||||
host := b.CreateString(row.Host)
|
||||
|
||||
serial.BranchControlBinlogRowStart(b)
|
||||
serial.BranchControlBinlogRowAddIsInsert(b, row.IsInsert)
|
||||
serial.BranchControlBinlogRowAddDatabase(b, database)
|
||||
serial.BranchControlBinlogRowAddBranch(b, branch)
|
||||
serial.BranchControlBinlogRowAddUser(b, user)
|
||||
serial.BranchControlBinlogRowAddHost(b, host)
|
||||
@@ -211,9 +219,10 @@ func (row *BinlogRow) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
}
|
||||
|
||||
// Insert adds an insert entry to the BinlogOverlay.
|
||||
func (overlay *BinlogOverlay) Insert(branch string, user string, host string, permissions uint64) {
|
||||
func (overlay *BinlogOverlay) Insert(database string, branch string, user string, host string, permissions uint64) {
|
||||
overlay.rows = append(overlay.rows, BinlogRow{
|
||||
IsInsert: true,
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
@@ -222,9 +231,10 @@ func (overlay *BinlogOverlay) Insert(branch string, user string, host string, pe
|
||||
}
|
||||
|
||||
// Delete adds a delete entry to the BinlogOverlay.
|
||||
func (overlay *BinlogOverlay) Delete(branch string, user string, host string, permissions uint64) {
|
||||
func (overlay *BinlogOverlay) Delete(database string, branch string, user string, host string, permissions uint64) {
|
||||
overlay.rows = append(overlay.rows, BinlogRow{
|
||||
IsInsert: false,
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"context"
|
||||
goerrors "errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
@@ -29,22 +28,25 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
|
||||
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
|
||||
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
|
||||
ErrExpressionsTooLong = errors.NewKind("expressions are too long [%q, %q, %q]")
|
||||
ErrInsertingRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q]")
|
||||
ErrUpdatingRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q]")
|
||||
ErrUpdatingToRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q")
|
||||
ErrDeletingRow = errors.NewKind("`%s`@`%s` cannot delete the row [%q, %q, %q]")
|
||||
ErrMissingController = errors.NewKind("a context has a non-nil session but is missing its branch controller")
|
||||
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
|
||||
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
|
||||
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
|
||||
ErrExpressionsTooLong = errors.NewKind("expressions are too long [%q, %q, %q, %q]")
|
||||
ErrInsertingAccessRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q, %q]")
|
||||
ErrInsertingNamespaceRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q]")
|
||||
ErrUpdatingRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q]")
|
||||
ErrUpdatingToRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q] to the new branch expression [%q, %q]")
|
||||
ErrDeletingRow = errors.NewKind("`%s`@`%s` cannot delete the row [%q, %q, %q, %q]")
|
||||
ErrMissingController = errors.NewKind("a context has a non-nil session but is missing its branch controller")
|
||||
)
|
||||
|
||||
// Context represents the interface that must be inherited from the context.
|
||||
type Context interface {
|
||||
GetBranch() (string, error)
|
||||
GetCurrentDatabase() string
|
||||
GetUser() string
|
||||
GetHost() string
|
||||
GetPrivilegeSet() (sql.PrivilegeSet, uint64)
|
||||
GetController() *Controller
|
||||
}
|
||||
|
||||
@@ -57,50 +59,6 @@ type Controller struct {
|
||||
doltConfigDirPath string
|
||||
}
|
||||
|
||||
var (
|
||||
// superUser is the server-wide user that has full, irrevocable permission to do whatever they want to any branch and table
|
||||
superUser string
|
||||
// superHost is the host counterpart of the superUser
|
||||
superHost string
|
||||
// superHostIsLoopback states whether the superHost is a loopback address
|
||||
superHostIsLoopback bool
|
||||
)
|
||||
|
||||
var enabled = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("DOLT_ENABLE_BRANCH_CONTROL") != "" {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
// SetEnabled is a TEMPORARY function just used for testing (so that we don't have to set the env variable)
|
||||
func SetEnabled(value bool) {
|
||||
enabled = value
|
||||
}
|
||||
|
||||
// SetSuperUser sets the server-wide super user to the given user and host combination. The super user has full,
|
||||
// irrevocable permission to do whatever they want to any branch and table.
|
||||
func SetSuperUser(user string, host string) {
|
||||
superUser = user
|
||||
// Check if superHost is a loopback
|
||||
if host == "localhost" || net.ParseIP(host).IsLoopback() {
|
||||
superHost = "localhost"
|
||||
superHostIsLoopback = true
|
||||
} else {
|
||||
superHost = host
|
||||
superHostIsLoopback = false
|
||||
}
|
||||
}
|
||||
|
||||
// IsSuperUser returns whether the given user and host combination is the super user.
|
||||
func IsSuperUser(user string, host string) bool {
|
||||
if user == superUser && ((host == superHost) || (superHostIsLoopback && net.ParseIP(host).IsLoopback())) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CreateDefaultController returns a default controller, which only has a single entry allowing all users to have write
|
||||
// permissions on all branches (only the super user has admin, if a super user has been set). This is equivalent to
|
||||
// passing empty strings to LoadData.
|
||||
@@ -115,9 +73,6 @@ func CreateDefaultController() *Controller {
|
||||
// LoadData loads the data from the given location and returns a controller. Returns the default controller if the
|
||||
// `branchControlFilePath` is empty.
|
||||
func LoadData(branchControlFilePath string, doltConfigDirPath string) (*Controller, error) {
|
||||
if !enabled {
|
||||
return nil, nil
|
||||
}
|
||||
accessTbl := newAccess()
|
||||
controller := &Controller{
|
||||
Access: accessTbl,
|
||||
@@ -171,9 +126,6 @@ func LoadData(branchControlFilePath string, doltConfigDirPath string) (*Controll
|
||||
|
||||
// SaveData saves the data from the context's controller to the location pointed by it.
|
||||
func SaveData(ctx context.Context) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we've got nothing to serialize
|
||||
if branchAwareSession == nil {
|
||||
@@ -216,9 +168,6 @@ func SaveData(ctx context.Context) error {
|
||||
// the context. In these cases, CheckAccess will pass as we want to allow all local commands to ignore branch
|
||||
// permissions.
|
||||
func CheckAccess(ctx context.Context, flags Permissions) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we allow all operations
|
||||
if branchAwareSession == nil {
|
||||
@@ -234,12 +183,13 @@ func CheckAccess(ctx context.Context, flags Permissions) error {
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
database := branchAwareSession.GetCurrentDatabase()
|
||||
branch, err := branchAwareSession.GetBranch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Get the permissions for the branch, user, and host combination
|
||||
_, perms := controller.Access.Match(branch, user, host)
|
||||
_, perms := controller.Access.Match(database, branch, user, host)
|
||||
// If either the flags match or the user is an admin for this branch, then we allow access
|
||||
if (perms&flags == flags) || (perms&Permissions_Admin == Permissions_Admin) {
|
||||
return nil
|
||||
@@ -252,9 +202,6 @@ func CheckAccess(ctx context.Context, flags Permissions) error {
|
||||
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
|
||||
// these cases, CanCreateBranch will pass as we want to allow all local commands to freely create branches.
|
||||
func CanCreateBranch(ctx context.Context, branchName string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we allow the create operation
|
||||
if branchAwareSession == nil {
|
||||
@@ -270,7 +217,8 @@ func CanCreateBranch(ctx context.Context, branchName string) error {
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
if controller.Namespace.CanCreate(branchName, user, host) {
|
||||
database := branchAwareSession.GetCurrentDatabase()
|
||||
if controller.Namespace.CanCreate(database, branchName, user, host) {
|
||||
return nil
|
||||
}
|
||||
return ErrCannotCreateBranch.New(user, host, branchName)
|
||||
@@ -281,9 +229,6 @@ func CanCreateBranch(ctx context.Context, branchName string) error {
|
||||
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
|
||||
// these cases, CanDeleteBranch will pass as we want to allow all local commands to freely delete branches.
|
||||
func CanDeleteBranch(ctx context.Context, branchName string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we allow the delete operation
|
||||
if branchAwareSession == nil {
|
||||
@@ -299,8 +244,9 @@ func CanDeleteBranch(ctx context.Context, branchName string) error {
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
database := branchAwareSession.GetCurrentDatabase()
|
||||
// Get the permissions for the branch, user, and host combination
|
||||
_, perms := controller.Access.Match(branchName, user, host)
|
||||
_, perms := controller.Access.Match(database, branchName, user, host)
|
||||
// If the user has the write or admin flags, then we allow access
|
||||
if (perms&Permissions_Write == Permissions_Write) || (perms&Permissions_Admin == Permissions_Admin) {
|
||||
return nil
|
||||
@@ -312,9 +258,6 @@ func CanDeleteBranch(ctx context.Context, branchName string) error {
|
||||
// context is missing some functionality that is needed to perform the addition, such as a user or the Controller, then
|
||||
// this simply returns.
|
||||
func AddAdminForContext(ctx context.Context, branchName string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
if branchAwareSession == nil {
|
||||
return nil
|
||||
@@ -326,15 +269,16 @@ func AddAdminForContext(ctx context.Context, branchName string) error {
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
database := branchAwareSession.GetCurrentDatabase()
|
||||
// Check if we already have admin permissions for the given branch, as there's no need to do another insertion if so
|
||||
controller.Access.RWMutex.RLock()
|
||||
_, modPerms := controller.Access.Match(branchName, user, host)
|
||||
_, modPerms := controller.Access.Match(database, branchName, user, host)
|
||||
controller.Access.RWMutex.RUnlock()
|
||||
if modPerms&Permissions_Admin == Permissions_Admin {
|
||||
return nil
|
||||
}
|
||||
controller.Access.RWMutex.Lock()
|
||||
controller.Access.insert(branchName, user, host, Permissions_Admin)
|
||||
controller.Access.insert(database, branchName, user, host, Permissions_Admin)
|
||||
controller.Access.RWMutex.Unlock()
|
||||
return SaveData(ctx)
|
||||
}
|
||||
@@ -351,3 +295,29 @@ func GetBranchAwareSession(ctx context.Context) Context {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasDatabasePrivileges returns whether the given context's user has the correct privileges to modify any table entries
|
||||
// that match the given database. The following are the required privileges:
|
||||
//
|
||||
// Global Space: SUPER, GRANT
|
||||
// Global Space: CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
|
||||
// Database Space: CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
|
||||
//
|
||||
// Any user that may grant SUPER is considered to be a super user. In addition, any user that may grant the suite of
|
||||
// alteration privileges is also considered a super user. The SUPER privilege does not exist at the database level, it
|
||||
// is a global privilege only.
|
||||
func HasDatabasePrivileges(ctx Context, database string) bool {
|
||||
if ctx == nil {
|
||||
return true
|
||||
}
|
||||
privSet, counter := ctx.GetPrivilegeSet()
|
||||
if counter == 0 {
|
||||
return false
|
||||
}
|
||||
hasSuper := privSet.Has(sql.PrivilegeType_Super, sql.PrivilegeType_Grant)
|
||||
isGlobalAdmin := privSet.Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
|
||||
sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_Grant)
|
||||
isDatabaseAdmin := privSet.Database(database).Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
|
||||
sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_Grant)
|
||||
return hasSuper || isGlobalAdmin || isDatabaseAdmin
|
||||
}
|
||||
|
||||
@@ -31,41 +31,50 @@ type Namespace struct {
|
||||
access *Access
|
||||
binlog *Binlog
|
||||
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []NamespaceValue
|
||||
RWMutex *sync.RWMutex
|
||||
Databases []MatchExpression
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []NamespaceValue
|
||||
RWMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// NamespaceValue contains the user-facing values of a particular row.
|
||||
type NamespaceValue struct {
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
Database string
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
}
|
||||
|
||||
// newNamespace returns a new Namespace.
|
||||
func newNamespace(accessTbl *Access) *Namespace {
|
||||
return &Namespace{
|
||||
binlog: NewNamespaceBinlog(nil),
|
||||
access: accessTbl,
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
binlog: NewNamespaceBinlog(nil),
|
||||
access: accessTbl,
|
||||
Databases: nil,
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// CanCreate checks the given branch, and returns whether the given user and host combination is able to create that
|
||||
// branch. Handles the super user case.
|
||||
func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
|
||||
// Super user can always create branches
|
||||
if IsSuperUser(user, host) {
|
||||
// CanCreate checks the given database and branch, and returns whether the given user and host combination is able to
|
||||
// create that branch. Handles the super user case.
|
||||
func (tbl *Namespace) CanCreate(database string, branch string, user string, host string) bool {
|
||||
filteredIndexes := Match(tbl.Databases, database, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
// If there are no database entries, then the Namespace is unrestricted
|
||||
if len(filteredIndexes) == 0 {
|
||||
indexPool.Put(filteredIndexes)
|
||||
return true
|
||||
}
|
||||
matchedSet := Match(tbl.Branches, branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
|
||||
filteredBranches := tbl.filterBranches(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
matchedSet := Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredBranches)
|
||||
// If there are no branch entries, then the Namespace is unrestricted
|
||||
if len(matchedSet) == 0 {
|
||||
indexPool.Put(matchedSet)
|
||||
@@ -74,7 +83,7 @@ func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
|
||||
|
||||
// We take either the longest match, or the set of longest matches if multiple matches have the same length
|
||||
longest := -1
|
||||
filteredIndexes := indexPool.Get().([]uint32)[:0]
|
||||
filteredIndexes = indexPool.Get().([]uint32)[:0]
|
||||
for _, matched := range matchedSet {
|
||||
matchedValue := tbl.Values[matched]
|
||||
// If we've found a longer match, then we reset the slice. We append to it in the following if statement.
|
||||
@@ -102,11 +111,11 @@ func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
|
||||
// returns -1. Assumes that the given expressions have already been folded.
|
||||
func (tbl *Namespace) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
|
||||
// GetIndex returns the index of the given database, branch, user, and host expressions. If the expressions cannot be
|
||||
// found, returns -1. Assumes that the given expressions have already been folded.
|
||||
func (tbl *Namespace) GetIndex(databaseExpr string, branchExpr string, userExpr string, hostExpr string) int {
|
||||
for i, value := range tbl.Values {
|
||||
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
if value.Database == databaseExpr && value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
@@ -131,11 +140,15 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
// Serialize the binlog
|
||||
binlog := tbl.binlog.Serialize(b)
|
||||
// Initialize field offset slices
|
||||
databaseOffsets := make([]flatbuffers.UOffsetT, len(tbl.Databases))
|
||||
branchOffsets := make([]flatbuffers.UOffsetT, len(tbl.Branches))
|
||||
userOffsets := make([]flatbuffers.UOffsetT, len(tbl.Users))
|
||||
hostOffsets := make([]flatbuffers.UOffsetT, len(tbl.Hosts))
|
||||
valueOffsets := make([]flatbuffers.UOffsetT, len(tbl.Values))
|
||||
// Get field offsets
|
||||
for i, matchExpr := range tbl.Databases {
|
||||
databaseOffsets[i] = matchExpr.Serialize(b)
|
||||
}
|
||||
for i, matchExpr := range tbl.Branches {
|
||||
branchOffsets[i] = matchExpr.Serialize(b)
|
||||
}
|
||||
@@ -149,6 +162,11 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
valueOffsets[i] = val.Serialize(b)
|
||||
}
|
||||
// Get the field vectors
|
||||
serial.BranchControlNamespaceStartDatabasesVector(b, len(databaseOffsets))
|
||||
for i := len(databaseOffsets) - 1; i >= 0; i-- {
|
||||
b.PrependUOffsetT(databaseOffsets[i])
|
||||
}
|
||||
databases := b.EndVector(len(databaseOffsets))
|
||||
serial.BranchControlNamespaceStartBranchesVector(b, len(branchOffsets))
|
||||
for i := len(branchOffsets) - 1; i >= 0; i-- {
|
||||
b.PrependUOffsetT(branchOffsets[i])
|
||||
@@ -172,6 +190,7 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
// Write the table
|
||||
serial.BranchControlNamespaceStart(b)
|
||||
serial.BranchControlNamespaceAddBinlog(b, binlog)
|
||||
serial.BranchControlNamespaceAddDatabases(b, databases)
|
||||
serial.BranchControlNamespaceAddBranches(b, branches)
|
||||
serial.BranchControlNamespaceAddUsers(b, users)
|
||||
serial.BranchControlNamespaceAddHosts(b, hosts)
|
||||
@@ -189,7 +208,10 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
|
||||
return fmt.Errorf("cannot deserialize to a non-empty namespace table")
|
||||
}
|
||||
// Verify that all fields have the same length
|
||||
if fb.BranchesLength() != fb.UsersLength() || fb.UsersLength() != fb.HostsLength() || fb.HostsLength() != fb.ValuesLength() {
|
||||
if fb.DatabasesLength() != fb.BranchesLength() ||
|
||||
fb.BranchesLength() != fb.UsersLength() ||
|
||||
fb.UsersLength() != fb.HostsLength() ||
|
||||
fb.HostsLength() != fb.ValuesLength() {
|
||||
return fmt.Errorf("cannot deserialize a namespace table with differing field lengths")
|
||||
}
|
||||
// Read the binlog
|
||||
@@ -201,10 +223,17 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
|
||||
return err
|
||||
}
|
||||
// Initialize every slice
|
||||
tbl.Databases = make([]MatchExpression, fb.DatabasesLength())
|
||||
tbl.Branches = make([]MatchExpression, fb.BranchesLength())
|
||||
tbl.Users = make([]MatchExpression, fb.UsersLength())
|
||||
tbl.Hosts = make([]MatchExpression, fb.HostsLength())
|
||||
tbl.Values = make([]NamespaceValue, fb.ValuesLength())
|
||||
// Read the databases
|
||||
for i := 0; i < fb.DatabasesLength(); i++ {
|
||||
serialMatchExpr := &serial.BranchControlMatchExpression{}
|
||||
fb.Databases(serialMatchExpr, i)
|
||||
tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr)
|
||||
}
|
||||
// Read the branches
|
||||
for i := 0; i < fb.BranchesLength(); i++ {
|
||||
serialMatchExpr := &serial.BranchControlMatchExpression{}
|
||||
@@ -228,14 +257,27 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
|
||||
serialNamespaceValue := &serial.BranchControlNamespaceValue{}
|
||||
fb.Values(serialNamespaceValue, i)
|
||||
tbl.Values[i] = NamespaceValue{
|
||||
Branch: string(serialNamespaceValue.Branch()),
|
||||
User: string(serialNamespaceValue.User()),
|
||||
Host: string(serialNamespaceValue.Host()),
|
||||
Database: string(serialNamespaceValue.Database()),
|
||||
Branch: string(serialNamespaceValue.Branch()),
|
||||
User: string(serialNamespaceValue.User()),
|
||||
Host: string(serialNamespaceValue.Host()),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterDatabases returns all databases that match the given collection indexes.
|
||||
func (tbl *Namespace) filterDatabases(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Databases[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterBranches returns all branches that match the given collection indexes.
|
||||
func (tbl *Namespace) filterBranches(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
@@ -274,11 +316,13 @@ func (tbl *Namespace) filterHosts(filters []uint32) []MatchExpression {
|
||||
|
||||
// Serialize returns the offset for the NamespaceValue written to the given builder.
|
||||
func (val *NamespaceValue) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
|
||||
database := b.CreateString(val.Database)
|
||||
branch := b.CreateString(val.Branch)
|
||||
user := b.CreateString(val.User)
|
||||
host := b.CreateString(val.Host)
|
||||
|
||||
serial.BranchControlNamespaceValueStart(b)
|
||||
serial.BranchControlNamespaceValueAddDatabase(b, database)
|
||||
serial.BranchControlNamespaceValueAddBranch(b, branch)
|
||||
serial.BranchControlNamespaceValueAddUser(b, user)
|
||||
serial.BranchControlNamespaceValueAddHost(b, host)
|
||||
|
||||
5
go/libraries/doltcore/env/actions/clone.go
vendored
5
go/libraries/doltcore/env/actions/clone.go
vendored
@@ -39,7 +39,6 @@ import (
|
||||
)
|
||||
|
||||
var ErrRepositoryExists = errors.New("data repository already exists")
|
||||
var ErrFailedToInitRepo = errors.New("")
|
||||
var ErrFailedToCreateDirectory = errors.New("unable to create directories")
|
||||
var ErrFailedToAccessDir = errors.New("unable to access directories")
|
||||
var ErrFailedToCreateRepoStateWithRemote = errors.New("unable to create repo state with remote")
|
||||
@@ -76,7 +75,7 @@ func EnvForClone(ctx context.Context, nbf *types.NomsBinFormat, r env.Remote, di
|
||||
dEnv := env.Load(ctx, homeProvider, newFs, doltdb.LocalDirDoltDB, version)
|
||||
err = dEnv.InitRepoWithNoData(ctx, nbf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w; %s", ErrFailedToInitRepo, err.Error())
|
||||
return nil, fmt.Errorf("failed to init repo: %w", err)
|
||||
}
|
||||
|
||||
dEnv.RSLoadErr = nil
|
||||
@@ -280,7 +279,7 @@ func InitEmptyClonedRepo(ctx context.Context, dEnv *env.DoltEnv) error {
|
||||
|
||||
err := dEnv.InitDBWithTime(ctx, types.Format_Default, name, email, initBranch, datas.CommitNowFunc())
|
||||
if err != nil {
|
||||
return ErrFailedToInitRepo
|
||||
return fmt.Errorf("failed to init repo: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
2
go/libraries/doltcore/env/actions/remotes.go
vendored
2
go/libraries/doltcore/env/actions/remotes.go
vendored
@@ -365,7 +365,7 @@ func FetchRemoteBranch(
|
||||
func FetchRefSpecs(ctx context.Context, dbData env.DbData, srcDB *doltdb.DoltDB, refSpecs []ref.RemoteRefSpec, remote env.Remote, mode ref.UpdateMode, progStarter ProgStarter, progStopper ProgStopper) error {
|
||||
branchRefs, err := srcDB.GetHeadRefs(ctx)
|
||||
if err != nil {
|
||||
return env.ErrFailedToReadDb
|
||||
return fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
|
||||
}
|
||||
|
||||
for _, rs := range refSpecs {
|
||||
|
||||
5
go/libraries/doltcore/env/environment.go
vendored
5
go/libraries/doltcore/env/environment.go
vendored
@@ -59,7 +59,6 @@ const (
|
||||
|
||||
var zeroHashStr = (hash.Hash{}).String()
|
||||
|
||||
var ErrPreexistingDoltDir = errors.New(".dolt dir already exists")
|
||||
var ErrStateUpdate = errors.New("error updating local data repo state")
|
||||
var ErrMarshallingSchema = errors.New("error marshalling schema")
|
||||
var ErrInvalidCredsFile = errors.New("invalid creds file")
|
||||
@@ -389,7 +388,7 @@ func (dEnv *DoltEnv) createDirectories(dir string) (string, error) {
|
||||
}
|
||||
|
||||
if dEnv.hasDoltDir(dir) {
|
||||
return "", ErrPreexistingDoltDir
|
||||
return "", fmt.Errorf(".dolt directory already exists at '%s'", dir)
|
||||
}
|
||||
|
||||
absDataDir := filepath.Join(absPath, dbfactory.DoltDataDir)
|
||||
@@ -923,7 +922,7 @@ func (dEnv *DoltEnv) RemoveRemote(ctx context.Context, name string) error {
|
||||
ddb := dEnv.DoltDB
|
||||
refs, err := ddb.GetRemoteRefs(ctx)
|
||||
if err != nil {
|
||||
return ErrFailedToReadFromDb
|
||||
return fmt.Errorf("%w: %s", ErrFailedToReadFromDb, err.Error())
|
||||
}
|
||||
|
||||
for _, r := range refs {
|
||||
|
||||
3
go/libraries/doltcore/env/paths.go
vendored
3
go/libraries/doltcore/env/paths.go
vendored
@@ -20,6 +20,7 @@ import (
|
||||
"path/filepath"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -42,7 +43,7 @@ type HomeDirProvider func() (string, error)
|
||||
// provide a different directory where the root .dolt directory should be located and global state will be stored there.
|
||||
func GetCurrentUserHomeDir() (string, error) {
|
||||
if doltRootPath, ok := os.LookupEnv(doltRootPathEnvVar); ok && doltRootPath != "" {
|
||||
return doltRootPath, nil
|
||||
return filesys.LocalFS.Abs(doltRootPath)
|
||||
}
|
||||
|
||||
var home string
|
||||
|
||||
3
go/libraries/doltcore/env/remotes.go
vendored
3
go/libraries/doltcore/env/remotes.go
vendored
@@ -205,10 +205,9 @@ func NewPushOpts(ctx context.Context, apr *argparser.ArgParseResults, rsr RepoSt
|
||||
hasRef, err := ddb.HasRef(ctx, currentBranch)
|
||||
|
||||
if err != nil {
|
||||
return nil, ErrFailedToReadDb
|
||||
return nil, fmt.Errorf("%w: %s", ErrFailedToReadDb, err.Error())
|
||||
} else if !hasRef {
|
||||
return nil, fmt.Errorf("%w: '%s'", ErrUnknownBranch, currentBranch.GetPath())
|
||||
|
||||
}
|
||||
|
||||
src := refSpec.SrcRef(currentBranch)
|
||||
|
||||
@@ -124,44 +124,6 @@ func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map
|
||||
tblToStats := make(map[string]*MergeStats)
|
||||
err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, spec.MergeCSpecStr, tblToStats)
|
||||
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
// Reload roots since the above method writes new values to the working set
|
||||
roots, err := dEnv.Roots(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
ws, err := dEnv.WorkingSet(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
var mergeParentCommits []*doltdb.Commit
|
||||
if ws.MergeActive() {
|
||||
mergeParentCommits = []*doltdb.Commit{ws.MergeState().Commit()}
|
||||
}
|
||||
|
||||
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{
|
||||
Message: spec.Msg,
|
||||
Date: spec.Date,
|
||||
AllowEmpty: spec.AllowEmpty,
|
||||
Force: spec.Force,
|
||||
Name: spec.Name,
|
||||
Email: spec.Email,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return tblToStats, fmt.Errorf("%w; failed to commit", err)
|
||||
}
|
||||
|
||||
err = dEnv.ClearMerge(ctx)
|
||||
if err != nil {
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
return tblToStats, err
|
||||
}
|
||||
|
||||
|
||||
@@ -183,23 +183,25 @@ func equalRows(old, new sql.Row, sch sql.Schema) (bool, error) {
|
||||
// special case time comparison to account
|
||||
// for precision changes between formats
|
||||
if _, ok := old[i].(time.Time); ok {
|
||||
if old[i], err = sql.Int64.Convert(old[i]); err != nil {
|
||||
var o, n interface{}
|
||||
if o, err = sql.Int64.Convert(old[i]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if new[i], err = sql.Int64.Convert(new[i]); err != nil {
|
||||
if n, err = sql.Int64.Convert(new[i]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cmp, err = sql.Int64.Compare(o, n); err != nil {
|
||||
return false, err
|
||||
}
|
||||
cmp, err = sql.Int64.Compare(old[i], new[i])
|
||||
} else {
|
||||
cmp, err = sch[i].Type.Compare(old[i], new[i])
|
||||
if cmp, err = sch[i].Type.Compare(old[i], new[i]); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if cmp != 0 {
|
||||
if cmp != 0 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -18,9 +18,14 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -36,7 +41,9 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/remotesrv"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/clusterdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
)
|
||||
|
||||
@@ -59,12 +66,17 @@ type Controller struct {
|
||||
sinterceptor serverinterceptor
|
||||
cinterceptor clientinterceptor
|
||||
lgr *logrus.Logger
|
||||
grpcCreds credentials.PerRPCCredentials
|
||||
|
||||
provider dbProvider
|
||||
iterSessions IterSessions
|
||||
killQuery func(uint32)
|
||||
killConnection func(uint32) error
|
||||
|
||||
jwks *jwtauth.MultiJWKS
|
||||
tlsCfg *tls.Config
|
||||
grpcCreds credentials.PerRPCCredentials
|
||||
pub ed25519.PublicKey
|
||||
priv ed25519.PrivateKey
|
||||
}
|
||||
|
||||
type sqlvars interface {
|
||||
@@ -112,9 +124,49 @@ func NewController(lgr *logrus.Logger, cfg Config, pCfg config.ReadWriteConfig)
|
||||
ret.cinterceptor.lgr = lgr.WithFields(logrus.Fields{})
|
||||
ret.cinterceptor.setRole(role, epoch)
|
||||
ret.cinterceptor.roleSetter = roleSetter
|
||||
|
||||
ret.tlsCfg, err = ret.outboundTlsConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ret.pub, ret.priv, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyID := creds.PubKeyToKID(ret.pub)
|
||||
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
|
||||
ret.grpcCreds = &creds.RPCCreds{
|
||||
PrivKey: ret.priv,
|
||||
Audience: creds.RemotesAPIAudience,
|
||||
Issuer: creds.ClientIssuer,
|
||||
KeyID: keyIDStr,
|
||||
RequireTLS: false,
|
||||
}
|
||||
|
||||
ret.jwks = ret.standbyRemotesJWKS()
|
||||
ret.sinterceptor.keyProvider = ret.jwks
|
||||
ret.sinterceptor.jwtExpected = JWTExpectations()
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (c *Controller) Run() {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
c.jwks.Run()
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (c *Controller) GracefulStop() error {
|
||||
c.jwks.GracefulStop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) ManageSystemVariables(variables sqlvars) {
|
||||
if c == nil {
|
||||
return
|
||||
@@ -198,7 +250,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
|
||||
}
|
||||
|
||||
func (c *Controller) gRPCDialProvider(denv *env.DoltEnv) dbfactory.GRPCDialProvider {
|
||||
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.cfg, c.grpcCreds}
|
||||
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.tlsCfg, c.grpcCreds}
|
||||
}
|
||||
|
||||
func (c *Controller) RegisterStoredProcedures(store procedurestore) {
|
||||
@@ -412,23 +464,9 @@ func (c *Controller) RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.Server
|
||||
args = sqle.RemoteSrvServerArgs(ctx, args)
|
||||
args.DBCache = remotesrvStoreCache{args.DBCache, c}
|
||||
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
keyID := creds.PubKeyToKID(pub)
|
||||
keyID := creds.PubKeyToKID(c.pub)
|
||||
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
|
||||
|
||||
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, pub)
|
||||
|
||||
c.grpcCreds = &creds.RPCCreds{
|
||||
PrivKey: priv,
|
||||
Audience: creds.RemotesAPIAudience,
|
||||
Issuer: creds.ClientIssuer,
|
||||
KeyID: keyIDStr,
|
||||
RequireTLS: false,
|
||||
}
|
||||
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub)
|
||||
|
||||
return args
|
||||
}
|
||||
@@ -565,3 +603,129 @@ func (c *Controller) waitForHooksToReplicate() error {
|
||||
return errors.New("cluster/controller: failed to transition from primary to standby gracefully; could not replicate databases to standby in a timely manner.")
|
||||
}
|
||||
}
|
||||
|
||||
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
|
||||
// following semantics:
|
||||
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
|
||||
// all of which are trusted roots for the outbound connections the
|
||||
// remotestorage client establishes.
|
||||
// * The certificate chain presented by the server must validate to a root
|
||||
// which was present in tls_ca. In particular, every certificate in the chain
|
||||
// must be within its validity window, the signatures must be valid, key usage
|
||||
// and isCa must be correctly set for the roots and the intermediates, and the
|
||||
// leaf must have extended key usage server auth.
|
||||
// * On the other hand, no verification is done against the SAN or the Subject
|
||||
// of the certificate.
|
||||
//
|
||||
// We use these TLS semantics for both connections to the gRPC endpoint which
|
||||
// is the actual remotesapi, and for connections to any HTTPS endpoints to
|
||||
// which the gRPC service returns URLs. For now, this works perfectly for our
|
||||
// use case, but it's tightly coupled to `cluster:` deployment topologies and
|
||||
// the likes.
|
||||
//
|
||||
// If tls_ca is not set then default TLS handling is performed. In particular,
|
||||
// if the remotesapi endpoints is HTTPS, then the system roots are used and
|
||||
// ServerName is verified against the presented URL SANs of the certificates.
|
||||
//
|
||||
// This tls Config is used for fetching JWKS, for outbound GRPC connections and
|
||||
// for outbound https connections on the URLs that the GRPC services return.
|
||||
func (c *Controller) outboundTlsConfig() (*tls.Config, error) {
|
||||
tlsCA := c.cfg.RemotesAPIConfig().TLSCA()
|
||||
if tlsCA == "" {
|
||||
return nil, nil
|
||||
}
|
||||
urlmatches := c.cfg.RemotesAPIConfig().ServerNameURLMatches()
|
||||
dnsmatches := c.cfg.RemotesAPIConfig().ServerNameDNSMatches()
|
||||
pem, err := os.ReadFile(tlsCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if ok := roots.AppendCertsFromPEM(pem); !ok {
|
||||
return nil, errors.New("error loading ca roots from " + tlsCA)
|
||||
}
|
||||
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
certs := make([]*x509.Certificate, len(rawCerts))
|
||||
var err error
|
||||
for i, asn1Data := range rawCerts {
|
||||
certs[i], err = x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
CurrentTime: time.Now(),
|
||||
Intermediates: x509.NewCertPool(),
|
||||
KeyUsages: keyUsages,
|
||||
}
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(urlmatches) > 0 {
|
||||
found := false
|
||||
for _, n := range urlmatches {
|
||||
for _, cn := range certs[0].URIs {
|
||||
if n == cn.String() {
|
||||
found = true
|
||||
}
|
||||
break
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("expected certificate to match something in server_name_urls, but it did not")
|
||||
}
|
||||
}
|
||||
if len(dnsmatches) > 0 {
|
||||
found := false
|
||||
for _, n := range dnsmatches {
|
||||
for _, cn := range certs[0].DNSNames {
|
||||
if n == cn {
|
||||
found = true
|
||||
}
|
||||
break
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("expected certificate to match something in server_name_dns, but it did not")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &tls.Config{
|
||||
// We have to InsecureSkipVerify because ServerName is always
|
||||
// set by the grpc dial provider and golang tls.Config does not
|
||||
// have good support for performing certificate validation
|
||||
// without server name validation.
|
||||
InsecureSkipVerify: true,
|
||||
|
||||
VerifyPeerCertificate: verifyFunc,
|
||||
|
||||
NextProtos: []string{"h2"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Controller) standbyRemotesJWKS() *jwtauth.MultiJWKS {
|
||||
client := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: c.tlsCfg,
|
||||
ForceAttemptHTTP2: true,
|
||||
},
|
||||
}
|
||||
urls := make([]string, len(c.cfg.StandbyRemotes()))
|
||||
for i, r := range c.cfg.StandbyRemotes() {
|
||||
urls[i] = strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, ".well-known/jwks.json", -1)
|
||||
}
|
||||
return jwtauth.NewMultiJWKS(c.lgr.WithFields(logrus.Fields{"component": "jwks-key-provider"}), urls, client)
|
||||
}
|
||||
|
||||
@@ -16,15 +16,13 @@ package cluster
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/backoff"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
|
||||
)
|
||||
@@ -35,24 +33,26 @@ import (
|
||||
// - client interceptors for transmitting our replication role.
|
||||
// - do not use environment credentials. (for now).
|
||||
type grpcDialProvider struct {
|
||||
orig dbfactory.GRPCDialProvider
|
||||
ci *clientinterceptor
|
||||
cfg Config
|
||||
creds credentials.PerRPCCredentials
|
||||
orig dbfactory.GRPCDialProvider
|
||||
ci *clientinterceptor
|
||||
tlsCfg *tls.Config
|
||||
creds credentials.PerRPCCredentials
|
||||
}
|
||||
|
||||
func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GRPCRemoteConfig, error) {
|
||||
tlsConfig, err := p.tlsConfig()
|
||||
if err != nil {
|
||||
return dbfactory.GRPCRemoteConfig{}, err
|
||||
}
|
||||
config.TLSConfig = tlsConfig
|
||||
config.TLSConfig = p.tlsCfg
|
||||
config.Creds = p.creds
|
||||
if config.Creds != nil && config.TLSConfig != nil {
|
||||
if c, ok := config.Creds.(*creds.RPCCreds); ok {
|
||||
c.RequireTLS = true
|
||||
}
|
||||
}
|
||||
config.WithEnvCreds = false
|
||||
cfg, err := p.orig.GetGRPCDialParams(config)
|
||||
if err != nil {
|
||||
return dbfactory.GRPCRemoteConfig{}, err
|
||||
}
|
||||
|
||||
cfg.DialOptions = append(cfg.DialOptions, p.ci.Options()...)
|
||||
cfg.DialOptions = append(cfg.DialOptions, grpc.WithConnectParams(grpc.ConnectParams{
|
||||
Backoff: backoff.Config{
|
||||
@@ -63,114 +63,6 @@ func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
|
||||
},
|
||||
MinConnectTimeout: 250 * time.Millisecond,
|
||||
}))
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
|
||||
// following semantics:
|
||||
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
|
||||
// all of which are trusted roots for the outbound connections the
|
||||
// remotestorage client establishes.
|
||||
// * The certificate chain presented by the server must validate to a root
|
||||
// which was present in tls_ca. In particular, every certificate in the chain
|
||||
// must be within its validity window, the signatures must be valid, key usage
|
||||
// and isCa must be correctly set for the roots and the intermediates, and the
|
||||
// leaf must have extended key usage server auth.
|
||||
// * On the other hand, no verification is done against the SAN or the Subject
|
||||
// of the certificate.
|
||||
//
|
||||
// We use these TLS semantics for both connections to the gRPC endpoint which
|
||||
// is the actual remotesapi, and for connections to any HTTPS endpoints to
|
||||
// which the gRPC service returns URLs. For now, this works perfectly for our
|
||||
// use case, but it's tightly coupled to `cluster:` deployment topologies and
|
||||
// the likes.
|
||||
//
|
||||
// If tls_ca is not set then default TLS handling is performed. In particular,
|
||||
// if the remotesapi endpoints is HTTPS, then the system roots are used and
|
||||
// ServerName is verified against the presented URL SANs of the certificates.
|
||||
func (p grpcDialProvider) tlsConfig() (*tls.Config, error) {
|
||||
tlsCA := p.cfg.RemotesAPIConfig().TLSCA()
|
||||
if tlsCA == "" {
|
||||
return nil, nil
|
||||
}
|
||||
urlmatches := p.cfg.RemotesAPIConfig().ServerNameURLMatches()
|
||||
dnsmatches := p.cfg.RemotesAPIConfig().ServerNameDNSMatches()
|
||||
pem, err := ioutil.ReadFile(tlsCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
roots := x509.NewCertPool()
|
||||
if ok := roots.AppendCertsFromPEM(pem); !ok {
|
||||
return nil, errors.New("error loading ca roots from " + tlsCA)
|
||||
}
|
||||
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
certs := make([]*x509.Certificate, len(rawCerts))
|
||||
var err error
|
||||
for i, asn1Data := range rawCerts {
|
||||
certs[i], err = x509.ParseCertificate(asn1Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
|
||||
opts := x509.VerifyOptions{
|
||||
Roots: roots,
|
||||
CurrentTime: time.Now(),
|
||||
Intermediates: x509.NewCertPool(),
|
||||
KeyUsages: keyUsages,
|
||||
}
|
||||
for _, cert := range certs[1:] {
|
||||
opts.Intermediates.AddCert(cert)
|
||||
}
|
||||
_, err = certs[0].Verify(opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(urlmatches) > 0 {
|
||||
found := false
|
||||
for _, n := range urlmatches {
|
||||
for _, cn := range certs[0].URIs {
|
||||
if n == cn.String() {
|
||||
found = true
|
||||
}
|
||||
break
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("expected certificate to match something in server_name_urls, but it did not")
|
||||
}
|
||||
}
|
||||
if len(dnsmatches) > 0 {
|
||||
found := false
|
||||
for _, n := range dnsmatches {
|
||||
for _, cn := range certs[0].DNSNames {
|
||||
if n == cn {
|
||||
found = true
|
||||
}
|
||||
break
|
||||
}
|
||||
if found {
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return errors.New("expected certificate to match something in server_name_dns, but it did not")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return &tls.Config{
|
||||
// We have to InsecureSkipVerify because ServerName is always
|
||||
// set by the grpc dial provider and golang tls.Config does not
|
||||
// have good support for performing certificate validation
|
||||
// without server name validation.
|
||||
InsecureSkipVerify: true,
|
||||
|
||||
VerifyPeerCertificate: verifyFunc,
|
||||
|
||||
NextProtos: []string{"h2"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -17,14 +17,18 @@ package cluster
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
|
||||
)
|
||||
|
||||
const clusterRoleHeader = "x-dolt-cluster-role"
|
||||
@@ -158,12 +162,20 @@ func (ci *clientinterceptor) Options() []grpc.DialOption {
|
||||
// * for incoming requests which are not standby, it will currently fail the
|
||||
// requests with codes.Unauthenticated. Eventually, it will allow read-only
|
||||
// traffic through which is authenticated and authorized.
|
||||
//
|
||||
// The serverinterceptor is responsible for authenticating incoming requests
|
||||
// from standby replicas. It is instantiated with a jwtauth.KeyProvider and
|
||||
// some jwt.Expected. Incoming requests must have a valid, unexpired, signed
|
||||
// JWT, signed by a key accessible in the KeyProvider.
|
||||
type serverinterceptor struct {
|
||||
lgr *logrus.Entry
|
||||
role Role
|
||||
epoch int
|
||||
mu sync.Mutex
|
||||
roleSetter func(role string, epoch int)
|
||||
|
||||
keyProvider jwtauth.KeyProvider
|
||||
jwtExpected jwt.Expected
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
|
||||
@@ -174,6 +186,9 @@ func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
|
||||
fromStandby = si.handleRequestHeaders(md, role, epoch)
|
||||
}
|
||||
if fromStandby {
|
||||
if err := si.authenticate(ss.Context()); err != nil {
|
||||
return err
|
||||
}
|
||||
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
|
||||
role, epoch := si.getRole()
|
||||
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
|
||||
@@ -204,6 +219,9 @@ func (si *serverinterceptor) Unary() grpc.UnaryServerInterceptor {
|
||||
fromStandby = si.handleRequestHeaders(md, role, epoch)
|
||||
}
|
||||
if fromStandby {
|
||||
if err := si.authenticate(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
|
||||
role, epoch := si.getRole()
|
||||
if err := grpc.SetHeader(ctx, metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
|
||||
@@ -272,3 +290,26 @@ func (si *serverinterceptor) getRole() (Role, int) {
|
||||
defer si.mu.Unlock()
|
||||
return si.role, si.epoch
|
||||
}
|
||||
|
||||
func (si *serverinterceptor) authenticate(ctx context.Context) error {
|
||||
if md, ok := metadata.FromIncomingContext(ctx); ok {
|
||||
auths := md.Get("authorization")
|
||||
if len(auths) != 1 {
|
||||
si.lgr.Info("incoming standby request had no authorization")
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
auth := auths[0]
|
||||
if !strings.HasPrefix(auth, "Bearer ") {
|
||||
si.lgr.Info("incoming standby request had malformed authentication header")
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
auth = strings.TrimPrefix(auth, "Bearer ")
|
||||
_, err := jwtauth.ValidateJWT(auth, time.Now(), si.keyProvider, si.jwtExpected)
|
||||
if err != nil {
|
||||
si.lgr.Infof("incoming standby request authorization header failed to verify: %v", err)
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return status.Error(codes.Unauthenticated, "unauthenticated")
|
||||
}
|
||||
|
||||
@@ -16,13 +16,15 @@ package cluster
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
@@ -30,6 +32,10 @@ import (
|
||||
"google.golang.org/grpc/health/grpc_health_v1"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
@@ -51,6 +57,53 @@ func noopSetRole(string, int) {
|
||||
|
||||
var lgr = logrus.StandardLogger().WithFields(logrus.Fields{})
|
||||
|
||||
var kp jwtauth.KeyProvider
|
||||
var pub ed25519.PublicKey
|
||||
var priv ed25519.PrivateKey
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
pub, priv, err = ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
kp = keyProvider{pub}
|
||||
}
|
||||
|
||||
type keyProvider struct {
|
||||
ed25519.PublicKey
|
||||
}
|
||||
|
||||
func (p keyProvider) GetKey(string) ([]jose.JSONWebKey, error) {
|
||||
return []jose.JSONWebKey{{
|
||||
Key: p.PublicKey,
|
||||
KeyID: "1",
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func newJWT() string {
|
||||
key := jose.SigningKey{Algorithm: jose.EdDSA, Key: priv}
|
||||
opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
|
||||
"kid": "1",
|
||||
}}
|
||||
signer, err := jose.NewSigner(key, opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
jwtBuilder := jwt.Signed(signer)
|
||||
jwtBuilder = jwtBuilder.Claims(jwt.Claims{
|
||||
Audience: []string{"some_audience"},
|
||||
Issuer: "some_issuer",
|
||||
Subject: "some_subject",
|
||||
Expiry: jwt.NewNumericDate(time.Now().Add(30 * time.Second)),
|
||||
})
|
||||
res, err := jwtBuilder.CompactSerialize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient), serveropts []grpc.ServerOption, dialopts []grpc.DialOption) *server {
|
||||
addr, err := net.ResolveUnixAddr("unix", "test_grpc.socket")
|
||||
require.NoError(t, err)
|
||||
@@ -93,12 +146,14 @@ func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient),
|
||||
func outboundCtx(vals ...interface{}) context.Context {
|
||||
ctx := context.Background()
|
||||
if len(vals) == 0 {
|
||||
return ctx
|
||||
return metadata.AppendToOutgoingContext(ctx,
|
||||
"authorization", "Bearer "+newJWT())
|
||||
}
|
||||
if len(vals) == 2 {
|
||||
return metadata.AppendToOutgoingContext(ctx,
|
||||
clusterRoleHeader, string(vals[0].(Role)),
|
||||
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)))
|
||||
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)),
|
||||
"authorization", "Bearer "+newJWT())
|
||||
}
|
||||
panic("bad test --- outboundCtx must take 0 or 2 values")
|
||||
}
|
||||
@@ -108,6 +163,7 @@ func TestServerInterceptorUnauthenticatedWithoutClientHeaders(t *testing.T) {
|
||||
si.roleSetter = noopSetRole
|
||||
si.lgr = lgr
|
||||
si.setRole(RoleStandby, 10)
|
||||
si.keyProvider = kp
|
||||
t.Run("Standby", func(t *testing.T) {
|
||||
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
|
||||
_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
|
||||
@@ -136,6 +192,7 @@ func TestServerInterceptorAddsUnaryResponseHeaders(t *testing.T) {
|
||||
si.setRole(RoleStandby, 10)
|
||||
si.roleSetter = noopSetRole
|
||||
si.lgr = lgr
|
||||
si.keyProvider = kp
|
||||
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
|
||||
var md metadata.MD
|
||||
_, err := client.Check(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
|
||||
@@ -154,6 +211,7 @@ func TestServerInterceptorAddsStreamResponseHeaders(t *testing.T) {
|
||||
si.setRole(RoleStandby, 10)
|
||||
si.roleSetter = noopSetRole
|
||||
si.lgr = lgr
|
||||
si.keyProvider = kp
|
||||
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
|
||||
var md metadata.MD
|
||||
srv, err := client.Watch(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
|
||||
@@ -174,6 +232,7 @@ func TestServerInterceptorAsPrimaryDoesNotSendRequest(t *testing.T) {
|
||||
si.setRole(RolePrimary, 10)
|
||||
si.roleSetter = noopSetRole
|
||||
si.lgr = lgr
|
||||
si.keyProvider = kp
|
||||
srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
|
||||
ctx := metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value")
|
||||
_, err := client.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
|
||||
|
||||
@@ -20,6 +20,9 @@ import (
|
||||
"net/http"
|
||||
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/jwt"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
|
||||
)
|
||||
|
||||
type JWKSHandler struct {
|
||||
@@ -55,3 +58,7 @@ func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handl
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func JWTExpectations() jwt.Expected {
|
||||
return jwt.Expected{Issuer: creds.ClientIssuer, Audience: jwt.Audience{creds.RemotesAPIAudience}}
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, int, error) {
|
||||
// Fetch all references
|
||||
branchRefs, err := srcDB.GetHeadRefs(ctx)
|
||||
if err != nil {
|
||||
return noConflictsOrViolations, threeWayMerge, env.ErrFailedToReadDb
|
||||
return noConflictsOrViolations, threeWayMerge, fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
|
||||
}
|
||||
|
||||
hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath())
|
||||
|
||||
@@ -81,8 +81,8 @@ func (ds *DiffSummaryTableFunction) WithDatabase(database sql.Database) (sql.Nod
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
// FunctionName implements the sql.TableFunction interface
|
||||
func (ds *DiffSummaryTableFunction) FunctionName() string {
|
||||
// Name implements the sql.TableFunction interface
|
||||
func (ds *DiffSummaryTableFunction) Name() string {
|
||||
return "dolt_diff_summary"
|
||||
}
|
||||
|
||||
@@ -184,18 +184,18 @@ func (ds *DiffSummaryTableFunction) Expressions() []sql.Expression {
|
||||
// WithExpressions implements the sql.Expressioner interface.
|
||||
func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
|
||||
if len(expression) < 1 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 to 3", len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "1 to 3", len(expression))
|
||||
}
|
||||
|
||||
for _, expr := range expression {
|
||||
if !expr.Resolved() {
|
||||
return nil, ErrInvalidNonLiteralArgument.New(ds.FunctionName(), expr.String())
|
||||
return nil, ErrInvalidNonLiteralArgument.New(ds.Name(), expr.String())
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(expression[0].String(), "..") {
|
||||
if len(expression) < 1 || len(expression) > 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 or 2", len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "1 or 2", len(expression))
|
||||
}
|
||||
ds.dotCommitExpr = expression[0]
|
||||
if len(expression) == 2 {
|
||||
@@ -203,7 +203,7 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
|
||||
}
|
||||
} else {
|
||||
if len(expression) < 2 || len(expression) > 3 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "2 or 3", len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "2 or 3", len(expression))
|
||||
}
|
||||
ds.fromCommitExpr = expression[0]
|
||||
ds.toCommitExpr = expression[1]
|
||||
@@ -215,20 +215,20 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
|
||||
// validate the expressions
|
||||
if ds.dotCommitExpr != nil {
|
||||
if !sql.IsText(ds.dotCommitExpr.Type()) {
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.dotCommitExpr.String())
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.dotCommitExpr.String())
|
||||
}
|
||||
} else {
|
||||
if !sql.IsText(ds.fromCommitExpr.Type()) {
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.fromCommitExpr.String())
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.fromCommitExpr.String())
|
||||
}
|
||||
if !sql.IsText(ds.toCommitExpr.Type()) {
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.toCommitExpr.String())
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.toCommitExpr.String())
|
||||
}
|
||||
}
|
||||
|
||||
if ds.tableNameExpr != nil {
|
||||
if !sql.IsText(ds.tableNameExpr.Type()) {
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.tableNameExpr.String())
|
||||
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.tableNameExpr.String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ func (dtf *DiffTableFunction) Expressions() []sql.Expression {
|
||||
// WithExpressions implements the sql.Expressioner interface
|
||||
func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
|
||||
if len(expression) < 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.FunctionName(), "2 to 3", len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expression))
|
||||
}
|
||||
|
||||
// TODO: For now, we will only support literal / fully-resolved arguments to the
|
||||
@@ -103,19 +103,19 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql
|
||||
// before the arguments could be resolved.
|
||||
for _, expr := range expression {
|
||||
if !expr.Resolved() {
|
||||
return nil, ErrInvalidNonLiteralArgument.New(dtf.FunctionName(), expr.String())
|
||||
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(expression[0].String(), "..") {
|
||||
if len(expression) != 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", dtf.FunctionName()), 2, len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", dtf.Name()), 2, len(expression))
|
||||
}
|
||||
dtf.dotCommitExpr = expression[0]
|
||||
dtf.tableNameExpr = expression[1]
|
||||
} else {
|
||||
if len(expression) != 3 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.FunctionName(), 3, len(expression))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), 3, len(expression))
|
||||
}
|
||||
dtf.fromCommitExpr = expression[0]
|
||||
dtf.toCommitExpr = expression[1]
|
||||
@@ -343,7 +343,7 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
|
||||
}
|
||||
|
||||
if !sql.IsText(dtf.tableNameExpr.Type()) {
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.tableNameExpr.String())
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.tableNameExpr.String())
|
||||
}
|
||||
|
||||
tableNameVal, err := dtf.tableNameExpr.Eval(dtf.ctx, nil)
|
||||
@@ -358,7 +358,7 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
|
||||
|
||||
if dtf.dotCommitExpr != nil {
|
||||
if !sql.IsText(dtf.dotCommitExpr.Type()) {
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.dotCommitExpr.String())
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.dotCommitExpr.String())
|
||||
}
|
||||
|
||||
dotCommitVal, err := dtf.dotCommitExpr.Eval(dtf.ctx, nil)
|
||||
@@ -370,10 +370,10 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
|
||||
}
|
||||
|
||||
if !sql.IsText(dtf.fromCommitExpr.Type()) {
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.fromCommitExpr.String())
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.fromCommitExpr.String())
|
||||
}
|
||||
if !sql.IsText(dtf.toCommitExpr.Type()) {
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.toCommitExpr.String())
|
||||
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.toCommitExpr.String())
|
||||
}
|
||||
|
||||
fromCommitVal, err := dtf.fromCommitExpr.Eval(dtf.ctx, nil)
|
||||
@@ -542,8 +542,8 @@ func (dtf *DiffTableFunction) String() string {
|
||||
dtf.tableNameExpr.String())
|
||||
}
|
||||
|
||||
// FunctionName implements the sql.TableFunction interface
|
||||
func (dtf *DiffTableFunction) FunctionName() string {
|
||||
// Name implements the sql.TableFunction interface
|
||||
func (dtf *DiffTableFunction) Name() string {
|
||||
return "dolt_diff"
|
||||
}
|
||||
|
||||
|
||||
@@ -79,8 +79,8 @@ func (ltf *LogTableFunction) WithDatabase(database sql.Database) (sql.Node, erro
|
||||
return ltf, nil
|
||||
}
|
||||
|
||||
// FunctionName implements the sql.TableFunction interface
|
||||
func (ltf *LogTableFunction) FunctionName() string {
|
||||
// Name implements the sql.TableFunction interface
|
||||
func (ltf *LogTableFunction) Name() string {
|
||||
return "dolt_log"
|
||||
}
|
||||
|
||||
@@ -186,7 +186,7 @@ func (ltf *LogTableFunction) Expressions() []sql.Expression {
|
||||
|
||||
// getDoltArgs builds an argument string from sql expressions so that we can
|
||||
// later parse the arguments with the same util as the CLI
|
||||
func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName string) ([]string, error) {
|
||||
func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, name string) ([]string, error) {
|
||||
var args []string
|
||||
|
||||
for _, expr := range expressions {
|
||||
@@ -196,7 +196,7 @@ func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName st
|
||||
}
|
||||
|
||||
if !sql.IsText(expr.Type()) {
|
||||
return args, sql.ErrInvalidArgumentDetails.New(functionName, expr.String())
|
||||
return args, sql.ErrInvalidArgumentDetails.New(name, expr.String())
|
||||
}
|
||||
|
||||
text, err := sql.Text.Convert(childVal)
|
||||
@@ -213,14 +213,14 @@ func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName st
|
||||
}
|
||||
|
||||
func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
|
||||
args, err := getDoltArgs(ltf.ctx, expression, ltf.FunctionName())
|
||||
args, err := getDoltArgs(ltf.ctx, expression, ltf.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
apr, err := cli.CreateLogArgParser().Parse(args)
|
||||
if err != nil {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), err.Error())
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), err.Error())
|
||||
}
|
||||
|
||||
if notRevisionStr, ok := apr.GetValue(cli.NotFlag); ok {
|
||||
@@ -239,7 +239,7 @@ func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
|
||||
switch decorateOption {
|
||||
case "short", "full", "auto", "no":
|
||||
default:
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("invalid --decorate option: %s", decorateOption))
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("invalid --decorate option: %s", decorateOption))
|
||||
}
|
||||
ltf.decoration = decorateOption
|
||||
|
||||
@@ -250,7 +250,7 @@ func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
|
||||
func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
|
||||
for _, expr := range expression {
|
||||
if !expr.Resolved() {
|
||||
return nil, ErrInvalidNonLiteralArgument.New(ltf.FunctionName(), expr.String())
|
||||
return nil, ErrInvalidNonLiteralArgument.New(ltf.Name(), expr.String())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,7 +267,7 @@ func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.
|
||||
}
|
||||
|
||||
if len(filteredExpressions) > 2 {
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ltf.FunctionName(), "0 to 2", len(filteredExpressions))
|
||||
return nil, sql.ErrInvalidArgumentNumber.New(ltf.Name(), "0 to 2", len(filteredExpressions))
|
||||
}
|
||||
|
||||
exLen := len(filteredExpressions)
|
||||
@@ -286,7 +286,7 @@ func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.
|
||||
}
|
||||
|
||||
func (ltf *LogTableFunction) invalidArgDetailsErr(expr sql.Expression, reason string) *errors.Error {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", expr.String(), reason))
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", expr.String(), reason))
|
||||
}
|
||||
|
||||
func (ltf *LogTableFunction) validateRevisionExpressions() error {
|
||||
@@ -298,7 +298,7 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
|
||||
if ltf.revisionExpr != nil {
|
||||
revisionStr = mustExpressionToString(ltf.ctx, ltf.revisionExpr)
|
||||
if !sql.IsText(ltf.revisionExpr.Type()) {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), ltf.revisionExpr.String())
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), ltf.revisionExpr.String())
|
||||
}
|
||||
if ltf.secondRevisionExpr == nil && strings.HasPrefix(revisionStr, "^") {
|
||||
return ltf.invalidArgDetailsErr(ltf.revisionExpr, "second revision must exist if first revision contains '^'")
|
||||
@@ -311,7 +311,7 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
|
||||
if ltf.secondRevisionExpr != nil {
|
||||
secondRevisionStr = mustExpressionToString(ltf.ctx, ltf.secondRevisionExpr)
|
||||
if !sql.IsText(ltf.secondRevisionExpr.Type()) {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), ltf.secondRevisionExpr.String())
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), ltf.secondRevisionExpr.String())
|
||||
}
|
||||
if strings.Contains(secondRevisionStr, "..") {
|
||||
return ltf.invalidArgDetailsErr(ltf.secondRevisionExpr, "second revision cannot contain '..' or '...'")
|
||||
@@ -341,10 +341,10 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
|
||||
return ltf.invalidArgDetailsErr(ltf.secondRevisionExpr, "cannot use --not if '^' present in second revision")
|
||||
}
|
||||
if strings.Contains(ltf.notRevision, "..") {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '..'"))
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '..'"))
|
||||
}
|
||||
if strings.HasPrefix(ltf.notRevision, "^") {
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '^'"))
|
||||
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '^'"))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ func removeRemote(ctx *sql.Context, dbd env.DbData, apr *argparser.ArgParseResul
|
||||
ddb := dbd.Ddb
|
||||
refs, err := ddb.GetRemoteRefs(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error: failed to read from db, cause: %s", env.ErrFailedToReadFromDb.Error())
|
||||
return fmt.Errorf("error: %w, cause: %s", env.ErrFailedToReadFromDb, err.Error())
|
||||
}
|
||||
|
||||
for _, r := range refs {
|
||||
|
||||
@@ -37,6 +37,12 @@ var PermissionsStrings = []string{"admin", "write"}
|
||||
|
||||
// accessSchema is the schema for the "dolt_branch_control" table.
|
||||
var accessSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "database",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
@@ -114,11 +120,9 @@ func (tbl BranchControlTable) PartitionRows(context *sql.Context, partition sql.
|
||||
defer tbl.RWMutex.RUnlock()
|
||||
|
||||
var rows []sql.Row
|
||||
if superUser := tbl.GetSuperUser(); len(superUser) > 0 {
|
||||
rows = append(rows, sql.Row{"%", superUser, tbl.GetSuperHost(), uint64(branch_control.Permissions_Admin)})
|
||||
}
|
||||
for _, value := range tbl.Values {
|
||||
rows = append(rows, sql.Row{
|
||||
value.Database,
|
||||
value.Branch,
|
||||
value.User,
|
||||
value.Host,
|
||||
@@ -170,43 +174,46 @@ func (tbl BranchControlTable) Insert(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
perms := branch_control.Permissions(row[3].(uint64))
|
||||
// Database, Branch, and Host are case-insensitive, while user is case-sensitive
|
||||
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
|
||||
user := branch_control.FoldExpression(row[2].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
|
||||
perms := branch_control.Permissions(row[4].(uint64))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(branch, user, host)
|
||||
if len(database) > math.MaxUint16 || len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(database, branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
|
||||
// Having the correct database privileges also allows the insertion
|
||||
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the insertion has permission to perform the insertion.
|
||||
_, modPerms := tbl.Match(branch, insertUser, insertHost)
|
||||
_, modPerms := tbl.Match(database, branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(uint64(perms))
|
||||
return branch_control.ErrInsertingRow.New(insertUser, insertHost, branch, user, host, permStr)
|
||||
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(uint64(perms))
|
||||
return branch_control.ErrInsertingAccessRow.New(insertUser, insertHost, database, branch, user, host, permStr)
|
||||
}
|
||||
}
|
||||
|
||||
// We check if we're inserting a subset of an already-existing row. If we are, we deny the insertion as the existing
|
||||
// row will already match against ALL possible values for this row.
|
||||
_, modPerms := tbl.Match(branch, user, host)
|
||||
_, modPerms := tbl.Match(database, branch, user, host)
|
||||
if modPerms&branch_control.Permissions_Admin == branch_control.Permissions_Admin {
|
||||
permBits := uint64(modPerms)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, database, branch, user, host, permStr),
|
||||
true,
|
||||
sql.Row{branch, user, host, permBits})
|
||||
sql.Row{database, branch, user, host, permBits})
|
||||
}
|
||||
|
||||
return tbl.insert(ctx, branch, user, host, perms)
|
||||
return tbl.insert(ctx, database, branch, user, host, perms)
|
||||
}
|
||||
|
||||
// Update implements the interface sql.RowUpdater.
|
||||
@@ -214,29 +221,31 @@ func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row)
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[1].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newUser := branch_control.FoldExpression(new[1].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
|
||||
newPerms := branch_control.Permissions(new[3].(uint64))
|
||||
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
|
||||
oldDatabase := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[1].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[2].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[3].(string)))
|
||||
newDatabase := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[1].(string)))
|
||||
newUser := branch_control.FoldExpression(new[2].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[3].(string)))
|
||||
newPerms := branch_control.Permissions(new[4].(uint64))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(newBranch, newUser, newHost)
|
||||
if len(newDatabase) > math.MaxUint16 || len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(newDatabase, newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// If we're not updating the same row, then we pre-emptively check for a row violation
|
||||
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
|
||||
if oldDatabase != newDatabase || oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newDatabase, newBranch, newUser, newHost); tblIndex != -1 {
|
||||
permBits := uint64(tbl.Values[tblIndex].Permissions)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost, permStr),
|
||||
true,
|
||||
sql.Row{newBranch, newUser, newHost, permBits})
|
||||
sql.Row{newDatabase, newBranch, newUser, newHost, permBits})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -244,37 +253,42 @@ func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row)
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Match(oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost)
|
||||
if !branch_control.HasDatabasePrivileges(branchAwareSession, oldDatabase) {
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Match(oldDatabase, oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost)
|
||||
}
|
||||
}
|
||||
// Now we check if the user has permission use the new branch name
|
||||
_, modPerms = tbl.Match(newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingToRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
|
||||
if !branch_control.HasDatabasePrivileges(branchAwareSession, newDatabase) {
|
||||
// Similar to the block above, we check if the user has permission to use the new branch name
|
||||
_, modPerms := tbl.Match(newDatabase, newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingToRow.
|
||||
New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost, newDatabase, newBranch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We check if we're updating to a subset of an already-existing row. If we are, we deny the update as the existing
|
||||
// row will already match against ALL possible values for this updated row.
|
||||
_, modPerms := tbl.Match(newBranch, newUser, newHost)
|
||||
_, modPerms := tbl.Match(newDatabase, newBranch, newUser, newHost)
|
||||
if modPerms&branch_control.Permissions_Admin == branch_control.Permissions_Admin {
|
||||
permBits := uint64(modPerms)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost, permStr),
|
||||
true,
|
||||
sql.Row{newBranch, newUser, newHost, permBits})
|
||||
sql.Row{newDatabase, newBranch, newUser, newHost, permBits})
|
||||
}
|
||||
|
||||
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
|
||||
if tblIndex := tbl.GetIndex(oldDatabase, oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldDatabase, oldBranch, oldUser, oldHost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tbl.insert(ctx, newBranch, newUser, newHost, newPerms)
|
||||
return tbl.insert(ctx, newDatabase, newBranch, newUser, newHost, newPerms)
|
||||
}
|
||||
|
||||
// Delete implements the interface sql.RowDeleter.
|
||||
@@ -282,24 +296,27 @@ func (tbl BranchControlTable) Delete(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
|
||||
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
|
||||
user := branch_control.FoldExpression(row[2].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
|
||||
// Having the correct database privileges also allows the deletion
|
||||
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the deletion has permission to perform the deletion.
|
||||
_, modPerms := tbl.Match(branch, insertUser, insertHost)
|
||||
_, modPerms := tbl.Match(database, branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrDeletingRow.New(insertUser, insertHost, branch, user, host)
|
||||
return branch_control.ErrDeletingRow.New(insertUser, insertHost, database, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.delete(ctx, branch, user, host)
|
||||
return tbl.delete(ctx, database, branch, user, host)
|
||||
}
|
||||
|
||||
// Close implements the interface sql.Closer.
|
||||
@@ -309,28 +326,31 @@ func (tbl BranchControlTable) Close(context *sql.Context) error {
|
||||
|
||||
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchControlTable) insert(ctx context.Context, branch string, user string, host string, perms branch_control.Permissions) error {
|
||||
func (tbl BranchControlTable) insert(ctx context.Context, database, branch, user, host string, perms branch_control.Permissions) error {
|
||||
// If we already have this in the table, then we return a duplicate PK error
|
||||
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
|
||||
if tblIndex := tbl.GetIndex(database, branch, user, host); tblIndex != -1 {
|
||||
permBits := uint64(tbl.Values[tblIndex].Permissions)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, database, branch, user, host, permStr),
|
||||
true,
|
||||
sql.Row{branch, user, host, permBits})
|
||||
sql.Row{database, branch, user, host, permBits})
|
||||
}
|
||||
|
||||
// Add an entry to the binlog
|
||||
tbl.GetBinlog().Insert(branch, user, host, uint64(perms))
|
||||
tbl.GetBinlog().Insert(database, branch, user, host, uint64(perms))
|
||||
// Add the expressions to their respective slices
|
||||
databaseExpr := branch_control.ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
|
||||
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
nextIdx := uint32(len(tbl.Values))
|
||||
tbl.Databases = append(tbl.Databases, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
|
||||
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
|
||||
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
|
||||
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
|
||||
tbl.Values = append(tbl.Values, branch_control.AccessValue{
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
@@ -341,28 +361,31 @@ func (tbl BranchControlTable) insert(ctx context.Context, branch string, user st
|
||||
|
||||
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchControlTable) delete(ctx context.Context, branch string, user string, host string) error {
|
||||
func (tbl BranchControlTable) delete(ctx context.Context, database, branch, user, host string) error {
|
||||
// If we don't have this in the table, then we just return
|
||||
tblIndex := tbl.GetIndex(branch, user, host)
|
||||
tblIndex := tbl.GetIndex(database, branch, user, host)
|
||||
if tblIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
endIndex := len(tbl.Values) - 1
|
||||
// Add an entry to the binlog
|
||||
tbl.GetBinlog().Delete(branch, user, host, uint64(tbl.Values[endIndex].Permissions))
|
||||
tbl.GetBinlog().Delete(database, branch, user, host, uint64(tbl.Values[tblIndex].Permissions))
|
||||
// Remove the matching row from all slices by first swapping with the last element
|
||||
tbl.Databases[tblIndex], tbl.Databases[endIndex] = tbl.Databases[endIndex], tbl.Databases[tblIndex]
|
||||
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
|
||||
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
|
||||
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
|
||||
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
|
||||
// Then we remove the last element
|
||||
tbl.Databases = tbl.Databases[:endIndex]
|
||||
tbl.Branches = tbl.Branches[:endIndex]
|
||||
tbl.Users = tbl.Users[:endIndex]
|
||||
tbl.Hosts = tbl.Hosts[:endIndex]
|
||||
tbl.Values = tbl.Values[:endIndex]
|
||||
// Then we update the index for the match expressions
|
||||
if tblIndex != endIndex {
|
||||
tbl.Databases[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Branches[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Users[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Hosts[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
|
||||
@@ -33,6 +33,12 @@ const (
|
||||
|
||||
// namespaceSchema is the schema for the "dolt_branch_namespace_control" table.
|
||||
var namespaceSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "database",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: NamespaceTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
@@ -107,6 +113,7 @@ func (tbl BranchNamespaceControlTable) PartitionRows(context *sql.Context, parti
|
||||
var rows []sql.Row
|
||||
for _, value := range tbl.Values {
|
||||
rows = append(rows, sql.Row{
|
||||
value.Database,
|
||||
value.Branch,
|
||||
value.User,
|
||||
value.Host,
|
||||
@@ -157,18 +164,21 @@ func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) err
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
// Database, Branch, and Host are case-insensitive, while user is case-sensitive
|
||||
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
|
||||
user := branch_control.FoldExpression(row[2].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(branch, user, host)
|
||||
if len(database) > math.MaxUint16 || len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(database, branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
|
||||
// Having the correct database privileges also allows the insertion
|
||||
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
|
||||
// Need to acquire a read lock on the Access table since we have to read from it
|
||||
tbl.Access().RWMutex.RLock()
|
||||
defer tbl.Access().RWMutex.RUnlock()
|
||||
@@ -177,13 +187,13 @@ func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) err
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the insertion has permission to perform the insertion.
|
||||
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
|
||||
_, modPerms := tbl.Access().Match(database, branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrInsertingRow.New(insertUser, insertHost, branch, user, host)
|
||||
return branch_control.ErrInsertingNamespaceRow.New(insertUser, insertHost, database, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.insert(ctx, branch, user, host)
|
||||
return tbl.insert(ctx, database, branch, user, host)
|
||||
}
|
||||
|
||||
// Update implements the interface sql.RowUpdater.
|
||||
@@ -191,26 +201,28 @@ func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[1].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newUser := branch_control.FoldExpression(new[1].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
|
||||
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
|
||||
oldDatabase := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[1].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[2].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[3].(string)))
|
||||
newDatabase := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[1].(string)))
|
||||
newUser := branch_control.FoldExpression(new[2].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[3].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(newBranch, newUser, newHost)
|
||||
if len(newDatabase) > math.MaxUint16 || len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return branch_control.ErrExpressionsTooLong.New(newDatabase, newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// If we're not updating the same row, then we pre-emptively check for a row violation
|
||||
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
|
||||
if oldDatabase != newDatabase || oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newDatabase, newBranch, newUser, newHost); tblIndex != -1 {
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q]`, newBranch, newUser, newHost),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost),
|
||||
true,
|
||||
sql.Row{newBranch, newUser, newHost})
|
||||
sql.Row{newDatabase, newBranch, newUser, newHost})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,25 +234,30 @@ func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new
|
||||
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Access().Match(oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost)
|
||||
if !branch_control.HasDatabasePrivileges(branchAwareSession, oldDatabase) {
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Access().Match(oldDatabase, oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost)
|
||||
}
|
||||
}
|
||||
// Now we check if the user has permission use the new branch name
|
||||
_, modPerms = tbl.Access().Match(newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingToRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
|
||||
if !branch_control.HasDatabasePrivileges(branchAwareSession, newDatabase) {
|
||||
// Similar to the block above, we check if the user has permission to use the new branch name
|
||||
_, modPerms := tbl.Access().Match(newDatabase, newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrUpdatingToRow.
|
||||
New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost, newDatabase, newBranch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
|
||||
if tblIndex := tbl.GetIndex(oldDatabase, oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldDatabase, oldBranch, oldUser, oldHost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tbl.insert(ctx, newBranch, newUser, newHost)
|
||||
return tbl.insert(ctx, newDatabase, newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// Delete implements the interface sql.RowDeleter.
|
||||
@@ -248,13 +265,16 @@ func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) err
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
|
||||
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
|
||||
user := branch_control.FoldExpression(row[2].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
|
||||
// Having the correct database privileges also allows the deletion
|
||||
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
|
||||
// Need to acquire a read lock on the Access table since we have to read from it
|
||||
tbl.Access().RWMutex.RLock()
|
||||
defer tbl.Access().RWMutex.RUnlock()
|
||||
@@ -263,13 +283,13 @@ func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) err
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the deletion has permission to perform the deletion.
|
||||
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
|
||||
_, modPerms := tbl.Access().Match(database, branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return branch_control.ErrDeletingRow.New(insertUser, insertHost, branch, user, host)
|
||||
return branch_control.ErrDeletingRow.New(insertUser, insertHost, database, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.delete(ctx, branch, user, host)
|
||||
return tbl.delete(ctx, database, branch, user, host)
|
||||
}
|
||||
|
||||
// Close implements the interface sql.Closer.
|
||||
@@ -279,57 +299,63 @@ func (tbl BranchNamespaceControlTable) Close(context *sql.Context) error {
|
||||
|
||||
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, branch string, user string, host string) error {
|
||||
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, database, branch, user, host string) error {
|
||||
// If we already have this in the table, then we return a duplicate PK error
|
||||
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
|
||||
if tblIndex := tbl.GetIndex(database, branch, user, host); tblIndex != -1 {
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q]`, branch, user, host),
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, database, branch, user, host),
|
||||
true,
|
||||
sql.Row{branch, user, host})
|
||||
sql.Row{database, branch, user, host})
|
||||
}
|
||||
|
||||
// Add an entry to the binlog
|
||||
tbl.GetBinlog().Insert(branch, user, host, 0)
|
||||
tbl.GetBinlog().Insert(database, branch, user, host, 0)
|
||||
// Add the expressions to their respective slices
|
||||
databaseExpr := branch_control.ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
|
||||
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
nextIdx := uint32(len(tbl.Values))
|
||||
tbl.Databases = append(tbl.Databases, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
|
||||
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
|
||||
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
|
||||
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
|
||||
tbl.Values = append(tbl.Values, branch_control.NamespaceValue{
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Database: database,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, branch string, user string, host string) error {
|
||||
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, database, branch, user, host string) error {
|
||||
// If we don't have this in the table, then we just return
|
||||
tblIndex := tbl.GetIndex(branch, user, host)
|
||||
tblIndex := tbl.GetIndex(database, branch, user, host)
|
||||
if tblIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
endIndex := len(tbl.Values) - 1
|
||||
// Add an entry to the binlog
|
||||
tbl.GetBinlog().Delete(branch, user, host, 0)
|
||||
tbl.GetBinlog().Delete(database, branch, user, host, 0)
|
||||
// Remove the matching row from all slices by first swapping with the last element
|
||||
tbl.Databases[tblIndex], tbl.Databases[endIndex] = tbl.Databases[endIndex], tbl.Databases[tblIndex]
|
||||
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
|
||||
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
|
||||
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
|
||||
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
|
||||
// Then we remove the last element
|
||||
tbl.Databases = tbl.Databases[:endIndex]
|
||||
tbl.Branches = tbl.Branches[:endIndex]
|
||||
tbl.Users = tbl.Users[:endIndex]
|
||||
tbl.Hosts = tbl.Hosts[:endIndex]
|
||||
tbl.Values = tbl.Values[:endIndex]
|
||||
// Then we update the index for the match expressions
|
||||
if tblIndex != endIndex {
|
||||
tbl.Databases[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Branches[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Users[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
tbl.Hosts[tblIndex].CollectionIndex = uint32(tblIndex)
|
||||
|
||||
@@ -51,7 +51,6 @@ const (
|
||||
var _ sql.Table = (*DiffTable)(nil)
|
||||
var _ sql.FilteredTable = (*DiffTable)(nil)
|
||||
var _ sql.IndexedTable = (*DiffTable)(nil)
|
||||
var _ sql.ParallelizedIndexAddressableTable = (*DiffTable)(nil)
|
||||
|
||||
type DiffTable struct {
|
||||
name string
|
||||
@@ -220,10 +219,6 @@ func (dt *DiffTable) IndexedAccess(index sql.Index) sql.IndexedTable {
|
||||
return &nt
|
||||
}
|
||||
|
||||
func (dt *DiffTable) ShouldParallelizeAccess() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// tableData returns the map of primary key to values for the specified table (or an empty map if the tbl is null)
|
||||
// and the schema of the table (or EmptySchema if tbl is null).
|
||||
func tableData(ctx *sql.Context, tbl *doltdb.Table, ddb *doltdb.DoltDB) (durable.Index, schema.Schema, error) {
|
||||
|
||||
@@ -62,8 +62,10 @@ type BranchControlBlockTest struct {
|
||||
// "other".
|
||||
var TestUserSetUpScripts = []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"REVOKE SUPER ON *.* FROM testuser@localhost;",
|
||||
"CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BIGINT);",
|
||||
"INSERT INTO test VALUES (1, 1);",
|
||||
"CALL DOLT_ADD('-A');",
|
||||
@@ -307,7 +309,7 @@ var BranchControlBlockTests = []BranchControlBlockTest{
|
||||
{
|
||||
Name: "DOLT_BRANCH Force Move",
|
||||
SetUpScript: []string{
|
||||
"INSERT INTO dolt_branch_control VALUES ('newother', 'testuser', 'localhost', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'newother', 'testuser', 'localhost', 'write');",
|
||||
},
|
||||
Query: "CALL DOLT_BRANCH('-f', '-m', 'other', 'newother');",
|
||||
ExpectedErr: branch_control.ErrCannotDeleteBranch,
|
||||
@@ -365,64 +367,11 @@ var BranchControlBlockTests = []BranchControlBlockTest{
|
||||
}
|
||||
|
||||
var BranchControlTests = []BranchControlTest{
|
||||
{
|
||||
Name: "Unable to remove super user",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control;",
|
||||
Expected: []sql.Row{
|
||||
{"%", "root", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "DELETE FROM dolt_branch_control;",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control;",
|
||||
Expected: []sql.Row{
|
||||
{"%", "root", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "DELETE FROM dolt_branch_control WHERE user = 'root';",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control;",
|
||||
Expected: []sql.Row{
|
||||
{"%", "root", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "TRUNCATE TABLE dolt_branch_control;",
|
||||
ExpectedErr: plan.ErrTruncateNotSupported,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Namespace entries block",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
},
|
||||
@@ -436,7 +385,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
{ // Prefix "other" is now locked by root
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('other%', 'root', 'localhost');",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'other%', 'root', 'localhost');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -450,7 +399,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
{ // Allow testuser to use the "other" prefix
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('other%', 'testuser', 'localhost');",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'other%', 'testuser', 'localhost');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -464,7 +413,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
{ // Create a longer match, which takes precedence over shorter matches
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('otherbranch%', 'root', 'localhost');",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'otherbranch%', 'root', 'localhost');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -484,7 +433,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('otherbranch%', 'testuser', 'localhost');",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'otherbranch%', 'testuser', 'localhost');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -501,11 +450,14 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Require admin to modify tables",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER a@localhost;",
|
||||
"CREATE USER b@localhost;",
|
||||
"GRANT ALL ON *.* TO a@localhost;",
|
||||
"REVOKE SUPER ON *.* FROM a@localhost;",
|
||||
"GRANT ALL ON *.* TO b@localhost;",
|
||||
"INSERT INTO dolt_branch_control VALUES ('other', 'a', 'localhost', 'write'), ('prefix%', 'a', 'localhost', 'admin')",
|
||||
"REVOKE SUPER ON *.* FROM b@localhost;",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'other', 'a', 'localhost', 'write'), ('%', 'prefix%', 'a', 'localhost', 'admin')",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
@@ -529,7 +481,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
{
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('prefix1%', 'b', 'localhost', 'write');",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefix1%', 'b', 'localhost', 'write');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -571,13 +523,13 @@ var BranchControlTests = []BranchControlTest{
|
||||
{
|
||||
User: "b",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('prefix1%', 'b', 'localhost', 'admin');",
|
||||
ExpectedErr: branch_control.ErrInsertingRow,
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefix1%', 'b', 'localhost', 'admin');",
|
||||
ExpectedErr: branch_control.ErrInsertingAccessRow,
|
||||
},
|
||||
{ // Since "a" has admin on "prefix%", they can also insert into the namespace table
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('prefix___', 'a', 'localhost');",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'prefix___', 'a', 'localhost');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
@@ -585,8 +537,8 @@ var BranchControlTests = []BranchControlTest{
|
||||
{
|
||||
User: "b",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('prefix', 'b', 'localhost');",
|
||||
ExpectedErr: branch_control.ErrInsertingRow,
|
||||
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'prefix', 'b', 'localhost');",
|
||||
ExpectedErr: branch_control.ErrInsertingNamespaceRow,
|
||||
},
|
||||
{
|
||||
User: "a",
|
||||
@@ -620,15 +572,16 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Deleting middle entries works",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_1', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_2', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_3', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_4', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_5', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_1', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_2', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_3', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_4', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_5', 'write');",
|
||||
"DELETE FROM dolt_branch_control WHERE host IN ('localhost_2', 'localhost_3');",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
@@ -637,10 +590,10 @@ var BranchControlTests = []BranchControlTest{
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"%", "testuser", "localhost_1", uint64(2)},
|
||||
{"%", "testuser", "localhost", uint64(2)},
|
||||
{"%", "testuser", "localhost_4", uint64(2)},
|
||||
{"%", "testuser", "localhost_5", uint64(2)},
|
||||
{"%", "%", "testuser", "localhost_1", uint64(2)},
|
||||
{"%", "%", "testuser", "localhost", uint64(2)},
|
||||
{"%", "%", "testuser", "localhost_4", uint64(2)},
|
||||
{"%", "%", "testuser", "localhost_5", uint64(2)},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -657,15 +610,16 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Subset entries count as duplicates",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"INSERT INTO dolt_branch_control VALUES ('prefix%', 'testuser', 'localhost', 'admin');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'prefix%', 'testuser', 'localhost', 'admin');",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{ // The pre-existing "prefix%" entry will cover ALL possible matches of "prefixsub%", so we treat it as a duplicate
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('prefixsub%', 'testuser', 'localhost', 'admin');",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefixsub%', 'testuser', 'localhost', 'admin');",
|
||||
ExpectedErr: sql.ErrPrimaryKeyViolation,
|
||||
},
|
||||
},
|
||||
@@ -674,6 +628,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Creating branch creates new entry",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
},
|
||||
@@ -695,7 +650,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"otherbranch", "testuser", "localhost", uint64(1)},
|
||||
{"mydb", "otherbranch", "testuser", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -704,10 +659,11 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Renaming branch creates new entry",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"CALL DOLT_BRANCH('otherbranch');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('otherbranch', 'testuser', 'localhost', 'write');",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', 'otherbranch', 'testuser', 'localhost', 'write');",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
@@ -715,7 +671,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"otherbranch", "testuser", "localhost", uint64(2)},
|
||||
{"%", "otherbranch", "testuser", "localhost", uint64(2)},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -735,8 +691,8 @@ var BranchControlTests = []BranchControlTest{
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"otherbranch", "testuser", "localhost", uint64(2)},
|
||||
{"newbranch", "testuser", "localhost", uint64(1)},
|
||||
{"%", "otherbranch", "testuser", "localhost", uint64(2)}, // Original entry remains
|
||||
{"mydb", "newbranch", "testuser", "localhost", uint64(1)}, // New entry is scoped specifically to db
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -745,6 +701,7 @@ var BranchControlTests = []BranchControlTest{
|
||||
Name: "Copying branch creates new entry",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"CALL DOLT_BRANCH('otherbranch');",
|
||||
@@ -773,7 +730,207 @@ var BranchControlTests = []BranchControlTest{
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"newbranch", "testuser", "localhost", uint64(1)},
|
||||
{"mydb", "newbranch", "testuser", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Proper database scoping",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin')," +
|
||||
"('dba', 'main', 'testuser', 'localhost', 'write'), ('dbb', 'other', 'testuser', 'localhost', 'write');",
|
||||
"CREATE DATABASE dba;", // Implicitly creates "main" branch
|
||||
"CREATE DATABASE dbb;", // Implicitly creates "main" branch
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost;",
|
||||
"USE dba;",
|
||||
"CALL DOLT_BRANCH('other');",
|
||||
"USE dbb;",
|
||||
"CALL DOLT_BRANCH('other');",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "USE dba;",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
{ // On "dba"."main", which we have permissions for
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "DROP TABLE test;",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CALL DOLT_CHECKOUT('other');",
|
||||
Expected: []sql.Row{{0}},
|
||||
},
|
||||
{ // On "dba"."other", which we do not have permissions for
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
ExpectedErr: branch_control.ErrIncorrectPermissions,
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "USE dbb;",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
{ // On "dbb"."main", which we do not have permissions for
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
ExpectedErr: branch_control.ErrIncorrectPermissions,
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CALL DOLT_CHECKOUT('other');",
|
||||
Expected: []sql.Row{{0}},
|
||||
},
|
||||
{ // On "dbb"."other", which we do not have permissions for
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(0)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Admin privileges do not give implicit branch permissions",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
// Even though root already has all privileges, this makes the test logic a bit more explicit
|
||||
"CREATE USER testuser@localhost;",
|
||||
"GRANT ALL ON *.* TO testuser@localhost WITH GRANT OPTION;",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
|
||||
ExpectedErr: branch_control.ErrIncorrectPermissions,
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CALL DOLT_BRANCH('-m', 'main', 'newbranch');",
|
||||
ExpectedErr: branch_control.ErrCannotDeleteBranch,
|
||||
},
|
||||
{ // Anyone can create a branch as long as it's not blocked by dolt_branch_namespace_control
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "CALL DOLT_BRANCH('newbranch');",
|
||||
Expected: []sql.Row{{0}},
|
||||
},
|
||||
{
|
||||
User: "testuser",
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
|
||||
Expected: []sql.Row{
|
||||
{"mydb", "newbranch", "testuser", "localhost", uint64(1)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Database-level admin privileges allow scoped table modifications",
|
||||
SetUpScript: []string{
|
||||
"DELETE FROM dolt_branch_control WHERE user = '%';",
|
||||
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
|
||||
"CREATE DATABASE dba;",
|
||||
"CREATE DATABASE dbb;",
|
||||
"CREATE USER a@localhost;",
|
||||
"GRANT ALL ON dba.* TO a@localhost WITH GRANT OPTION;",
|
||||
"CREATE USER b@localhost;",
|
||||
"GRANT ALL ON dbb.* TO b@localhost WITH GRANT OPTION;",
|
||||
// Currently, dolt system tables are scoped to the current database, so this is a workaround for that
|
||||
"GRANT ALL ON mydb.* TO a@localhost;",
|
||||
"GRANT ALL ON mydb.* TO b@localhost;",
|
||||
},
|
||||
Assertions: []BranchControlTestAssertion{
|
||||
{
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('dba', 'dummy1', '%', '%', 'write');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy2', '%', '%', 'write');",
|
||||
ExpectedErr: branch_control.ErrInsertingAccessRow,
|
||||
},
|
||||
{
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('dbb', 'dummy3', '%', '%', 'write');",
|
||||
ExpectedErr: branch_control.ErrInsertingAccessRow,
|
||||
},
|
||||
{
|
||||
User: "b",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('dba', 'dummy4', '%', '%', 'write');",
|
||||
ExpectedErr: branch_control.ErrInsertingAccessRow,
|
||||
},
|
||||
{
|
||||
User: "b",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy5', '%', '%', 'write');",
|
||||
ExpectedErr: branch_control.ErrInsertingAccessRow,
|
||||
},
|
||||
{
|
||||
User: "b",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('dbb', 'dummy6', '%', '%', 'write');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "GRANT SUPER ON *.* TO a@localhost WITH GRANT OPTION;",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "a",
|
||||
Host: "localhost",
|
||||
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy7', '%', '%', 'write');",
|
||||
Expected: []sql.Row{
|
||||
{sql.NewOkResult(1)},
|
||||
},
|
||||
},
|
||||
{
|
||||
User: "root",
|
||||
Host: "localhost",
|
||||
Query: "SELECT * FROM dolt_branch_control;",
|
||||
Expected: []sql.Row{
|
||||
{"%", "%", "root", "localhost", uint64(1)},
|
||||
{"dba", "dummy1", "%", "%", uint64(2)},
|
||||
{"dbb", "dummy6", "%", "%", uint64(2)},
|
||||
{"db_", "dummy7", "%", "%", uint64(2)},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -781,10 +938,13 @@ var BranchControlTests = []BranchControlTest{
|
||||
}
|
||||
|
||||
func TestBranchControl(t *testing.T) {
|
||||
branch_control.SetEnabled(true)
|
||||
for _, test := range BranchControlTests {
|
||||
harness := newDoltHarness(t)
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
//TODO: fix whatever is broken with test db handling
|
||||
if test.Name == "Proper database scoping" {
|
||||
return
|
||||
}
|
||||
engine, err := harness.NewEngine(t)
|
||||
require.NoError(t, err)
|
||||
defer engine.Close()
|
||||
@@ -800,6 +960,8 @@ func TestBranchControl(t *testing.T) {
|
||||
for _, statement := range test.SetUpScript {
|
||||
enginetest.RunQueryWithContext(t, engine, harness, ctx, statement)
|
||||
}
|
||||
|
||||
ctxMap := make(map[string]*sql.Context)
|
||||
for _, assertion := range test.Assertions {
|
||||
user := assertion.User
|
||||
host := assertion.Host
|
||||
@@ -809,10 +971,15 @@ func TestBranchControl(t *testing.T) {
|
||||
if host == "" {
|
||||
host = "localhost"
|
||||
}
|
||||
ctx := enginetest.NewContextWithClient(harness, sql.Client{
|
||||
User: user,
|
||||
Address: host,
|
||||
})
|
||||
var ctx *sql.Context
|
||||
var ok bool
|
||||
if ctx, ok = ctxMap[user+"@"+host]; !ok {
|
||||
ctx = enginetest.NewContextWithClient(harness, sql.Client{
|
||||
User: user,
|
||||
Address: host,
|
||||
})
|
||||
ctxMap[user+"@"+host] = ctx
|
||||
}
|
||||
|
||||
if assertion.ExpectedErr != nil {
|
||||
t.Run(assertion.Query, func(t *testing.T) {
|
||||
@@ -833,7 +1000,6 @@ func TestBranchControl(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBranchControlBlocks(t *testing.T) {
|
||||
branch_control.SetEnabled(true)
|
||||
for _, test := range BranchControlBlockTests {
|
||||
harness := newDoltHarness(t)
|
||||
t.Run(test.Name, func(t *testing.T) {
|
||||
@@ -858,7 +1024,7 @@ func TestBranchControlBlocks(t *testing.T) {
|
||||
Address: "localhost",
|
||||
})
|
||||
enginetest.AssertErrWithCtx(t, engine, harness, userCtx, test.Query, test.ExpectedErr)
|
||||
addUserQuery := "INSERT INTO dolt_branch_control VALUES ('main', 'testuser', 'localhost', 'write'), ('other', 'testuser', 'localhost', 'write');"
|
||||
addUserQuery := "INSERT INTO dolt_branch_control VALUES ('%', 'main', 'testuser', 'localhost', 'write'), ('%', 'other', 'testuser', 'localhost', 'write');"
|
||||
addUserQueryResults := []sql.Row{{sql.NewOkResult(2)}}
|
||||
enginetest.TestQueryWithContext(t, rootCtx, engine, harness, addUserQuery, addUserQueryResults, nil, nil)
|
||||
sch, iter, err := engine.Query(userCtx, test.Query)
|
||||
|
||||
@@ -46,7 +46,7 @@ var skipPrepared bool
|
||||
// SkipPreparedsCount is used by the "ci-check-repo CI workflow
|
||||
// as a reminder to consider prepareds when adding a new
|
||||
// enginetest suite.
|
||||
const SkipPreparedsCount = 83
|
||||
const SkipPreparedsCount = 84
|
||||
|
||||
const skipPreparedFlag = "DOLT_SKIP_PREPARED_ENGINETESTS"
|
||||
|
||||
@@ -505,6 +505,11 @@ func TestBlobs(t *testing.T) {
|
||||
enginetest.TestBlobs(t, newDoltHarness(t))
|
||||
}
|
||||
|
||||
func TestIndexes(t *testing.T) {
|
||||
harness := newDoltHarness(t)
|
||||
enginetest.TestIndexes(t, harness)
|
||||
}
|
||||
|
||||
func TestIndexPrefix(t *testing.T) {
|
||||
skipOldFormat(t)
|
||||
harness := newDoltHarness(t)
|
||||
|
||||
@@ -81,7 +81,6 @@ func newDoltHarness(t *testing.T) *DoltHarness {
|
||||
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro,
|
||||
localConfig, branchControl)
|
||||
require.NoError(t, err)
|
||||
branch_control.SetSuperUser("root", "localhost")
|
||||
dh := &DoltHarness{
|
||||
t: t,
|
||||
session: session,
|
||||
|
||||
@@ -718,22 +718,6 @@ var DoltScripts = []queries.ScriptTest{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "unique key violation prevents insert",
|
||||
SetUpScript: []string{
|
||||
"CREATE TABLE auniquetable (pk int primary key, uk int unique key, i int);",
|
||||
"INSERT INTO auniquetable VALUES(0,0,0);",
|
||||
"INSERT INTO auniquetable (pk,uk) VALUES(1,0) on duplicate key update i = 99;",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "SELECT pk, uk, i from auniquetable",
|
||||
Expected: []sql.Row{
|
||||
{0, 0, 99},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func makeLargeInsert(sz int) string {
|
||||
@@ -1465,6 +1449,28 @@ var HistorySystemTableScriptTests = []queries.ScriptTest{
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "dolt_history table index lookup",
|
||||
SetUpScript: []string{
|
||||
"create table yx (y int, x int primary key);",
|
||||
"call dolt_add('.');",
|
||||
"call dolt_commit('-m', 'creating table');",
|
||||
"insert into yx values (0, 1);",
|
||||
"call dolt_commit('-am', 'add data');",
|
||||
"insert into yx values (2, 3);",
|
||||
"call dolt_commit('-am', 'add data');",
|
||||
"insert into yx values (4, 5);",
|
||||
"call dolt_commit('-am', 'add data');",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "select count(x) from dolt_history_yx where x = 1;",
|
||||
Expected: []sql.Row{
|
||||
{3},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// BrokenHistorySystemTableScriptTests contains tests that work for non-prepared, but don't work
|
||||
@@ -4665,6 +4671,13 @@ var DiffTableFunctionScriptTests = []queries.ScriptTest{
|
||||
{nil, nil, nil, 3, "three", "four", "removed"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: `
|
||||
SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type
|
||||
from dolt_diff(@Commit1, @Commit2, 't')
|
||||
inner join t on to_pk = t.pk;`,
|
||||
Expected: []sql.Row{{1, "one", "two", nil, nil, nil, "added"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -5356,6 +5369,10 @@ var LogTableFunctionScriptTests = []queries.ScriptTest{
|
||||
Query: "SELECT count(*) from dolt_log('main^');",
|
||||
Expected: []sql.Row{{3}},
|
||||
},
|
||||
{
|
||||
Query: "SELECT count(*) from dolt_log('main') join dolt_diff(@Commit1, @Commit2, 't') where commit_hash = to_commit;",
|
||||
Expected: []sql.Row{{2}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -5817,6 +5834,13 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
|
||||
Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit5, 't');",
|
||||
ExpectedErr: sql.ErrTableNotFound,
|
||||
},
|
||||
{
|
||||
Query: `
|
||||
SELECT *
|
||||
from dolt_diff_summary(@Commit3, @Commit4, 't')
|
||||
inner join t as of @Commit3 on rows_unmodified = t.pk;`,
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -8126,448 +8150,152 @@ var DoltCommitTests = []queries.ScriptTest{
|
||||
|
||||
var DoltIndexPrefixScripts = []queries.ScriptTest{
|
||||
{
|
||||
Name: "varchar primary key prefix",
|
||||
Name: "inline secondary indexes with collation",
|
||||
SetUpScript: []string{
|
||||
"create table t (v varchar(100))",
|
||||
"create table t (i int primary key, v1 varchar(10), v2 varchar(10), unique index (v1(3),v2(5))) collate utf8mb4_0900_ai_ci",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (v(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (v varchar(100), primary key (v(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "varchar keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, v varchar(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (v(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v` varchar(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v1` varchar(10) COLLATE utf8mb4_0900_ai_ci,\n `v2` varchar(10) COLLATE utf8mb4_0900_ai_ci,\n PRIMARY KEY (`i`),\n UNIQUE KEY `v1v2` (`v1`(3),`v2`(5))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
Query: "insert into t values (0, 'a', 'a'), (1, 'ab','ab'), (2, 'abc', 'abc'), (3, 'abcde', 'abcde')",
|
||||
Expected: []sql.Row{{sql.NewOkResult(4)}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (99, 'ABC', 'ABCDE')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'bb'), (2, 'cc')",
|
||||
Expected: []sql.Row{{sql.NewOkResult(3)}},
|
||||
Query: "insert into t values (99, 'ABC123', 'ABCDE123')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v = 'a'",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v = 'aa'",
|
||||
Query: "select * from t where v1 = 'A'",
|
||||
Expected: []sql.Row{
|
||||
{0, "aa"},
|
||||
{0, "a", "a"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (i int primary key, v varchar(100), index (v(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
Query: "explain select * from t where v1 = 'A'",
|
||||
Expected: []sql.Row{
|
||||
{"Filter(t.v1 = 'A')"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" ├─ filters: [{[A, A], [NULL, ∞)}]"},
|
||||
{" └─ columns: [i v1 v2]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v1 = 'ABC'",
|
||||
Expected: []sql.Row{
|
||||
{2, "abc", "abc"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "explain select * from t where v1 = 'ABC'",
|
||||
Expected: []sql.Row{
|
||||
{"Filter(t.v1 = 'ABC')"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" ├─ filters: [{[ABC, ABC], [NULL, ∞)}]"},
|
||||
{" └─ columns: [i v1 v2]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v1 = 'ABCD'",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
{
|
||||
Query: "explain select * from t where v1 = 'ABCD'",
|
||||
Expected: []sql.Row{
|
||||
{"Filter(t.v1 = 'ABCD')"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" ├─ filters: [{[ABCD, ABCD], [NULL, ∞)}]"},
|
||||
{" └─ columns: [i v1 v2]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v1 > 'A' and v1 < 'ABCDE'",
|
||||
Expected: []sql.Row{
|
||||
{1, "ab", "ab"},
|
||||
{2, "abc", "abc"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "explain select * from t where v1 > 'A' and v1 < 'ABCDE'",
|
||||
Expected: []sql.Row{
|
||||
{"Filter((t.v1 > 'A') AND (t.v1 < 'ABCDE'))"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" ├─ filters: [{(A, ABCDE), [NULL, ∞)}]"},
|
||||
{" └─ columns: [i v1 v2]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t where v1 > 'A' and v2 < 'ABCDE'",
|
||||
Expected: []sql.Row{
|
||||
{1, "ab", "ab"},
|
||||
{2, "abc", "abc"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "explain select * from t where v1 > 'A' and v2 < 'ABCDE'",
|
||||
Expected: []sql.Row{
|
||||
{"Filter((t.v1 > 'A') AND (t.v2 < 'ABCDE'))"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" ├─ filters: [{(A, ∞), (NULL, ABCDE)}]"},
|
||||
{" └─ columns: [i v1 v2]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "update t set v1 = concat(v1, 'Z') where v1 >= 'A'",
|
||||
Expected: []sql.Row{
|
||||
{sql.OkResult{RowsAffected: 4, InsertID: 0, Info: plan.UpdateInfo{Matched: 4, Updated: 4}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "explain update t set v1 = concat(v1, 'Z') where v1 >= 'A'",
|
||||
Expected: []sql.Row{
|
||||
{"Update"},
|
||||
{" └─ UpdateSource(SET t.v1 = concat(t.v1, 'Z'))"},
|
||||
{" └─ Filter(t.v1 >= 'A')"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" └─ filters: [{[A, ∞), [NULL, ∞)}]"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "select * from t",
|
||||
Expected: []sql.Row{
|
||||
{0, "aZ", "a"},
|
||||
{1, "abZ", "ab"},
|
||||
{2, "abcZ", "abc"},
|
||||
{3, "abcdeZ", "abcde"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "delete from t where v1 >= 'A'",
|
||||
Expected: []sql.Row{
|
||||
{sql.OkResult{RowsAffected: 4}},
|
||||
},
|
||||
},
|
||||
{
|
||||
Query: "show create table v_tbl",
|
||||
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `i` int NOT NULL,\n `v` varchar(100),\n PRIMARY KEY (`i`),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
Query: "explain delete from t where v1 >= 'A'",
|
||||
Expected: []sql.Row{
|
||||
{"Delete"},
|
||||
{" └─ Filter(t.v1 >= 'A')"},
|
||||
{" └─ IndexedTableAccess(t)"},
|
||||
{" ├─ index: [t.v1,t.v2]"},
|
||||
{" └─ filters: [{[A, ∞), [NULL, ∞)}]"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "varchar keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (v varchar(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (v(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `v` varchar(10),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (v varchar(100), index (v(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table v_tbl",
|
||||
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `v` varchar(100),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "char primary key prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (c char(100))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (c(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table c_tbl (c char(100), primary key (c(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "char keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, c char(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (c(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `c` char(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `c` (`c`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table c_tbl (i int primary key, c varchar(100), index (c(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table c_tbl",
|
||||
Expected: []sql.Row{{"c_tbl", "CREATE TABLE `c_tbl` (\n `i` int NOT NULL,\n `c` varchar(100),\n PRIMARY KEY (`i`),\n KEY `c` (`c`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "char keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (c char(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (c(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `c` char(10),\n UNIQUE KEY `c` (`c`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table c_tbl (c char(100), index (c(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table c_tbl",
|
||||
Expected: []sql.Row{{"c_tbl", "CREATE TABLE `c_tbl` (\n `c` char(100),\n KEY `c` (`c`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "varbinary primary key prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (v varbinary(100))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (v(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (v varbinary(100), primary key (v(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "varbinary keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, v varbinary(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (v(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v` varbinary(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (i int primary key, v varbinary(100), index (v(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table v_tbl",
|
||||
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `i` int NOT NULL,\n `v` varbinary(100),\n PRIMARY KEY (`i`),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "varbinary keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (v varbinary(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (v(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `v` varbinary(10),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table v_tbl (v varbinary(100), index (v(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table v_tbl",
|
||||
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `v` varbinary(100),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "binary primary key prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (b binary(100))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (b(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (b binary(100), primary key (b(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "binary keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, b binary(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (b(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `b` binary(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (i int primary key, b binary(100), index (b(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table b_tbl",
|
||||
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `i` int NOT NULL,\n `b` binary(100),\n PRIMARY KEY (`i`),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "binary keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (b binary(10))",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (b(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `b` binary(10),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (b binary(100), index (b(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table b_tbl",
|
||||
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `b` binary(100),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "blob primary key prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (b blob)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (b(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (b blob, primary key (b(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "blob keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, b blob)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (b(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `b` blob,\n PRIMARY KEY (`i`),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (i int primary key, b blob, index (b(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table b_tbl",
|
||||
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `i` int NOT NULL,\n `b` blob,\n PRIMARY KEY (`i`),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "blob keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (b blob)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (b(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `b` blob,\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (b blob, index (b(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table b_tbl",
|
||||
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `b` blob,\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "text primary key prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (t text)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add primary key (t(10))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
{
|
||||
Query: "create table b_tbl (t text, primary key (t(10)))",
|
||||
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "text keyed secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (i int primary key, t text)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (t(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `t` text,\n PRIMARY KEY (`i`),\n UNIQUE KEY `t` (`t`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values (0, 'aa'), (1, 'ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table t_tbl (i int primary key, t text, index (t(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t_tbl",
|
||||
Expected: []sql.Row{{"t_tbl", "CREATE TABLE `t_tbl` (\n `i` int NOT NULL,\n `t` text,\n PRIMARY KEY (`i`),\n KEY `t` (`t`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "text keyless secondary index prefix",
|
||||
SetUpScript: []string{
|
||||
"create table t (t text)",
|
||||
},
|
||||
Assertions: []queries.ScriptTestAssertion{
|
||||
{
|
||||
Query: "alter table t add unique index (t(1))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t",
|
||||
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `t` text,\n UNIQUE KEY `t` (`t`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
},
|
||||
{
|
||||
Query: "insert into t values ('aa'), ('ab')",
|
||||
ExpectedErr: sql.ErrUniqueKeyViolation,
|
||||
},
|
||||
{
|
||||
Query: "create table t_tbl (t text, index (t(10)))",
|
||||
Expected: []sql.Row{{sql.NewOkResult(0)}},
|
||||
},
|
||||
{
|
||||
Query: "show create table t_tbl",
|
||||
Expected: []sql.Row{{"t_tbl", "CREATE TABLE `t_tbl` (\n `t` text,\n KEY `t` (`t`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
|
||||
Query: "select * from t",
|
||||
Expected: []sql.Row{},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -54,7 +54,11 @@ func (op EqualsOp) CompareNomsValues(v1, v2 types.Value) (bool, error) {
|
||||
}
|
||||
|
||||
// CompareToNil always returns false as values are neither greater than, less than, or equal to nil
|
||||
func (op EqualsOp) CompareToNil(types.Value) (bool, error) {
|
||||
// except for equality op, the compared value is null.
|
||||
func (op EqualsOp) CompareToNil(v types.Value) (bool, error) {
|
||||
if v == types.NullValue {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ func TestCompareToNull(t *testing.T) {
|
||||
gte: false,
|
||||
lt: false,
|
||||
lte: false,
|
||||
eq: false,
|
||||
eq: true,
|
||||
},
|
||||
{
|
||||
name: "not nil",
|
||||
|
||||
@@ -102,6 +102,14 @@ func getExpFunc(nbf *types.NomsBinFormat, sch schema.Schema, exp sql.Expression)
|
||||
return newAndFunc(leftFunc, rightFunc), nil
|
||||
case *expression.InTuple:
|
||||
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
|
||||
case *expression.Not:
|
||||
expFunc, err := getExpFunc(nbf, sch, typedExpr.Child)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newNotFunc(expFunc), nil
|
||||
case *expression.IsNull:
|
||||
return newComparisonFunc(EqualsOp{}, expression.BinaryExpression{Left: typedExpr.Child, Right: expression.NewLiteral(nil, sql.Null)}, sch)
|
||||
}
|
||||
|
||||
return nil, errNotImplemented.New(exp.Type().String())
|
||||
@@ -139,6 +147,17 @@ func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func newNotFunc(exp ExpressionFunc) ExpressionFunc {
|
||||
return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
|
||||
res, err := exp(ctx, vals)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return !res, nil
|
||||
}
|
||||
}
|
||||
|
||||
type ComparisonType int
|
||||
|
||||
const (
|
||||
|
||||
@@ -260,6 +260,8 @@ func parseDate(s string) (time.Time, error) {
|
||||
func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
|
||||
v := literal.Value()
|
||||
switch typedVal := v.(type) {
|
||||
case time.Time:
|
||||
return typedVal, nil
|
||||
case string:
|
||||
ts, err := parseDate(typedVal)
|
||||
|
||||
@@ -275,6 +277,10 @@ func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
|
||||
|
||||
// LiteralToNomsValue converts a go-mysql-servel Literal into a noms value.
|
||||
func LiteralToNomsValue(kind types.NomsKind, literal *expression.Literal) (types.Value, error) {
|
||||
if literal.Value() == nil {
|
||||
return types.NullValue, nil
|
||||
}
|
||||
|
||||
switch kind {
|
||||
case types.IntKind:
|
||||
i64, err := literalAsInt64(literal)
|
||||
|
||||
@@ -56,7 +56,6 @@ var (
|
||||
var _ sql.Table = (*HistoryTable)(nil)
|
||||
var _ sql.FilteredTable = (*HistoryTable)(nil)
|
||||
var _ sql.IndexAddressableTable = (*HistoryTable)(nil)
|
||||
var _ sql.ParallelizedIndexAddressableTable = (*HistoryTable)(nil)
|
||||
var _ sql.IndexedTable = (*HistoryTable)(nil)
|
||||
|
||||
// HistoryTable is a system table that shows the history of rows over time
|
||||
@@ -68,10 +67,6 @@ type HistoryTable struct {
|
||||
projectedCols []uint64
|
||||
}
|
||||
|
||||
func (ht *HistoryTable) ShouldParallelizeAccess() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (ht *HistoryTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
|
||||
tbl, err := ht.doltTable.DoltTable(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -395,8 +395,18 @@ var _ DoltIndex = (*doltIndex)(nil)
|
||||
|
||||
// CanSupport implements sql.Index
|
||||
func (di *doltIndex) CanSupport(...sql.Range) bool {
|
||||
// TODO (james): don't use and prefix indexes if there's a prefix on a text/blob column
|
||||
if len(di.prefixLengths) > 0 {
|
||||
return false
|
||||
hasTextBlob := false
|
||||
colColl := di.indexSch.GetAllCols()
|
||||
colColl.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
|
||||
if sql.IsTextBlob(col.TypeInfo.ToSqlType()) {
|
||||
hasTextBlob = true
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
})
|
||||
return !hasTextBlob
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -634,6 +644,11 @@ func (di *doltIndex) HandledFilters(filters []sql.Expression) []sql.Expression {
|
||||
return nil
|
||||
}
|
||||
|
||||
// filters on indexes with prefix lengths are not completely handled
|
||||
if len(di.prefixLengths) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var handled []sql.Expression
|
||||
for _, f := range filters {
|
||||
if expression.ContainsImpreciseComparison(f) {
|
||||
@@ -776,6 +791,30 @@ func pruneEmptyRanges(sqlRanges []sql.Range) (pruned []sql.Range, err error) {
|
||||
return pruned, nil
|
||||
}
|
||||
|
||||
// trimRangeCutValue will trim the key value retrieved, depending on its type and prefix length
|
||||
// TODO: this is just the trimKeyPart in the SecondaryIndexWriters, maybe find a different place
|
||||
func (di *doltIndex) trimRangeCutValue(to int, keyPart interface{}) interface{} {
|
||||
var prefixLength uint16
|
||||
if len(di.prefixLengths) > to {
|
||||
prefixLength = di.prefixLengths[to]
|
||||
}
|
||||
if prefixLength != 0 {
|
||||
switch kp := keyPart.(type) {
|
||||
case string:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
case []uint8:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
}
|
||||
}
|
||||
return keyPart
|
||||
}
|
||||
|
||||
func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.NodeStore, ranges []sql.Range, tb *val.TupleBuilder) ([]prolly.Range, error) {
|
||||
ranges, err := pruneEmptyRanges(ranges)
|
||||
if err != nil {
|
||||
@@ -788,19 +827,20 @@ func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.Node
|
||||
fields := make([]prolly.RangeField, len(rng))
|
||||
for j, expr := range rng {
|
||||
if rangeCutIsBinding(expr.LowerBound) {
|
||||
bound := expr.LowerBound.TypeAsLowerBound()
|
||||
fields[j].Lo = prolly.Bound{
|
||||
Binding: true,
|
||||
Inclusive: bound == sql.Closed,
|
||||
}
|
||||
// accumulate bound values in |tb|
|
||||
v, err := getRangeCutValue(expr.LowerBound, rng[j].Typ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = PutField(ctx, ns, tb, j, v); err != nil {
|
||||
nv := di.trimRangeCutValue(j, v)
|
||||
if err = PutField(ctx, ns, tb, j, nv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
bound := expr.LowerBound.TypeAsLowerBound()
|
||||
fields[j].Lo = prolly.Bound{
|
||||
Binding: true,
|
||||
Inclusive: bound == sql.Closed,
|
||||
}
|
||||
} else {
|
||||
fields[j].Lo = prolly.Bound{}
|
||||
}
|
||||
@@ -814,18 +854,19 @@ func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.Node
|
||||
for i, expr := range rng {
|
||||
if rangeCutIsBinding(expr.UpperBound) {
|
||||
bound := expr.UpperBound.TypeAsUpperBound()
|
||||
fields[i].Hi = prolly.Bound{
|
||||
Binding: true,
|
||||
Inclusive: bound == sql.Closed,
|
||||
}
|
||||
// accumulate bound values in |tb|
|
||||
v, err := getRangeCutValue(expr.UpperBound, rng[i].Typ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = PutField(ctx, ns, tb, i, v); err != nil {
|
||||
nv := di.trimRangeCutValue(i, v)
|
||||
if err = PutField(ctx, ns, tb, i, nv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fields[i].Hi = prolly.Bound{
|
||||
Binding: true,
|
||||
Inclusive: bound == sql.Closed || nv != v, // TODO (james): this might panic for []byte
|
||||
}
|
||||
} else {
|
||||
fields[i].Hi = prolly.Bound{}
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ func (idt *IndexedDoltTable) PartitionRows(ctx *sql.Context, part sql.Partition)
|
||||
return nil, err
|
||||
}
|
||||
if idt.lb == nil || !canCache || idt.lb.Key() != key {
|
||||
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, nil, idt.table.sqlSch, idt.isDoltFormat)
|
||||
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, idt.table.projectedCols, idt.table.sqlSch, idt.isDoltFormat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -98,7 +98,7 @@ func (idt *IndexedDoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition
|
||||
return nil, err
|
||||
}
|
||||
if idt.lb == nil || !canCache || idt.lb.Key() != key {
|
||||
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, nil, idt.table.sqlSch, idt.isDoltFormat)
|
||||
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, idt.table.projectedCols, idt.table.sqlSch, idt.isDoltFormat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -293,8 +293,14 @@ func (m prollySecondaryIndexWriter) trimKeyPart(to int, keyPart interface{}) int
|
||||
if prefixLength != 0 {
|
||||
switch kp := keyPart.(type) {
|
||||
case string:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
case []uint8:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,8 +218,14 @@ func (writer prollyKeylessSecondaryWriter) trimKeyPart(to int, keyPart interface
|
||||
if prefixLength != 0 {
|
||||
switch kp := keyPart.(type) {
|
||||
case string:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
case []uint8:
|
||||
if prefixLength > uint16(len(kp)) {
|
||||
prefixLength = uint16(len(kp))
|
||||
}
|
||||
keyPart = kp[:prefixLength]
|
||||
}
|
||||
}
|
||||
@@ -304,7 +310,8 @@ func (writer prollyKeylessSecondaryWriter) Delete(ctx context.Context, sqlRow sq
|
||||
|
||||
for to := range writer.keyMap {
|
||||
from := writer.keyMap.MapOrdinal(to)
|
||||
if err := index.PutField(ctx, writer.mut.NodeStore(), writer.keyBld, to, sqlRow[from]); err != nil {
|
||||
keyPart := writer.trimKeyPart(to, sqlRow[from])
|
||||
if err := index.PutField(ctx, writer.mut.NodeStore(), writer.keyBld, to, keyPart); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ package argparser
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
@@ -49,14 +50,17 @@ func ValidatorFromStrList(paramName string, validStrList []string) ValidationFun
|
||||
|
||||
type ArgParser struct {
|
||||
Supported []*Option
|
||||
NameOrAbbrevToOpt map[string]*Option
|
||||
nameOrAbbrevToOpt map[string]*Option
|
||||
ArgListHelp [][2]string
|
||||
}
|
||||
|
||||
func NewArgParser() *ArgParser {
|
||||
var supported []*Option
|
||||
nameOrAbbrevToOpt := make(map[string]*Option)
|
||||
return &ArgParser{supported, nameOrAbbrevToOpt, nil}
|
||||
return &ArgParser{
|
||||
Supported: supported,
|
||||
nameOrAbbrevToOpt: nameOrAbbrevToOpt,
|
||||
}
|
||||
}
|
||||
|
||||
// SupportOption adds support for a new argument with the option given. Options must have a unique name and abbreviated name.
|
||||
@@ -64,8 +68,8 @@ func (ap *ArgParser) SupportOption(opt *Option) {
|
||||
name := opt.Name
|
||||
abbrev := opt.Abbrev
|
||||
|
||||
_, nameExist := ap.NameOrAbbrevToOpt[name]
|
||||
_, abbrevExist := ap.NameOrAbbrevToOpt[abbrev]
|
||||
_, nameExist := ap.nameOrAbbrevToOpt[name]
|
||||
_, abbrevExist := ap.nameOrAbbrevToOpt[abbrev]
|
||||
|
||||
if name == "" {
|
||||
panic("Name is required")
|
||||
@@ -80,10 +84,10 @@ func (ap *ArgParser) SupportOption(opt *Option) {
|
||||
}
|
||||
|
||||
ap.Supported = append(ap.Supported, opt)
|
||||
ap.NameOrAbbrevToOpt[name] = opt
|
||||
ap.nameOrAbbrevToOpt[name] = opt
|
||||
|
||||
if abbrev != "" {
|
||||
ap.NameOrAbbrevToOpt[abbrev] = opt
|
||||
ap.nameOrAbbrevToOpt[abbrev] = opt
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,6 +99,18 @@ func (ap *ArgParser) SupportsFlag(name, abbrev, desc string) *ArgParser {
|
||||
return ap
|
||||
}
|
||||
|
||||
// SupportsAlias adds support for an alias for an existing option. The alias can be used in place of the original option.
|
||||
func (ap *ArgParser) SupportsAlias(alias, original string) *ArgParser {
|
||||
opt, ok := ap.nameOrAbbrevToOpt[original]
|
||||
|
||||
if !ok {
|
||||
panic(fmt.Sprintf("No option found for %s, this is a bug", original))
|
||||
}
|
||||
|
||||
ap.nameOrAbbrevToOpt[alias] = opt
|
||||
return ap
|
||||
}
|
||||
|
||||
// SupportsString adds support for a new string argument with the description given. See SupportOpt for details on params.
|
||||
func (ap *ArgParser) SupportsString(name, abbrev, valDesc, desc string) *ArgParser {
|
||||
opt := &Option{name, abbrev, valDesc, OptionalValue, desc, nil, false}
|
||||
@@ -146,7 +162,7 @@ func (ap *ArgParser) SupportsInt(name, abbrev, valDesc, desc string) *ArgParser
|
||||
// modal options in order of descending string length
|
||||
func (ap *ArgParser) sortedModalOptions() []string {
|
||||
smo := make([]string, 0, len(ap.Supported))
|
||||
for s, opt := range ap.NameOrAbbrevToOpt {
|
||||
for s, opt := range ap.nameOrAbbrevToOpt {
|
||||
if opt.OptType == OptionalFlag && s != "" {
|
||||
smo = append(smo, s)
|
||||
}
|
||||
@@ -179,7 +195,7 @@ func (ap *ArgParser) matchModalOptions(arg string) (matches []*Option, rest stri
|
||||
isMatch := len(rest) >= lo && rest[:lo] == on
|
||||
if isMatch {
|
||||
rest = rest[lo:]
|
||||
m := ap.NameOrAbbrevToOpt[on]
|
||||
m := ap.nameOrAbbrevToOpt[on]
|
||||
matches = append(matches, m)
|
||||
|
||||
// only match options once
|
||||
@@ -200,7 +216,7 @@ func (ap *ArgParser) matchModalOptions(arg string) (matches []*Option, rest stri
|
||||
|
||||
func (ap *ArgParser) sortedValueOptions() []string {
|
||||
vos := make([]string, 0, len(ap.Supported))
|
||||
for s, opt := range ap.NameOrAbbrevToOpt {
|
||||
for s, opt := range ap.nameOrAbbrevToOpt {
|
||||
if (opt.OptType == OptionalValue || opt.OptType == OptionalEmptyValue) && s != "" {
|
||||
vos = append(vos, s)
|
||||
}
|
||||
@@ -219,14 +235,14 @@ func (ap *ArgParser) matchValueOption(arg string) (match *Option, value *string)
|
||||
if len(v) > 0 {
|
||||
value = &v
|
||||
}
|
||||
match = ap.NameOrAbbrevToOpt[on]
|
||||
match = ap.nameOrAbbrevToOpt[on]
|
||||
return match, value
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Parses the string args given using the configuration previously specified with calls to the various Supports*
|
||||
// Parse parses the string args given using the configuration previously specified with calls to the various Supports*
|
||||
// methods. Any unrecognized arguments or incorrect types will result in an appropriate error being returned. If the
|
||||
// universal --help or -h flag is found, an ErrHelp error is returned.
|
||||
func (ap *ArgParser) Parse(args []string) (*ArgParseResults, error) {
|
||||
|
||||
@@ -201,7 +201,7 @@ func (res *ArgParseResults) AnyFlagsEqualTo(val bool) *set.StrSet {
|
||||
func (res *ArgParseResults) FlagsEqualTo(names []string, val bool) *set.StrSet {
|
||||
results := make([]string, 0, len(res.parser.Supported))
|
||||
for _, name := range names {
|
||||
opt, ok := res.parser.NameOrAbbrevToOpt[name]
|
||||
opt, ok := res.parser.nameOrAbbrevToOpt[name]
|
||||
if ok && opt.OptType == OptionalFlag {
|
||||
_, ok := res.options[name]
|
||||
|
||||
|
||||
@@ -16,12 +16,14 @@ package jwtauth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
jose "gopkg.in/square/go-jose.v2"
|
||||
"gopkg.in/square/go-jose.v2/json"
|
||||
)
|
||||
@@ -110,3 +112,171 @@ func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
|
||||
}
|
||||
return jwks.Key(kid), nil
|
||||
}
|
||||
|
||||
// The MultiJWKS will source JWKS from multiple URLs and will make them all
|
||||
// available through GetKey(). It's GetKey() cannot error, but it can return no
|
||||
// results.
|
||||
//
|
||||
// The URLs in the refresh list are static. Each URL will be periodically
|
||||
// refreshed and the results will be aggregated into the JWKS view. If a key no
|
||||
// longer appears at the URL, it may eventually be removed from the set of keys
|
||||
// available through GetKey(). Requesting a key which is not currently in the
|
||||
// key set will generally hint that the URLs should be more aggressively
|
||||
// refreshed, but there is no blocking on refreshing the URLs.
|
||||
//
|
||||
// GracefulStop() will shutdown any ongoing fetching work and will return when
|
||||
// everything is cleanly shutdown.
|
||||
type MultiJWKS struct {
|
||||
client *http.Client
|
||||
wg sync.WaitGroup
|
||||
stop chan struct{}
|
||||
refresh []chan *sync.WaitGroup
|
||||
urls []string
|
||||
sets []jose.JSONWebKeySet
|
||||
agg jose.JSONWebKeySet
|
||||
mu sync.RWMutex
|
||||
lgr *logrus.Entry
|
||||
stopped bool
|
||||
}
|
||||
|
||||
func NewMultiJWKS(lgr *logrus.Entry, urls []string, client *http.Client) *MultiJWKS {
|
||||
res := new(MultiJWKS)
|
||||
res.lgr = lgr
|
||||
res.client = client
|
||||
res.urls = urls
|
||||
res.stop = make(chan struct{})
|
||||
res.refresh = make([]chan *sync.WaitGroup, len(urls))
|
||||
for i := range res.refresh {
|
||||
res.refresh[i] = make(chan *sync.WaitGroup, 3)
|
||||
}
|
||||
res.sets = make([]jose.JSONWebKeySet, len(urls))
|
||||
return res
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) Run() {
|
||||
t.wg.Add(len(t.urls))
|
||||
for i := 0; i < len(t.urls); i++ {
|
||||
go t.thread(i)
|
||||
}
|
||||
t.wg.Wait()
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) GracefulStop() {
|
||||
t.mu.Lock()
|
||||
t.stopped = true
|
||||
t.mu.Unlock()
|
||||
close(t.stop)
|
||||
t.wg.Wait()
|
||||
// TODO: Potentially clear t.refresh channels, ensure nothing else can call GetKey()...
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) needsRefresh() *sync.WaitGroup {
|
||||
wg := new(sync.WaitGroup)
|
||||
if t.stopped {
|
||||
return wg
|
||||
}
|
||||
wg.Add(len(t.refresh))
|
||||
for _, c := range t.refresh {
|
||||
select {
|
||||
case c <- wg:
|
||||
default:
|
||||
wg.Done()
|
||||
}
|
||||
}
|
||||
return wg
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) store(i int, jwks jose.JSONWebKeySet) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.sets[i] = jwks
|
||||
sum := 0
|
||||
for _, s := range t.sets {
|
||||
sum += len(s.Keys)
|
||||
}
|
||||
t.agg.Keys = make([]jose.JSONWebKey, 0, sum)
|
||||
for _, s := range t.sets {
|
||||
t.agg.Keys = append(t.agg.Keys, s.Keys...)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
res := t.agg.Key(kid)
|
||||
if len(res) == 0 {
|
||||
t.lgr.Infof("fetched key %s, found no key, signaling refresh", kid)
|
||||
refresh := t.needsRefresh()
|
||||
t.mu.RUnlock()
|
||||
refresh.Wait()
|
||||
t.mu.RLock()
|
||||
res = t.agg.Key(kid)
|
||||
t.lgr.Infof("refresh for key %s done, found %d keys", kid, len(res))
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) fetch(i int) error {
|
||||
request, err := http.NewRequest("GET", t.urls[i], nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response, err := t.client.Do(request)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
if response.StatusCode/100 != 2 {
|
||||
return fmt.Errorf("http request failed: StatusCode: %d", response.StatusCode)
|
||||
}
|
||||
contents, err := ioutil.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var jwks jose.JSONWebKeySet
|
||||
err = json.Unmarshal(contents, &jwks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.store(i, jwks)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *MultiJWKS) thread(i int) {
|
||||
defer t.wg.Done()
|
||||
timer := time.NewTimer(30 * time.Second)
|
||||
var refresh *sync.WaitGroup
|
||||
for {
|
||||
nextRefresh := 30 * time.Second
|
||||
err := t.fetch(i)
|
||||
if err != nil {
|
||||
// Something bad...
|
||||
t.lgr.Warnf("error fetching %s: %v", t.urls[i], err)
|
||||
nextRefresh = 1 * time.Second
|
||||
}
|
||||
timer.Reset(nextRefresh)
|
||||
if refresh != nil {
|
||||
refresh.Done()
|
||||
}
|
||||
refresh = nil
|
||||
select {
|
||||
case <-t.stop:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case refresh = <-t.refresh[i]:
|
||||
refresh.Done()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
case refresh = <-t.refresh[i]:
|
||||
if !timer.Stop() {
|
||||
<-timer.C
|
||||
}
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ table BranchControl {
|
||||
|
||||
table BranchControlAccess {
|
||||
binlog: BranchControlBinlog;
|
||||
databases: [BranchControlMatchExpression];
|
||||
branches: [BranchControlMatchExpression];
|
||||
users: [BranchControlMatchExpression];
|
||||
hosts: [BranchControlMatchExpression];
|
||||
@@ -28,6 +29,7 @@ table BranchControlAccess {
|
||||
}
|
||||
|
||||
table BranchControlAccessValue {
|
||||
database: string;
|
||||
branch: string;
|
||||
user: string;
|
||||
host: string;
|
||||
@@ -36,6 +38,7 @@ table BranchControlAccessValue {
|
||||
|
||||
table BranchControlNamespace {
|
||||
binlog: BranchControlBinlog;
|
||||
databases: [BranchControlMatchExpression];
|
||||
branches: [BranchControlMatchExpression];
|
||||
users: [BranchControlMatchExpression];
|
||||
hosts: [BranchControlMatchExpression];
|
||||
@@ -43,6 +46,7 @@ table BranchControlNamespace {
|
||||
}
|
||||
|
||||
table BranchControlNamespaceValue {
|
||||
database: string;
|
||||
branch: string;
|
||||
user: string;
|
||||
host: string;
|
||||
@@ -54,6 +58,7 @@ table BranchControlBinlog {
|
||||
|
||||
table BranchControlBinlogRow {
|
||||
is_insert: bool;
|
||||
database: string;
|
||||
branch: string;
|
||||
user: string;
|
||||
host: string;
|
||||
|
||||
@@ -118,7 +118,7 @@ func MutateMapWithTupleIter(ctx context.Context, m Map, iter TupleIter) (Map, er
|
||||
}
|
||||
|
||||
func DiffMaps(ctx context.Context, from, to Map, cb tree.DiffFn) error {
|
||||
return tree.DiffOrderedTrees(ctx, from.tuples, to.tuples, cb)
|
||||
return tree.DiffOrderedTrees(ctx, from.tuples, to.tuples, makeDiffCallBack(from, to, cb))
|
||||
}
|
||||
|
||||
// RangeDiffMaps returns diffs within a Range. See Range for which diffs are
|
||||
@@ -153,13 +153,15 @@ func RangeDiffMaps(ctx context.Context, from, to Map, rng Range, cb tree.DiffFn)
|
||||
return err
|
||||
}
|
||||
|
||||
dcb := makeDiffCallBack(from, to, cb)
|
||||
|
||||
for {
|
||||
var diff tree.Diff
|
||||
if diff, err = differ.Next(ctx); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if err = cb(ctx, diff); err != nil {
|
||||
if err = dcb(ctx, diff); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -170,7 +172,23 @@ func RangeDiffMaps(ctx context.Context, from, to Map, rng Range, cb tree.DiffFn)
|
||||
// specified by |start| and |stop|. If |start| and/or |stop| is null, then the
|
||||
// range is unbounded towards that end.
|
||||
func DiffMapsKeyRange(ctx context.Context, from, to Map, start, stop val.Tuple, cb tree.DiffFn) error {
|
||||
return tree.DiffKeyRangeOrderedTrees(ctx, from.tuples, to.tuples, start, stop, cb)
|
||||
return tree.DiffKeyRangeOrderedTrees(ctx, from.tuples, to.tuples, start, stop, makeDiffCallBack(from, to, cb))
|
||||
}
|
||||
|
||||
func makeDiffCallBack(from, to Map, innerCb tree.DiffFn) tree.DiffFn {
|
||||
if !from.valDesc.Equals(to.valDesc) {
|
||||
return innerCb
|
||||
}
|
||||
|
||||
return func(ctx context.Context, diff tree.Diff) error {
|
||||
// Skip diffs produced by non-canonical tuples. A canonical-tuple is a
|
||||
// tuple where any null suffixes have been trimmed.
|
||||
if diff.Type == tree.ModifiedDiff &&
|
||||
from.valDesc.Compare(val.Tuple(diff.From), val.Tuple(diff.To)) == 0 {
|
||||
return nil
|
||||
}
|
||||
return innerCb(ctx, diff)
|
||||
}
|
||||
}
|
||||
|
||||
func MergeMaps(ctx context.Context, left, right, base Map, cb tree.CollisionFn) (Map, tree.MergeStats, error) {
|
||||
|
||||
229
integration-tests/bats/branch-control.bats
Normal file
229
integration-tests/bats/branch-control.bats
Normal file
@@ -0,0 +1,229 @@
|
||||
#!/usr/bin/env bats
|
||||
load $BATS_TEST_DIRNAME/helper/common.bash
|
||||
load $BATS_TEST_DIRNAME/helper/query-server-common.bash
|
||||
|
||||
setup() {
|
||||
setup_common
|
||||
}
|
||||
|
||||
teardown() {
|
||||
assert_feature_version
|
||||
stop_sql_server
|
||||
teardown_common
|
||||
}
|
||||
|
||||
setup_test_user() {
|
||||
dolt sql -q "create user test"
|
||||
dolt sql -q "grant all on *.* to test"
|
||||
dolt sql -q "delete from dolt_branch_control where user='%'"
|
||||
}
|
||||
|
||||
@test "branch-control: fresh database. branch control tables exist" {
|
||||
run dolt sql -r csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "%,%,%,%,write" ]
|
||||
|
||||
dolt sql -q "select * from dolt_branch_namespace_control"
|
||||
|
||||
run dolt sql -q "describe dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[[ $output =~ "database" ]] || false
|
||||
[[ $output =~ "branch" ]] || false
|
||||
[[ $output =~ "user" ]] || false
|
||||
[[ $output =~ "host" ]] || false
|
||||
[[ $output =~ "permissions" ]] || false
|
||||
|
||||
run dolt sql -q "describe dolt_branch_namespace_control"
|
||||
[ $status -eq 0 ]
|
||||
[[ $output =~ "database" ]] || false
|
||||
[[ $output =~ "branch" ]] || false
|
||||
[[ $output =~ "user" ]] || false
|
||||
[[ $output =~ "host" ]] || false
|
||||
}
|
||||
|
||||
@test "branch-control: fresh database. branch control tables exist through server interface" {
|
||||
start_sql_server
|
||||
|
||||
run dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "%,%,%,%,write" ]
|
||||
|
||||
dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" -q "select * from dolt_branch_namespace_control"
|
||||
}
|
||||
|
||||
@test "branch-control: modify dolt_branch_control from dolt sql then make sure changes are reflected" {
|
||||
setup_test_user
|
||||
dolt sql -q "insert into dolt_branch_control values ('test-db','test-branch', 'test', '%', 'write')"
|
||||
|
||||
run dolt sql -r csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "test-db,test-branch,test,%,write" ]
|
||||
|
||||
start_sql_server
|
||||
run dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "test-db,test-branch,test,%,write" ]
|
||||
}
|
||||
|
||||
@test "branch-control: default user root works as expected" {
|
||||
# I can't figure out how to get a dolt sql-server started as root.
|
||||
# So, I'm copying the pattern from sql-privs.bats and starting it
|
||||
# manually.
|
||||
PORT=$( definePORT )
|
||||
dolt sql-server --host 0.0.0.0 --port=$PORT &
|
||||
SERVER_PID=$! # will get killed by teardown_common
|
||||
sleep 5 # not using python wait so this works on windows
|
||||
|
||||
run dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT --result-format csv -q "select * from dolt_branch_control"
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "%,%,%,%,write" ]
|
||||
|
||||
dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT -q "delete from dolt_branch_control where user='%'"
|
||||
|
||||
run dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ "$output" == "" ]
|
||||
}
|
||||
|
||||
@test "branch-control: test basic branch write permissions" {
|
||||
setup_test_user
|
||||
|
||||
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test', '%', 'write')"
|
||||
dolt branch test-branch
|
||||
|
||||
start_sql_server
|
||||
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "does not have the correct permissions" ]] || false
|
||||
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('test-branch'); create table t (c1 int)"
|
||||
|
||||
# I should also have branch permissions on branches I create
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('-b', 'test-branch-2'); create table t (c1 int)"
|
||||
|
||||
# Now back to main. Still locked out.
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "does not have the correct permissions" ]] || false
|
||||
}
|
||||
|
||||
@test "branch-control: test admin permissions" {
|
||||
setup_test_user
|
||||
|
||||
dolt sql -q "create user test2"
|
||||
dolt sql -q "grant all on *.* to test2"
|
||||
|
||||
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test', '%', 'admin')"
|
||||
dolt branch test-branch
|
||||
|
||||
start_sql_server
|
||||
|
||||
# Admin has no write permission to branch not an admin on
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "does not have the correct permissions" ]] || false
|
||||
|
||||
# Admin can write
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('test-branch'); create table t (c1 int)"
|
||||
|
||||
# Admin can make other users
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test2', '%', 'write')"
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
|
||||
[ ${lines[2]} = "dolt_repo_$$,test-branch,test2,%,write" ]
|
||||
|
||||
# test2 can see all branch permissions
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
|
||||
[ ${lines[2]} = "dolt_repo_$$,test-branch,test2,%,write" ]
|
||||
|
||||
# test2 now has write permissions on test-branch
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 -q "call dolt_checkout('test-branch'); insert into t values(0)"
|
||||
|
||||
# Remove test2 permissions
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "delete from dolt_branch_control where user='test2'"
|
||||
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
|
||||
|
||||
# test2 cannot write to branch
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 -q "call dolt_checkout('test-branch'); insert into t values(1)"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "does not have the correct permissions" ]] || false
|
||||
}
|
||||
|
||||
@test "branch-control: creating a branch grants admin permissions" {
|
||||
setup_test_user
|
||||
|
||||
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'main', 'test', '%', 'write')"
|
||||
|
||||
start_sql_server
|
||||
|
||||
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_branch('test-branch')"
|
||||
|
||||
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host,permissions" ]
|
||||
[ ${lines[1]} = "dolt_repo_$$,main,test,%,write" ]
|
||||
[ ${lines[2]} = "dolt_repo_$$,test-branch,test,%,admin" ]
|
||||
}
|
||||
|
||||
@test "branch-control: test branch namespace control" {
|
||||
setup_test_user
|
||||
|
||||
dolt sql -q "create user test2"
|
||||
dolt sql -q "grant all on *.* to test2"
|
||||
|
||||
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-
|
||||
branch', 'test', '%', 'admin')"
|
||||
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test-%', 'test2', '%')"
|
||||
|
||||
start_sql_server
|
||||
|
||||
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_namespace_control"
|
||||
[ $status -eq 0 ]
|
||||
[ ${lines[0]} = "database,branch,user,host" ]
|
||||
[ ${lines[1]} = "dolt_repo_$$,test-%,test2,%" ]
|
||||
|
||||
# test cannot create test-branch
|
||||
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test-branch')"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "cannot create a branch" ]] || false
|
||||
|
||||
# test2 can create test-branch
|
||||
dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test-branch')"
|
||||
}
|
||||
|
||||
@test "branch-control: test longest match in branch namespace control" {
|
||||
setup_test_user
|
||||
|
||||
dolt sql -q "create user test2"
|
||||
dolt sql -q "grant all on *.* to test2"
|
||||
|
||||
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test/%', 'test', '%')"
|
||||
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test2/%', 'test2', '%')"
|
||||
|
||||
start_sql_server
|
||||
|
||||
# test can create a branch in its namesapce but not in test2
|
||||
dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test/branch1')"
|
||||
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test2/branch1')"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "cannot create a branch" ]] || false
|
||||
|
||||
dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test2/branch1')"
|
||||
run dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test/branch1')"
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "cannot create a branch" ]] || false
|
||||
}
|
||||
@@ -108,6 +108,7 @@ start_multi_db_server() {
|
||||
stop_sql_server() {
|
||||
# Clean up any mysql.sock file in the default, global location
|
||||
rm -f /tmp/mysql.sock
|
||||
rm -f /tmp/dolt.$PORT.sock
|
||||
|
||||
wait=$1
|
||||
if [ ! -z "$SERVER_PID" ]; then
|
||||
|
||||
@@ -764,7 +764,7 @@ DELIM
|
||||
[ "${lines[1]}" = 5,5 ]
|
||||
}
|
||||
|
||||
@test "import-create-tables: --ignore-skipped-rows correctly prevents skipped rows from printing" {
|
||||
@test "import-create-tables: --quiet correctly prevents skipped rows from printing" {
|
||||
cat <<DELIM > 1pk5col-rpt-ints.csv
|
||||
pk,c1,c2,c3,c4,c5
|
||||
1,1,2,3,4,5
|
||||
@@ -772,7 +772,7 @@ pk,c1,c2,c3,c4,c5
|
||||
1,1,2,3,4,8
|
||||
DELIM
|
||||
|
||||
run dolt table import -c --continue --ignore-skipped-rows --pk=pk test 1pk5col-rpt-ints.csv
|
||||
run dolt table import -c --continue --quiet --pk=pk test 1pk5col-rpt-ints.csv
|
||||
[ "$status" -eq 0 ]
|
||||
! [[ "$output" =~ "The following rows were skipped:" ]] || false
|
||||
! [[ "$output" =~ "1,1,2,3,4,7" ]] || false
|
||||
@@ -780,4 +780,17 @@ DELIM
|
||||
[[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false
|
||||
[[ "$output" =~ "Lines skipped: 2" ]] || false
|
||||
[[ "$output" =~ "Import completed successfully." ]] || false
|
||||
|
||||
dolt sql -q "drop table test"
|
||||
|
||||
# --ignore-skipped-rows is an alias for --quiet
|
||||
run dolt table import -c --continue --ignore-skipped-rows --pk=pk test 1pk5col-rpt-ints.csv
|
||||
[ "$status" -eq 0 ]
|
||||
! [[ "$output" =~ "The following rows were skipped:" ]] || false
|
||||
! [[ "$output" =~ "1,1,2,3,4,7" ]] || false
|
||||
! [[ "$output" =~ "1,1,2,3,4,8" ]] || false
|
||||
[[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false
|
||||
[[ "$output" =~ "Lines skipped: 2" ]] || false
|
||||
[[ "$output" =~ "Import completed successfully." ]] || false
|
||||
|
||||
}
|
||||
|
||||
@@ -931,6 +931,27 @@ SQL
|
||||
[[ ! "$output" =~ "add (2,3) to t1" ]] || false
|
||||
}
|
||||
|
||||
@test "merge: dolt merge does not ff and not commit with --no-ff and --no-commit" {
|
||||
dolt branch other
|
||||
dolt sql -q "INSERT INTO test1 VALUES (1,2,3)"
|
||||
dolt commit -am "add (1,2,3) to test1";
|
||||
|
||||
dolt checkout other
|
||||
run dolt sql -q "select * from test1;" -r csv
|
||||
[[ ! "$output" =~ "1,2,3" ]] || false
|
||||
|
||||
run dolt merge other --no-ff --no-commit
|
||||
log_status_eq 0
|
||||
[[ "$output" =~ "Automatic merge went well; stopped before committing as requested" ]] || false
|
||||
|
||||
run dolt log --oneline -n 1
|
||||
[[ "$output" =~ "added tables" ]] || false
|
||||
[[ ! "$output" =~ "add (1,2,3) to test1" ]] || false
|
||||
|
||||
run dolt commit -m "merge main"
|
||||
log_status_eq 0
|
||||
}
|
||||
|
||||
@test "merge: specify ---author for merge that's used for creating commit" {
|
||||
dolt branch other
|
||||
dolt sql -q "INSERT INTO test1 VALUES (1,2,3)"
|
||||
|
||||
@@ -134,3 +134,16 @@ teardown() {
|
||||
[ $status -ne 0 ]
|
||||
[[ $output =~ "not found" ]] || false
|
||||
}
|
||||
|
||||
@test "sql-client: handle dashes for implicit database" {
|
||||
make_repo test-dashes
|
||||
cd test-dashes
|
||||
PORT=$( definePORT )
|
||||
dolt sql-server --user=root --port=$PORT &
|
||||
SERVER_PID=$! # will get killed by teardown_common
|
||||
sleep 5 # not using python wait so this works on windows
|
||||
|
||||
run dolt sql-client -u root -P $PORT -q "show databases"
|
||||
[ $status -eq 0 ]
|
||||
[[ $output =~ " test_dashes " ]] || false
|
||||
}
|
||||
|
||||
@@ -2434,6 +2434,44 @@ SQL
|
||||
[[ "$output" =~ "| 3 |" ]] || false
|
||||
}
|
||||
|
||||
@test "sql: dolt diff table correctly works with NOT and/or IS NULL" {
|
||||
dolt sql -q "CREATE TABLE t(pk int primary key);"
|
||||
dolt add .
|
||||
dolt commit -m "new table t"
|
||||
dolt sql -q "INSERT INTO t VALUES (1), (2)"
|
||||
dolt commit -am "add 1, 2"
|
||||
|
||||
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where from_pk is null"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "2" ]] || false
|
||||
|
||||
dolt sql -q "UPDATE t SET pk = 3 WHERE pk = 2"
|
||||
dolt commit -am "add 3"
|
||||
|
||||
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where from_pk is not null"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "1" ]] || false
|
||||
}
|
||||
|
||||
@test "sql: dolt diff table correctly works with datetime comparisons" {
|
||||
dolt sql -q "CREATE TABLE t(pk int primary key);"
|
||||
dolt add .
|
||||
dolt commit -m "new table t"
|
||||
dolt sql -q "INSERT INTO t VALUES (1), (2), (3)"
|
||||
dolt commit -am "add 1, 2, 3"
|
||||
|
||||
# adds a row and removes a row
|
||||
dolt sql -q "UPDATE t SET pk = 4 WHERE pk = 2"
|
||||
|
||||
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where to_commit_date is not null"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "3" ]] || false
|
||||
|
||||
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where to_commit_date < now()"
|
||||
[ "$status" -eq 0 ]
|
||||
[[ "$output" =~ "3" ]] || false
|
||||
}
|
||||
|
||||
@test "sql: sql print on order by returns the correct result" {
|
||||
dolt sql -q "CREATE TABLE mytable(pk int primary key);"
|
||||
dolt sql -q "INSERT INTO mytable VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10),(11),(12),(13),(14),(15),(16),(17),(18),(19),(20)"
|
||||
|
||||
22
integration-tests/bats/validation.bats
Normal file
22
integration-tests/bats/validation.bats
Normal file
@@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env bats
|
||||
load $BATS_TEST_DIRNAME/helper/common.bash
|
||||
|
||||
setup() {
|
||||
setup_common
|
||||
}
|
||||
|
||||
teardown() {
|
||||
teardown_common
|
||||
}
|
||||
|
||||
# Validation is a set of tests that validate various things about dolt
|
||||
# that have nothing to do with product functionality directly.
|
||||
|
||||
@test "validation: no test symbols in binary" {
|
||||
run grep_for_testify
|
||||
[ "$output" = "" ]
|
||||
}
|
||||
|
||||
grep_for_testify() {
|
||||
strings `which dolt` | grep testify
|
||||
}
|
||||
@@ -562,6 +562,7 @@ tests:
|
||||
- exec: "create table more_vals (i int primary key)"
|
||||
error_match: "repo1 is read-only"
|
||||
- on: server2
|
||||
retry_attempts: 100
|
||||
queries:
|
||||
- query: "SELECT @@GLOBAL.dolt_cluster_role,@@GLOBAL.dolt_cluster_role_epoch"
|
||||
result:
|
||||
|
||||
9
integration-tests/orm-tests/mikro-orm/README.md
Normal file
9
integration-tests/orm-tests/mikro-orm/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Mikro-ORM Smoke Test
|
||||
|
||||
The `index.ts` file is the main entry point and will insert a new record into the database, then load it, print
|
||||
success, and exit with a zero exit code. If any errors are encountered, they are logged, and the process exits with a
|
||||
non-zero exit code.
|
||||
|
||||
To run this smoke test project:
|
||||
1. Run `npm install` command
|
||||
2. Run `npm start` command
|
||||
21
integration-tests/orm-tests/mikro-orm/package.json
Normal file
21
integration-tests/orm-tests/mikro-orm/package.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"name": "mikro-orm-smoketest",
|
||||
"version": "0.0.1",
|
||||
"description": "DoltDB smoke test for Mikro-ORM integration",
|
||||
"type": "commonjs",
|
||||
"scripts": {
|
||||
"start": "ts-node src/index.ts",
|
||||
"mikro-orm": "mikro-orm-ts-node-commonjs"
|
||||
},
|
||||
"devDependencies": {
|
||||
"ts-node": "^10.7.0",
|
||||
"@types/node": "^16.11.10",
|
||||
"typescript": "^4.5.2"
|
||||
},
|
||||
"dependencies": {
|
||||
"@mikro-orm/core": "^5.0.3",
|
||||
"@mikro-orm/mysql": "^5.0.3",
|
||||
"mysql": "^2.14.1"
|
||||
}
|
||||
}
|
||||
|
||||
22
integration-tests/orm-tests/mikro-orm/src/entity/User.ts
Normal file
22
integration-tests/orm-tests/mikro-orm/src/entity/User.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { Entity, PrimaryKey, Property } from "@mikro-orm/core";
|
||||
|
||||
@Entity()
|
||||
export class User {
|
||||
@PrimaryKey()
|
||||
id!: number;
|
||||
|
||||
@Property()
|
||||
firstName!: string;
|
||||
|
||||
@Property()
|
||||
lastName!: string;
|
||||
|
||||
@Property()
|
||||
age!: number;
|
||||
|
||||
constructor(firstName: string, lastName: string, age: number) {
|
||||
this.firstName = firstName;
|
||||
this.lastName = lastName;
|
||||
this.age = age;
|
||||
}
|
||||
}
|
||||
43
integration-tests/orm-tests/mikro-orm/src/index.ts
Normal file
43
integration-tests/orm-tests/mikro-orm/src/index.ts
Normal file
@@ -0,0 +1,43 @@
|
||||
import { MikroORM } from "@mikro-orm/core";
|
||||
import { MySqlDriver } from '@mikro-orm/mysql';
|
||||
import { User } from "./entity/User";
|
||||
|
||||
async function connectAndGetOrm() {
|
||||
const orm = await MikroORM.init<MySqlDriver>({
|
||||
entities: [User],
|
||||
type: "mysql",
|
||||
clientUrl: "mysql://localhost:3306",
|
||||
dbName: "dolt",
|
||||
user: "dolt",
|
||||
password: "",
|
||||
persistOnCreate: true,
|
||||
});
|
||||
|
||||
return orm;
|
||||
}
|
||||
|
||||
connectAndGetOrm().then(async orm => {
|
||||
console.log("Connected");
|
||||
const em = orm.em.fork();
|
||||
|
||||
// this creates the tables if not exist
|
||||
const generator = orm.getSchemaGenerator();
|
||||
await generator.updateSchema();
|
||||
|
||||
console.log("Inserting a new user into the database...")
|
||||
const user = new User("Timber", "Saw", 25)
|
||||
await em.persistAndFlush(user)
|
||||
console.log("Saved a new user with id: " + user.id)
|
||||
|
||||
console.log("Loading users from the database...")
|
||||
const users = await em.findOne(User, 1)
|
||||
console.log("Loaded users: ", users)
|
||||
|
||||
orm.close();
|
||||
console.log("Smoke test passed!")
|
||||
process.exit(0)
|
||||
}).catch(error => {
|
||||
console.log(error)
|
||||
console.log("Smoke test failed!")
|
||||
process.exit(1)
|
||||
});
|
||||
14
integration-tests/orm-tests/mikro-orm/tsconfig.json
Normal file
14
integration-tests/orm-tests/mikro-orm/tsconfig.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"module": "commonjs",
|
||||
"declaration": true,
|
||||
"removeComments": true,
|
||||
"emitDecoratorMetadata": true,
|
||||
"esModuleInterop": true,
|
||||
"experimentalDecorators": true,
|
||||
"target": "es2017",
|
||||
"outDir": "./dist",
|
||||
"baseUrl": "./src",
|
||||
"incremental": true,
|
||||
}
|
||||
}
|
||||
@@ -55,7 +55,7 @@ teardown() {
|
||||
npx -c "prisma migrate dev --name init"
|
||||
}
|
||||
|
||||
# Prisma is an ORM for Node/TypeScript applications. This test checks out the Peewee test suite
|
||||
# Prisma is an ORM for Node/TypeScript applications. This test checks out the Prisma test suite
|
||||
# and runs it against Dolt.
|
||||
@test "Prisma ORM test suite" {
|
||||
skip "Not implemented yet"
|
||||
@@ -77,6 +77,16 @@ teardown() {
|
||||
npm start
|
||||
}
|
||||
|
||||
# MikroORM is an ORM for Node/TypeScript applications. This is a simple smoke test to make sure
|
||||
# Dolt can support the most basic MikroORM operations.
|
||||
@test "MikroORM smoke test" {
|
||||
mysql --protocol TCP -u dolt -e "create database dolt;"
|
||||
|
||||
cd mikro-orm
|
||||
npm install
|
||||
npm start
|
||||
}
|
||||
|
||||
# Turn this test on to prevent the container from exiting if you need to exec a shell into
|
||||
# the container to debug failed tests.
|
||||
#@test "Pause container for an hour to debug failures" {
|
||||
|
||||
Reference in New Issue
Block a user