Merge remote-tracking branch 'origin/main' into dhruv/column-truncation

This commit is contained in:
Dhruv Sringari
2022-11-16 10:36:16 -08:00
68 changed files with 2276 additions and 1229 deletions

View File

@@ -321,9 +321,8 @@ func getCommitMessageFromEditor(ctx context.Context, dEnv *env.DoltEnv, suggeste
finalMsg = parseCommitMessage(commitMsg)
})
// if editor could not be opened or the message received is empty, use auto-generated/suggested msg.
if err != nil || finalMsg == "" {
return suggestedMsg, nil
if err != nil {
return "", err
}
return finalMsg, nil

View File

@@ -130,8 +130,6 @@ func NewSqlEngine(
if bcController, err = branch_control.LoadData(config.BranchCtrlFilePath, config.DoltCfgDirPath); err != nil {
return nil, err
}
// Set the server's super user
branch_control.SetSuperUser(config.ServerUser, config.ServerHost)
// Set up engine
engine := gms.New(analyzer.NewBuilder(pro).WithParallelism(parallelism).Build(), &gms.Config{

View File

@@ -159,9 +159,9 @@ func (cmd MergeCmd) Exec(ctx context.Context, commandStr string, args []string,
}
suggestedMsg := fmt.Sprintf("Merge branch '%s' into %s", commitSpecStr, dEnv.RepoStateReader().CWBHeadRef().GetPath())
msg, err := getCommitMessage(ctx, apr, dEnv, suggestedMsg)
if err != nil {
return handleCommitErr(ctx, dEnv, err, usage)
msg := ""
if m, ok := apr.GetValue(cli.MessageArg); ok {
msg = m
}
if apr.Contains(cli.NoCommitFlag) && apr.Contains(cli.CommitFlag) {
@@ -214,25 +214,6 @@ func getUnmergedTableCount(ctx context.Context, root *doltdb.RootValue) (int, er
return unmergedTableCount, nil
}
// getCommitMessage returns commit message depending on whether user defined commit message and/or no-ff flag.
// If user defined message, it will use that. If not, and no-ff flag is defined, it will ask user for commit message from editor.
// If none of commit message or no-ff flag is defined, it will return suggested message.
func getCommitMessage(ctx context.Context, apr *argparser.ArgParseResults, dEnv *env.DoltEnv, suggestedMsg string) (string, errhand.VerboseError) {
if m, ok := apr.GetValue(cli.MessageArg); ok {
return m, nil
}
if apr.Contains(cli.NoFFParam) {
msg, err := getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", apr.Contains(cli.NoEditFlag))
if err != nil {
return "", errhand.VerboseErrorFromError(err)
}
return msg, nil
}
return "", nil
}
func validateMergeSpec(ctx context.Context, spec *merge.MergeSpec) errhand.VerboseError {
if spec.HeadH == spec.MergeH {
//TODO - why is this different for merge/pull?
@@ -484,38 +465,117 @@ func handleMergeErr(ctx context.Context, dEnv *env.DoltEnv, mergeErr error, hasC
// FF merges will not surface constraint violations on their own; constraint verify --all
// is required to reify violations.
func performMerge(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
var tblStats map[string]*merge.MergeStats
if ok, err := spec.HeadC.CanFastForwardTo(ctx, spec.MergeC); err != nil && !errors.Is(err, doltdb.ErrUpToDate) {
return nil, err
} else if ok {
if spec.Noff {
tblStats, err = merge.ExecNoFFMerge(ctx, dEnv, spec)
return tblStats, err
return executeNoFFMergeAndCommit(ctx, dEnv, spec, suggestedMsg)
}
return nil, merge.ExecuteFFMerge(ctx, dEnv, spec)
}
tblStats, err := merge.ExecuteMerge(ctx, dEnv, spec)
return executeMergeAndCommit(ctx, dEnv, spec, suggestedMsg)
}
func executeNoFFMergeAndCommit(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
tblToStats, err := merge.ExecNoFFMerge(ctx, dEnv, spec)
if err != nil {
return tblStats, err
return tblToStats, err
}
if !spec.NoCommit && !hasConflictOrViolations(tblStats) {
msg := spec.Msg
if spec.Msg == "" {
msg, err = getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", spec.NoEdit)
if err != nil {
return nil, err
}
}
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
res := performCommit(ctx, "commit", []string{"-m", msg, "--author", author}, dEnv)
if res != 0 {
return nil, fmt.Errorf("dolt commit failed after merging")
}
if spec.NoCommit {
cli.Println("Automatic merge went well; stopped before committing as requested")
return tblToStats, nil
}
return tblStats, nil
// Reload roots since the above method writes new values to the working set
roots, err := dEnv.Roots(ctx)
if err != nil {
return tblToStats, err
}
ws, err := dEnv.WorkingSet(ctx)
if err != nil {
return tblToStats, err
}
var mergeParentCommits []*doltdb.Commit
if ws.MergeActive() {
mergeParentCommits = []*doltdb.Commit{ws.MergeState().Commit()}
}
msg, err := getCommitMsgForMerge(ctx, dEnv, spec.Msg, suggestedMsg, spec.NoEdit)
if err != nil {
return tblToStats, err
}
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{
Message: msg,
Date: spec.Date,
AllowEmpty: spec.AllowEmpty,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
})
if err != nil {
return tblToStats, fmt.Errorf("%w; failed to commit", err)
}
err = dEnv.ClearMerge(ctx)
if err != nil {
return tblToStats, err
}
return tblToStats, err
}
func executeMergeAndCommit(ctx context.Context, dEnv *env.DoltEnv, spec *merge.MergeSpec, suggestedMsg string) (map[string]*merge.MergeStats, error) {
tblToStats, err := merge.ExecuteMerge(ctx, dEnv, spec)
if err != nil {
return tblToStats, err
}
if hasConflictOrViolations(tblToStats) {
return tblToStats, nil
}
if spec.NoCommit {
cli.Println("Automatic merge went well; stopped before committing as requested")
return tblToStats, nil
}
msg, err := getCommitMsgForMerge(ctx, dEnv, spec.Msg, suggestedMsg, spec.NoEdit)
if err != nil {
return tblToStats, err
}
author := fmt.Sprintf("%s <%s>", spec.Name, spec.Email)
res := performCommit(ctx, "commit", []string{"-m", msg, "--author", author}, dEnv)
if res != 0 {
return nil, fmt.Errorf("dolt commit failed after merging")
}
return tblToStats, nil
}
// getCommitMsgForMerge returns user defined message if exists; otherwise, get the commit message from editor.
func getCommitMsgForMerge(ctx context.Context, dEnv *env.DoltEnv, userDefinedMsg, suggestedMsg string, noEdit bool) (string, error) {
if userDefinedMsg != "" {
return userDefinedMsg, nil
}
msg, err := getCommitMessageFromEditor(ctx, dEnv, suggestedMsg, "", noEdit)
if err != nil {
return msg, err
}
if msg == "" {
return msg, fmt.Errorf("error: Empty commit message.\n" +
"Not committing merge; use 'dolt commit' to complete the merge.")
}
return msg, nil
}
// hasConflictOrViolations checks for conflicts or constraint violation regardless of a table being modified

View File

@@ -112,7 +112,7 @@ func pullHelper(ctx context.Context, dEnv *env.DoltEnv, pullSpec *env.PullSpec)
// Fetch all references
branchRefs, err := srcDB.GetHeadRefs(ctx)
if err != nil {
return env.ErrFailedToReadDb
return fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
}
hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath())

View File

@@ -285,12 +285,11 @@ func Serve(
lgr.Errorf("error starting remotesapi server listeners for cluster config on port %d: %v", clusterController.RemoteSrvPort(), err)
startError = err
return
} else {
go func() {
clusterRemoteSrv.Serve(listeners)
}()
}
go clusterRemoteSrv.Serve(listeners)
go clusterController.Run()
clusterController.ManageQueryConnections(
mySQLServer.SessionManager().Iter,
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
@@ -323,6 +322,9 @@ func Serve(
if clusterRemoteSrv != nil {
clusterRemoteSrv.GracefulStop()
}
if clusterController != nil {
clusterController.GracefulStop()
}
return mySQLServer.Close()
})

View File

@@ -182,7 +182,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
cli.PrintErrln(color.RedString(err.Error()))
return 1
}
dbToUse = filepath.Base(directory)
dbToUse = strings.Replace(filepath.Base(directory), "-", "_", -1)
}
format := engine.FormatTabular
if hasResultFormat {

View File

@@ -62,7 +62,8 @@ const (
primaryKeyParam = "pk"
fileTypeParam = "file-type"
delimParam = "delim"
ignoreSkippedRows = "ignore-skipped-rows"
quiet = "quiet"
ignoreSkippedRows = "ignore-skipped-rows" // alias for quiet
disableFkChecks = "disable-fk-checks"
)
@@ -74,7 +75,7 @@ The schema for the new table can be specified explicitly by providing a SQL sche
If {{.EmphasisLeft}}--update-table | -u{{.EmphasisRight}} is given the operation will update {{.LessThan}}table{{.GreaterThan}} with the contents of file. The table's existing schema will be used, and field names will be used to match file fields with table fields unless a mapping file is specified.
During import, if there is an error importing any row, the import will be aborted by default. Use the {{.EmphasisLeft}}--continue{{.EmphasisRight}} flag to continue importing when an error is encountered. You can add the {{.EmphasisLeft}}--ignore-skipped-rows{{.EmphasisRight}} flag to prevent the import utility from printing all the skipped rows.
During import, if there is an error importing any row, the import will be aborted by default. Use the {{.EmphasisLeft}}--continue{{.EmphasisRight}} flag to continue importing when an error is encountered. You can add the {{.EmphasisLeft}}--quiet{{.EmphasisRight}} flag to prevent the import utility from printing all the skipped rows.
If {{.EmphasisLeft}}--replace-table | -r{{.EmphasisRight}} is given the operation will replace {{.LessThan}}table{{.GreaterThan}} with the contents of the file. The table's existing schema will be used, and field names will be used to match file fields with table fields unless a mapping file is specified.
@@ -87,8 +88,8 @@ A mapping file can be used to map fields between the file being imported and the
In create, update, and replace scenarios the file's extension is used to infer the type of the file. If a file does not have the expected extension then the {{.EmphasisLeft}}--file-type{{.EmphasisRight}} parameter should be used to explicitly define the format of the file in one of the supported formats (csv, psv, json, xlsx). For files separated by a delimiter other than a ',' (type csv) or a '|' (type psv), the --delim parameter can be used to specify a delimiter`,
Synopsis: []string{
"-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--disable-fk-checks] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
"-u [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--ignore-skipped-rows] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
"-c [-f] [--pk {{.LessThan}}field{{.GreaterThan}}] [--schema {{.LessThan}}file{{.GreaterThan}}] [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--quiet] [--disable-fk-checks] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
"-u [--map {{.LessThan}}file{{.GreaterThan}}] [--continue] [--quiet] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
"-r [--map {{.LessThan}}file{{.GreaterThan}}] [--file-type {{.LessThan}}type{{.GreaterThan}}] {{.LessThan}}table{{.GreaterThan}} {{.LessThan}}file{{.GreaterThan}}",
},
}
@@ -96,17 +97,17 @@ In create, update, and replace scenarios the file's extension is used to infer t
var bitTypeRegex = regexp.MustCompile(`(?m)b\'(\d+)\'`)
type importOptions struct {
operation mvdata.TableImportOp
destTableName string
contOnErr bool
force bool
schFile string
primaryKeys []string
nameMapper rowconv.NameMapper
src mvdata.DataLocation
srcOptions interface{}
ignoreSkippedRows bool
disableFkChecks bool
operation mvdata.TableImportOp
destTableName string
contOnErr bool
force bool
schFile string
primaryKeys []string
nameMapper rowconv.NameMapper
src mvdata.DataLocation
srcOptions interface{}
quiet bool
disableFkChecks bool
}
func (m importOptions) IsBatched() bool {
@@ -168,7 +169,7 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d
schemaFile, _ := apr.GetValue(schemaParam)
force := apr.Contains(forceParam)
contOnErr := apr.Contains(contOnErrParam)
ignore := apr.Contains(ignoreSkippedRows)
quiet := apr.Contains(quiet)
disableFks := apr.Contains(disableFkChecks)
val, _ := apr.GetValue(primaryKeyParam)
@@ -238,17 +239,17 @@ func getImportMoveOptions(ctx context.Context, apr *argparser.ArgParseResults, d
}
return &importOptions{
operation: moveOp,
destTableName: tableName,
contOnErr: contOnErr,
force: force,
schFile: schemaFile,
nameMapper: colMapper,
primaryKeys: pks,
src: srcLoc,
srcOptions: srcOpts,
ignoreSkippedRows: ignore,
disableFkChecks: disableFks,
operation: moveOp,
destTableName: tableName,
contOnErr: contOnErr,
force: force,
schFile: schemaFile,
nameMapper: colMapper,
primaryKeys: pks,
src: srcLoc,
srcOptions: srcOpts,
quiet: quiet,
disableFkChecks: disableFks,
}, nil
}
@@ -337,7 +338,8 @@ func (cmd ImportCmd) ArgParser() *argparser.ArgParser {
ap.SupportsFlag(forceParam, "f", "If a create operation is being executed, data already exists in the destination, the force flag will allow the target to be overwritten.")
ap.SupportsFlag(replaceParam, "r", "Replace existing table with imported data while preserving the original schema.")
ap.SupportsFlag(contOnErrParam, "", "Continue importing when row import errors are encountered.")
ap.SupportsFlag(ignoreSkippedRows, "", "Ignore the skipped rows printed by the --continue flag.")
ap.SupportsFlag(quiet, "", "Suppress any warning messages about invalid rows when using the --continue flag.")
ap.SupportsAlias(ignoreSkippedRows, quiet)
ap.SupportsFlag(disableFkChecks, "", "Disables foreign key checks.")
ap.SupportsString(schemaParam, "s", "schema_file", "The schema for the output data.")
ap.SupportsString(mappingFileParam, "m", "mapping_file", "A file that lays out how fields should be mapped from input data to output data.")
@@ -524,8 +526,8 @@ func move(ctx context.Context, rd table.SqlRowReader, wr *mvdata.SqlEngineTableW
return true
}
// Don't log the skipped rows when the ignore-skipped-rows param is specified.
if options.ignoreSkippedRows {
// Don't log the skipped rows when asked to suppress warning output
if options.quiet {
return false
}

View File

@@ -1,4 +1,4 @@
// Copyright 2019 Dolthub, Inc.
// Copyright 2019-2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -12,5 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// dolt is a command line tool for working with dolt data repositories stored in noms.
// dolt is the command line tool for working with Dolt databases.
package main

View File

@@ -57,7 +57,7 @@ import (
)
const (
Version = "0.50.15"
Version = "0.51.1"
)
var dumpDocsCommand = &commands.DumpDocsCmd{}

View File

@@ -210,7 +210,7 @@ func (rcv *BranchControlAccess) TryBinlog(obj *BranchControlBinlog) (*BranchCont
return nil, nil
}
func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j int) bool {
func (rcv *BranchControlAccess) Databases(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
x := rcv._tab.Vector(o)
@@ -222,7 +222,7 @@ func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j in
return false
}
func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
func (rcv *BranchControlAccess) TryDatabases(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
x := rcv._tab.Vector(o)
@@ -237,8 +237,43 @@ func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j
return false, nil
}
func (rcv *BranchControlAccess) DatabasesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
return rcv._tab.VectorLen(o)
}
return 0
}
func (rcv *BranchControlAccess) Branches(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
x = rcv._tab.Indirect(x)
obj.Init(rcv._tab.Bytes, x)
return true
}
return false
}
func (rcv *BranchControlAccess) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
x = rcv._tab.Indirect(x)
obj.Init(rcv._tab.Bytes, x)
if BranchControlMatchExpressionNumFields < obj.Table().NumFields() {
return false, flatbuffers.ErrTableHasUnknownFields
}
return true, nil
}
return false, nil
}
func (rcv *BranchControlAccess) BranchesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -246,7 +281,7 @@ func (rcv *BranchControlAccess) BranchesLength() int {
}
func (rcv *BranchControlAccess) Users(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -258,7 +293,7 @@ func (rcv *BranchControlAccess) Users(obj *BranchControlMatchExpression, j int)
}
func (rcv *BranchControlAccess) TryUsers(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -273,7 +308,7 @@ func (rcv *BranchControlAccess) TryUsers(obj *BranchControlMatchExpression, j in
}
func (rcv *BranchControlAccess) UsersLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -281,7 +316,7 @@ func (rcv *BranchControlAccess) UsersLength() int {
}
func (rcv *BranchControlAccess) Hosts(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -293,7 +328,7 @@ func (rcv *BranchControlAccess) Hosts(obj *BranchControlMatchExpression, j int)
}
func (rcv *BranchControlAccess) TryHosts(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -308,7 +343,7 @@ func (rcv *BranchControlAccess) TryHosts(obj *BranchControlMatchExpression, j in
}
func (rcv *BranchControlAccess) HostsLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -316,7 +351,7 @@ func (rcv *BranchControlAccess) HostsLength() int {
}
func (rcv *BranchControlAccess) Values(obj *BranchControlAccessValue, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -328,7 +363,7 @@ func (rcv *BranchControlAccess) Values(obj *BranchControlAccessValue, j int) boo
}
func (rcv *BranchControlAccess) TryValues(obj *BranchControlAccessValue, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -343,14 +378,14 @@ func (rcv *BranchControlAccess) TryValues(obj *BranchControlAccessValue, j int)
}
func (rcv *BranchControlAccess) ValuesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
return rcv._tab.VectorLen(o)
}
return 0
}
const BranchControlAccessNumFields = 5
const BranchControlAccessNumFields = 6
func BranchControlAccessStart(builder *flatbuffers.Builder) {
builder.StartObject(BranchControlAccessNumFields)
@@ -358,26 +393,32 @@ func BranchControlAccessStart(builder *flatbuffers.Builder) {
func BranchControlAccessAddBinlog(builder *flatbuffers.Builder, binlog flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(binlog), 0)
}
func BranchControlAccessAddDatabases(builder *flatbuffers.Builder, databases flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(databases), 0)
}
func BranchControlAccessStartDatabasesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlAccessAddBranches(builder *flatbuffers.Builder, branches flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branches), 0)
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branches), 0)
}
func BranchControlAccessStartBranchesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlAccessAddUsers(builder *flatbuffers.Builder, users flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(users), 0)
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(users), 0)
}
func BranchControlAccessStartUsersVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlAccessAddHosts(builder *flatbuffers.Builder, hosts flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(hosts), 0)
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(hosts), 0)
}
func BranchControlAccessStartHostsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlAccessAddValues(builder *flatbuffers.Builder, values flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(values), 0)
builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(values), 0)
}
func BranchControlAccessStartValuesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
@@ -430,7 +471,7 @@ func (rcv *BranchControlAccessValue) Table() flatbuffers.Table {
return rcv._tab
}
func (rcv *BranchControlAccessValue) Branch() []byte {
func (rcv *BranchControlAccessValue) Database() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -438,7 +479,7 @@ func (rcv *BranchControlAccessValue) Branch() []byte {
return nil
}
func (rcv *BranchControlAccessValue) User() []byte {
func (rcv *BranchControlAccessValue) Branch() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -446,7 +487,7 @@ func (rcv *BranchControlAccessValue) User() []byte {
return nil
}
func (rcv *BranchControlAccessValue) Host() []byte {
func (rcv *BranchControlAccessValue) User() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -454,8 +495,16 @@ func (rcv *BranchControlAccessValue) Host() []byte {
return nil
}
func (rcv *BranchControlAccessValue) Permissions() uint64 {
func (rcv *BranchControlAccessValue) Host() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
return nil
}
func (rcv *BranchControlAccessValue) Permissions() uint64 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
return rcv._tab.GetUint64(o + rcv._tab.Pos)
}
@@ -463,25 +512,28 @@ func (rcv *BranchControlAccessValue) Permissions() uint64 {
}
func (rcv *BranchControlAccessValue) MutatePermissions(n uint64) bool {
return rcv._tab.MutateUint64Slot(10, n)
return rcv._tab.MutateUint64Slot(12, n)
}
const BranchControlAccessValueNumFields = 4
const BranchControlAccessValueNumFields = 5
func BranchControlAccessValueStart(builder *flatbuffers.Builder) {
builder.StartObject(BranchControlAccessValueNumFields)
}
func BranchControlAccessValueAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(database), 0)
}
func BranchControlAccessValueAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(branch), 0)
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
}
func BranchControlAccessValueAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(user), 0)
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
}
func BranchControlAccessValueAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(host), 0)
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
}
func BranchControlAccessValueAddPermissions(builder *flatbuffers.Builder, permissions uint64) {
builder.PrependUint64Slot(3, permissions, 0)
builder.PrependUint64Slot(4, permissions, 0)
}
func BranchControlAccessValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
@@ -560,7 +612,7 @@ func (rcv *BranchControlNamespace) TryBinlog(obj *BranchControlBinlog) (*BranchC
return nil, nil
}
func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j int) bool {
func (rcv *BranchControlNamespace) Databases(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
x := rcv._tab.Vector(o)
@@ -572,7 +624,7 @@ func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j
return false
}
func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
func (rcv *BranchControlNamespace) TryDatabases(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
x := rcv._tab.Vector(o)
@@ -587,8 +639,43 @@ func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression
return false, nil
}
func (rcv *BranchControlNamespace) DatabasesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
return rcv._tab.VectorLen(o)
}
return 0
}
func (rcv *BranchControlNamespace) Branches(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
x = rcv._tab.Indirect(x)
obj.Init(rcv._tab.Bytes, x)
return true
}
return false
}
func (rcv *BranchControlNamespace) TryBranches(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
x = rcv._tab.Indirect(x)
obj.Init(rcv._tab.Bytes, x)
if BranchControlMatchExpressionNumFields < obj.Table().NumFields() {
return false, flatbuffers.ErrTableHasUnknownFields
}
return true, nil
}
return false, nil
}
func (rcv *BranchControlNamespace) BranchesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -596,7 +683,7 @@ func (rcv *BranchControlNamespace) BranchesLength() int {
}
func (rcv *BranchControlNamespace) Users(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -608,7 +695,7 @@ func (rcv *BranchControlNamespace) Users(obj *BranchControlMatchExpression, j in
}
func (rcv *BranchControlNamespace) TryUsers(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -623,7 +710,7 @@ func (rcv *BranchControlNamespace) TryUsers(obj *BranchControlMatchExpression, j
}
func (rcv *BranchControlNamespace) UsersLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -631,7 +718,7 @@ func (rcv *BranchControlNamespace) UsersLength() int {
}
func (rcv *BranchControlNamespace) Hosts(obj *BranchControlMatchExpression, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -643,7 +730,7 @@ func (rcv *BranchControlNamespace) Hosts(obj *BranchControlMatchExpression, j in
}
func (rcv *BranchControlNamespace) TryHosts(obj *BranchControlMatchExpression, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -658,7 +745,7 @@ func (rcv *BranchControlNamespace) TryHosts(obj *BranchControlMatchExpression, j
}
func (rcv *BranchControlNamespace) HostsLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
return rcv._tab.VectorLen(o)
}
@@ -666,7 +753,7 @@ func (rcv *BranchControlNamespace) HostsLength() int {
}
func (rcv *BranchControlNamespace) Values(obj *BranchControlNamespaceValue, j int) bool {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -678,7 +765,7 @@ func (rcv *BranchControlNamespace) Values(obj *BranchControlNamespaceValue, j in
}
func (rcv *BranchControlNamespace) TryValues(obj *BranchControlNamespaceValue, j int) (bool, error) {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
x := rcv._tab.Vector(o)
x += flatbuffers.UOffsetT(j) * 4
@@ -693,14 +780,14 @@ func (rcv *BranchControlNamespace) TryValues(obj *BranchControlNamespaceValue, j
}
func (rcv *BranchControlNamespace) ValuesLength() int {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
return rcv._tab.VectorLen(o)
}
return 0
}
const BranchControlNamespaceNumFields = 5
const BranchControlNamespaceNumFields = 6
func BranchControlNamespaceStart(builder *flatbuffers.Builder) {
builder.StartObject(BranchControlNamespaceNumFields)
@@ -708,26 +795,32 @@ func BranchControlNamespaceStart(builder *flatbuffers.Builder) {
func BranchControlNamespaceAddBinlog(builder *flatbuffers.Builder, binlog flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(binlog), 0)
}
func BranchControlNamespaceAddDatabases(builder *flatbuffers.Builder, databases flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(databases), 0)
}
func BranchControlNamespaceStartDatabasesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlNamespaceAddBranches(builder *flatbuffers.Builder, branches flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branches), 0)
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branches), 0)
}
func BranchControlNamespaceStartBranchesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlNamespaceAddUsers(builder *flatbuffers.Builder, users flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(users), 0)
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(users), 0)
}
func BranchControlNamespaceStartUsersVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlNamespaceAddHosts(builder *flatbuffers.Builder, hosts flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(hosts), 0)
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(hosts), 0)
}
func BranchControlNamespaceStartHostsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
}
func BranchControlNamespaceAddValues(builder *flatbuffers.Builder, values flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(values), 0)
builder.PrependUOffsetTSlot(5, flatbuffers.UOffsetT(values), 0)
}
func BranchControlNamespaceStartValuesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT {
return builder.StartVector(4, numElems, 4)
@@ -780,7 +873,7 @@ func (rcv *BranchControlNamespaceValue) Table() flatbuffers.Table {
return rcv._tab
}
func (rcv *BranchControlNamespaceValue) Branch() []byte {
func (rcv *BranchControlNamespaceValue) Database() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(4))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -788,7 +881,7 @@ func (rcv *BranchControlNamespaceValue) Branch() []byte {
return nil
}
func (rcv *BranchControlNamespaceValue) User() []byte {
func (rcv *BranchControlNamespaceValue) Branch() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -796,7 +889,7 @@ func (rcv *BranchControlNamespaceValue) User() []byte {
return nil
}
func (rcv *BranchControlNamespaceValue) Host() []byte {
func (rcv *BranchControlNamespaceValue) User() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -804,19 +897,30 @@ func (rcv *BranchControlNamespaceValue) Host() []byte {
return nil
}
const BranchControlNamespaceValueNumFields = 3
func (rcv *BranchControlNamespaceValue) Host() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
return nil
}
const BranchControlNamespaceValueNumFields = 4
func BranchControlNamespaceValueStart(builder *flatbuffers.Builder) {
builder.StartObject(BranchControlNamespaceValueNumFields)
}
func BranchControlNamespaceValueAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(database), 0)
}
func BranchControlNamespaceValueAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(0, flatbuffers.UOffsetT(branch), 0)
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
}
func BranchControlNamespaceValueAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(user), 0)
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
}
func BranchControlNamespaceValueAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(host), 0)
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
}
func BranchControlNamespaceValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()
@@ -972,7 +1076,7 @@ func (rcv *BranchControlBinlogRow) MutateIsInsert(n bool) bool {
return rcv._tab.MutateBoolSlot(4, n)
}
func (rcv *BranchControlBinlogRow) Branch() []byte {
func (rcv *BranchControlBinlogRow) Database() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(6))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -980,7 +1084,7 @@ func (rcv *BranchControlBinlogRow) Branch() []byte {
return nil
}
func (rcv *BranchControlBinlogRow) User() []byte {
func (rcv *BranchControlBinlogRow) Branch() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(8))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -988,7 +1092,7 @@ func (rcv *BranchControlBinlogRow) User() []byte {
return nil
}
func (rcv *BranchControlBinlogRow) Host() []byte {
func (rcv *BranchControlBinlogRow) User() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(10))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
@@ -996,8 +1100,16 @@ func (rcv *BranchControlBinlogRow) Host() []byte {
return nil
}
func (rcv *BranchControlBinlogRow) Permissions() uint64 {
func (rcv *BranchControlBinlogRow) Host() []byte {
o := flatbuffers.UOffsetT(rcv._tab.Offset(12))
if o != 0 {
return rcv._tab.ByteVector(o + rcv._tab.Pos)
}
return nil
}
func (rcv *BranchControlBinlogRow) Permissions() uint64 {
o := flatbuffers.UOffsetT(rcv._tab.Offset(14))
if o != 0 {
return rcv._tab.GetUint64(o + rcv._tab.Pos)
}
@@ -1005,10 +1117,10 @@ func (rcv *BranchControlBinlogRow) Permissions() uint64 {
}
func (rcv *BranchControlBinlogRow) MutatePermissions(n uint64) bool {
return rcv._tab.MutateUint64Slot(12, n)
return rcv._tab.MutateUint64Slot(14, n)
}
const BranchControlBinlogRowNumFields = 5
const BranchControlBinlogRowNumFields = 6
func BranchControlBinlogRowStart(builder *flatbuffers.Builder) {
builder.StartObject(BranchControlBinlogRowNumFields)
@@ -1016,17 +1128,20 @@ func BranchControlBinlogRowStart(builder *flatbuffers.Builder) {
func BranchControlBinlogRowAddIsInsert(builder *flatbuffers.Builder, isInsert bool) {
builder.PrependBoolSlot(0, isInsert, false)
}
func BranchControlBinlogRowAddDatabase(builder *flatbuffers.Builder, database flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(database), 0)
}
func BranchControlBinlogRowAddBranch(builder *flatbuffers.Builder, branch flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(1, flatbuffers.UOffsetT(branch), 0)
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(branch), 0)
}
func BranchControlBinlogRowAddUser(builder *flatbuffers.Builder, user flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(2, flatbuffers.UOffsetT(user), 0)
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(user), 0)
}
func BranchControlBinlogRowAddHost(builder *flatbuffers.Builder, host flatbuffers.UOffsetT) {
builder.PrependUOffsetTSlot(3, flatbuffers.UOffsetT(host), 0)
builder.PrependUOffsetTSlot(4, flatbuffers.UOffsetT(host), 0)
}
func BranchControlBinlogRowAddPermissions(builder *flatbuffers.Builder, permissions uint64) {
builder.PrependUint64Slot(4, permissions, 0)
builder.PrependUint64Slot(5, permissions, 0)
}
func BranchControlBinlogRowEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT {
return builder.EndObject()

View File

@@ -58,7 +58,7 @@ require (
github.com/cenkalti/backoff/v4 v4.1.3
github.com/cespare/xxhash v1.1.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0
github.com/google/flatbuffers v2.0.6+incompatible
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6
github.com/mitchellh/go-ps v1.0.0

View File

@@ -180,8 +180,8 @@ github.com/dolthub/flatbuffers v1.13.0-dh.1 h1:OWJdaPep22N52O/0xsUevxJ6Qfw1M2txC
github.com/dolthub/flatbuffers v1.13.0-dh.1/go.mod h1:CorYGaDmXjHz1Z7i50PYXG1Ricn31GcA2wNOTFIQAKE=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579 h1:rOV6whqkxka2wGMGD/DOgUgX0jWw/gaJwTMqJ1ye2wA=
github.com/dolthub/go-mysql-server v0.14.1-0.20221111192934-cf0c1818d579/go.mod h1:KtpU4Sf7J+SIat/nxoA733QTn3tdL34NtoGxEBFcTsA=
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0 h1:BSTAs705aUez54ENrjZCQ3px0s/17nPi58XvDlpA0sI=
github.com/dolthub/go-mysql-server v0.14.1-0.20221116004305-6af2406c5bd0/go.mod h1:KtpU4Sf7J+SIat/nxoA733QTn3tdL34NtoGxEBFcTsA=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371 h1:oyPHJlzumKta1vnOQqUnfdz+pk3EmnHS3Nd0cCT0I2g=
github.com/dolthub/ishell v0.0.0-20220112232610-14e753f0f371/go.mod h1:dhGBqcCEfK5kuFmeO5+WOx3hqc1k3M29c1oS/R7N4ms=
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=

View File

@@ -39,15 +39,17 @@ const (
type Access struct {
binlog *Binlog
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []AccessValue
RWMutex *sync.RWMutex
Databases []MatchExpression
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []AccessValue
RWMutex *sync.RWMutex
}
// AccessValue contains the user-facing values of a particular row, along with the permissions for a row.
type AccessValue struct {
Database string
Branch string
User string
Host string
@@ -57,22 +59,19 @@ type AccessValue struct {
// newAccess returns a new Access.
func newAccess() *Access {
return &Access{
binlog: NewAccessBinlog(nil),
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
RWMutex: &sync.RWMutex{},
binlog: NewAccessBinlog(nil),
Databases: nil,
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
RWMutex: &sync.RWMutex{},
}
}
// Match returns whether any entries match the given branch, user, and host, along with their permissions. Requires
// external synchronization handling, therefore manually manage the RWMutex.
func (tbl *Access) Match(branch string, user string, host string) (bool, Permissions) {
if IsSuperUser(user, host) {
return true, Permissions_Admin
}
// Match returns whether any entries match the given database, branch, user, and host, along with their permissions.
// Requires external synchronization handling, therefore manually manage the RWMutex.
func (tbl *Access) Match(database string, branch string, user string, host string) (bool, Permissions) {
filteredIndexes := Match(tbl.Users, user, sql.Collation_utf8mb4_0900_bin)
filteredHosts := tbl.filterHosts(filteredIndexes)
@@ -80,6 +79,11 @@ func (tbl *Access) Match(branch string, user string, host string) (bool, Permiss
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredHosts)
filteredDatabases := tbl.filterDatabases(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredDatabases, database, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredDatabases)
filteredBranches := tbl.filterBranches(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
@@ -93,9 +97,9 @@ func (tbl *Access) Match(branch string, user string, host string) (bool, Permiss
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
// returns -1. Assumes that the given expressions have already been folded. Requires external synchronization handling,
// therefore manually manage the RWMutex.
func (tbl *Access) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
func (tbl *Access) GetIndex(databaseExpr string, branchExpr string, userExpr string, hostExpr string) int {
for i, value := range tbl.Values {
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
if value.Database == databaseExpr && value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
return i
}
}
@@ -107,16 +111,6 @@ func (tbl *Access) GetBinlog() *Binlog {
return tbl.binlog
}
// GetSuperUser returns the server-level super user. Intended for display purposes only.
func (tbl *Access) GetSuperUser() string {
return superUser
}
// GetSuperHost returns the server-level super user's host. Intended for display purposes only.
func (tbl *Access) GetSuperHost() string {
return superHost
}
// Serialize returns the offset for the Access table written to the given builder.
func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
tbl.RWMutex.RLock()
@@ -125,11 +119,15 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
// Serialize the binlog
binlog := tbl.binlog.Serialize(b)
// Initialize field offset slices
databaseOffsets := make([]flatbuffers.UOffsetT, len(tbl.Databases))
branchOffsets := make([]flatbuffers.UOffsetT, len(tbl.Branches))
userOffsets := make([]flatbuffers.UOffsetT, len(tbl.Users))
hostOffsets := make([]flatbuffers.UOffsetT, len(tbl.Hosts))
valueOffsets := make([]flatbuffers.UOffsetT, len(tbl.Values))
// Get field offsets
for i, matchExpr := range tbl.Databases {
databaseOffsets[i] = matchExpr.Serialize(b)
}
for i, matchExpr := range tbl.Branches {
branchOffsets[i] = matchExpr.Serialize(b)
}
@@ -143,6 +141,11 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
valueOffsets[i] = val.Serialize(b)
}
// Get the field vectors
serial.BranchControlAccessStartDatabasesVector(b, len(databaseOffsets))
for i := len(databaseOffsets) - 1; i >= 0; i-- {
b.PrependUOffsetT(databaseOffsets[i])
}
databases := b.EndVector(len(databaseOffsets))
serial.BranchControlAccessStartBranchesVector(b, len(branchOffsets))
for i := len(branchOffsets) - 1; i >= 0; i-- {
b.PrependUOffsetT(branchOffsets[i])
@@ -166,6 +169,7 @@ func (tbl *Access) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
// Write the table
serial.BranchControlAccessStart(b)
serial.BranchControlAccessAddBinlog(b, binlog)
serial.BranchControlAccessAddDatabases(b, databases)
serial.BranchControlAccessAddBranches(b, branches)
serial.BranchControlAccessAddUsers(b, users)
serial.BranchControlAccessAddHosts(b, hosts)
@@ -183,7 +187,10 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
return fmt.Errorf("cannot deserialize to a non-empty access table")
}
// Verify that all fields have the same length
if fb.BranchesLength() != fb.UsersLength() || fb.UsersLength() != fb.HostsLength() || fb.HostsLength() != fb.ValuesLength() {
if fb.DatabasesLength() != fb.BranchesLength() ||
fb.BranchesLength() != fb.UsersLength() ||
fb.UsersLength() != fb.HostsLength() ||
fb.HostsLength() != fb.ValuesLength() {
return fmt.Errorf("cannot deserialize an access table with differing field lengths")
}
// Read the binlog
@@ -195,10 +202,17 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
return err
}
// Initialize every slice
tbl.Databases = make([]MatchExpression, fb.DatabasesLength())
tbl.Branches = make([]MatchExpression, fb.BranchesLength())
tbl.Users = make([]MatchExpression, fb.UsersLength())
tbl.Hosts = make([]MatchExpression, fb.HostsLength())
tbl.Values = make([]AccessValue, fb.ValuesLength())
// Read the databases
for i := 0; i < fb.DatabasesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Databases(serialMatchExpr, i)
tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the branches
for i := 0; i < fb.BranchesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
@@ -222,6 +236,7 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
serialAccessValue := &serial.BranchControlAccessValue{}
fb.Values(serialAccessValue, i)
tbl.Values[i] = AccessValue{
Database: string(serialAccessValue.Database()),
Branch: string(serialAccessValue.Branch()),
User: string(serialAccessValue.User()),
Host: string(serialAccessValue.Host()),
@@ -231,6 +246,18 @@ func (tbl *Access) Deserialize(fb *serial.BranchControlAccess) error {
return nil
}
// filterDatabases returns all databases that match the given collection indexes.
func (tbl *Access) filterDatabases(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Databases[filter])
}
return matchExprs
}
// filterBranches returns all branches that match the given collection indexes.
func (tbl *Access) filterBranches(filters []uint32) []MatchExpression {
if len(filters) == 0 {
@@ -281,7 +308,7 @@ func (tbl *Access) gatherPermissions(collectionIndexes []uint32) Permissions {
func (tbl *Access) insertDefaultRow() {
// Check if the appropriate row already exists
for _, value := range tbl.Values {
if value.Branch == "%" && value.User == "%" && value.Host == "%" {
if value.Database == "%" && value.Branch == "%" && value.User == "%" && value.Host == "%" {
// Getting to this state will be disallowed in the future, but if the row exists without any perms, then add
// the Write perm
if uint64(value.Permissions) == 0 {
@@ -290,17 +317,21 @@ func (tbl *Access) insertDefaultRow() {
return
}
}
tbl.insert("%", "%", "%", Permissions_Write)
tbl.insert("%", "%", "%", "%", Permissions_Write)
}
// insert adds the given expressions to the table. This does not perform any sort of validation whatsoever, so it is
// important to ensure that the expressions are valid before insertion.
func (tbl *Access) insert(branch string, user string, host string, perms Permissions) {
// Branch and Host are case-insensitive, while user is case-sensitive
func (tbl *Access) insert(database string, branch string, user string, host string, perms Permissions) {
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
database = strings.ToLower(FoldExpression(database))
branch = strings.ToLower(FoldExpression(branch))
user = FoldExpression(user)
host = strings.ToLower(FoldExpression(host))
// Each expression is capped at 2¹⁶-1 values, so we truncate to 2¹⁶-2 and add the any-match character at the end if it's over
if len(database) > math.MaxUint16 {
database = string(append([]byte(database[:math.MaxUint16-1]), byte('%')))
}
if len(branch) > math.MaxUint16 {
branch = string(append([]byte(branch[:math.MaxUint16-1]), byte('%')))
}
@@ -311,16 +342,19 @@ func (tbl *Access) insert(branch string, user string, host string, perms Permiss
host = string(append([]byte(host[:math.MaxUint16-1]), byte('%')))
}
// Add the expression strings to the binlog
tbl.binlog.Insert(branch, user, host, uint64(perms))
tbl.binlog.Insert(database, branch, user, host, uint64(perms))
// Parse and insert the expressions
databaseExpr := ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
branchExpr := ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
userExpr := ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
hostExpr := ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
nextIdx := uint32(len(tbl.Values))
tbl.Databases = append(tbl.Databases, MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
tbl.Branches = append(tbl.Branches, MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
tbl.Users = append(tbl.Users, MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
tbl.Hosts = append(tbl.Hosts, MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
tbl.Values = append(tbl.Values, AccessValue{
Database: database,
Branch: branch,
User: user,
Host: host,
@@ -330,11 +364,13 @@ func (tbl *Access) insert(branch string, user string, host string, perms Permiss
// Serialize returns the offset for the AccessValue written to the given builder.
func (val *AccessValue) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
database := b.CreateString(val.Database)
branch := b.CreateString(val.Branch)
user := b.CreateString(val.User)
host := b.CreateString(val.Host)
serial.BranchControlAccessValueStart(b)
serial.BranchControlAccessValueAddDatabase(b, database)
serial.BranchControlAccessValueAddBranch(b, branch)
serial.BranchControlAccessValueAddUser(b, user)
serial.BranchControlAccessValueAddHost(b, host)

View File

@@ -35,6 +35,7 @@ type Binlog struct {
// BinlogRow is a row within the Binlog.
type BinlogRow struct {
IsInsert bool
Database string
Branch string
User string
Host string
@@ -55,6 +56,7 @@ func NewAccessBinlog(vals []AccessValue) *Binlog {
for i, val := range vals {
rows[i] = BinlogRow{
IsInsert: true,
Database: val.Database,
Branch: val.Branch,
User: val.User,
Host: val.Host,
@@ -74,6 +76,7 @@ func NewNamespaceBinlog(vals []NamespaceValue) *Binlog {
for i, val := range vals {
rows[i] = BinlogRow{
IsInsert: true,
Database: val.Database,
Branch: val.Branch,
User: val.User,
Host: val.Host,
@@ -126,6 +129,7 @@ func (binlog *Binlog) Deserialize(fb *serial.BranchControlBinlog) error {
fb.Rows(serialBinlogRow, i)
binlog.rows[i] = BinlogRow{
IsInsert: serialBinlogRow.IsInsert(),
Database: string(serialBinlogRow.Database()),
Branch: string(serialBinlogRow.Branch()),
User: string(serialBinlogRow.User()),
Host: string(serialBinlogRow.Host()),
@@ -163,12 +167,13 @@ func (binlog *Binlog) MergeOverlay(overlay *BinlogOverlay) error {
}
// Insert adds an insert entry to the Binlog.
func (binlog *Binlog) Insert(branch string, user string, host string, permissions uint64) {
func (binlog *Binlog) Insert(database string, branch string, user string, host string, permissions uint64) {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
binlog.rows = append(binlog.rows, BinlogRow{
IsInsert: true,
Database: database,
Branch: branch,
User: user,
Host: host,
@@ -177,12 +182,13 @@ func (binlog *Binlog) Insert(branch string, user string, host string, permission
}
// Delete adds a delete entry to the Binlog.
func (binlog *Binlog) Delete(branch string, user string, host string, permissions uint64) {
func (binlog *Binlog) Delete(database string, branch string, user string, host string, permissions uint64) {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
binlog.rows = append(binlog.rows, BinlogRow{
IsInsert: false,
Database: database,
Branch: branch,
User: user,
Host: host,
@@ -197,12 +203,14 @@ func (binlog *Binlog) Rows() []BinlogRow {
// Serialize returns the offset for the BinlogRow written to the given builder.
func (row *BinlogRow) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
database := b.CreateString(row.Database)
branch := b.CreateString(row.Branch)
user := b.CreateString(row.User)
host := b.CreateString(row.Host)
serial.BranchControlBinlogRowStart(b)
serial.BranchControlBinlogRowAddIsInsert(b, row.IsInsert)
serial.BranchControlBinlogRowAddDatabase(b, database)
serial.BranchControlBinlogRowAddBranch(b, branch)
serial.BranchControlBinlogRowAddUser(b, user)
serial.BranchControlBinlogRowAddHost(b, host)
@@ -211,9 +219,10 @@ func (row *BinlogRow) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
}
// Insert adds an insert entry to the BinlogOverlay.
func (overlay *BinlogOverlay) Insert(branch string, user string, host string, permissions uint64) {
func (overlay *BinlogOverlay) Insert(database string, branch string, user string, host string, permissions uint64) {
overlay.rows = append(overlay.rows, BinlogRow{
IsInsert: true,
Database: database,
Branch: branch,
User: user,
Host: host,
@@ -222,9 +231,10 @@ func (overlay *BinlogOverlay) Insert(branch string, user string, host string, pe
}
// Delete adds a delete entry to the BinlogOverlay.
func (overlay *BinlogOverlay) Delete(branch string, user string, host string, permissions uint64) {
func (overlay *BinlogOverlay) Delete(database string, branch string, user string, host string, permissions uint64) {
overlay.rows = append(overlay.rows, BinlogRow{
IsInsert: false,
Database: database,
Branch: branch,
User: user,
Host: host,

View File

@@ -18,7 +18,6 @@ import (
"context"
goerrors "errors"
"fmt"
"net"
"os"
"github.com/dolthub/go-mysql-server/sql"
@@ -29,22 +28,25 @@ import (
)
var (
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
ErrExpressionsTooLong = errors.NewKind("expressions are too long [%q, %q, %q]")
ErrInsertingRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q]")
ErrUpdatingRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q]")
ErrUpdatingToRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q")
ErrDeletingRow = errors.NewKind("`%s`@`%s` cannot delete the row [%q, %q, %q]")
ErrMissingController = errors.NewKind("a context has a non-nil session but is missing its branch controller")
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
ErrExpressionsTooLong = errors.NewKind("expressions are too long [%q, %q, %q, %q]")
ErrInsertingAccessRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q, %q]")
ErrInsertingNamespaceRow = errors.NewKind("`%s`@`%s` cannot add the row [%q, %q, %q, %q]")
ErrUpdatingRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q]")
ErrUpdatingToRow = errors.NewKind("`%s`@`%s` cannot update the row [%q, %q, %q, %q] to the new branch expression [%q, %q]")
ErrDeletingRow = errors.NewKind("`%s`@`%s` cannot delete the row [%q, %q, %q, %q]")
ErrMissingController = errors.NewKind("a context has a non-nil session but is missing its branch controller")
)
// Context represents the interface that must be inherited from the context.
type Context interface {
GetBranch() (string, error)
GetCurrentDatabase() string
GetUser() string
GetHost() string
GetPrivilegeSet() (sql.PrivilegeSet, uint64)
GetController() *Controller
}
@@ -57,50 +59,6 @@ type Controller struct {
doltConfigDirPath string
}
var (
// superUser is the server-wide user that has full, irrevocable permission to do whatever they want to any branch and table
superUser string
// superHost is the host counterpart of the superUser
superHost string
// superHostIsLoopback states whether the superHost is a loopback address
superHostIsLoopback bool
)
var enabled = false
func init() {
if os.Getenv("DOLT_ENABLE_BRANCH_CONTROL") != "" {
enabled = true
}
}
// SetEnabled is a TEMPORARY function just used for testing (so that we don't have to set the env variable)
func SetEnabled(value bool) {
enabled = value
}
// SetSuperUser sets the server-wide super user to the given user and host combination. The super user has full,
// irrevocable permission to do whatever they want to any branch and table.
func SetSuperUser(user string, host string) {
superUser = user
// Check if superHost is a loopback
if host == "localhost" || net.ParseIP(host).IsLoopback() {
superHost = "localhost"
superHostIsLoopback = true
} else {
superHost = host
superHostIsLoopback = false
}
}
// IsSuperUser returns whether the given user and host combination is the super user.
func IsSuperUser(user string, host string) bool {
if user == superUser && ((host == superHost) || (superHostIsLoopback && net.ParseIP(host).IsLoopback())) {
return true
}
return false
}
// CreateDefaultController returns a default controller, which only has a single entry allowing all users to have write
// permissions on all branches (only the super user has admin, if a super user has been set). This is equivalent to
// passing empty strings to LoadData.
@@ -115,9 +73,6 @@ func CreateDefaultController() *Controller {
// LoadData loads the data from the given location and returns a controller. Returns the default controller if the
// `branchControlFilePath` is empty.
func LoadData(branchControlFilePath string, doltConfigDirPath string) (*Controller, error) {
if !enabled {
return nil, nil
}
accessTbl := newAccess()
controller := &Controller{
Access: accessTbl,
@@ -171,9 +126,6 @@ func LoadData(branchControlFilePath string, doltConfigDirPath string) (*Controll
// SaveData saves the data from the context's controller to the location pointed by it.
func SaveData(ctx context.Context) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we've got nothing to serialize
if branchAwareSession == nil {
@@ -216,9 +168,6 @@ func SaveData(ctx context.Context) error {
// the context. In these cases, CheckAccess will pass as we want to allow all local commands to ignore branch
// permissions.
func CheckAccess(ctx context.Context, flags Permissions) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we allow all operations
if branchAwareSession == nil {
@@ -234,12 +183,13 @@ func CheckAccess(ctx context.Context, flags Permissions) error {
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
database := branchAwareSession.GetCurrentDatabase()
branch, err := branchAwareSession.GetBranch()
if err != nil {
return err
}
// Get the permissions for the branch, user, and host combination
_, perms := controller.Access.Match(branch, user, host)
_, perms := controller.Access.Match(database, branch, user, host)
// If either the flags match or the user is an admin for this branch, then we allow access
if (perms&flags == flags) || (perms&Permissions_Admin == Permissions_Admin) {
return nil
@@ -252,9 +202,6 @@ func CheckAccess(ctx context.Context, flags Permissions) error {
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
// these cases, CanCreateBranch will pass as we want to allow all local commands to freely create branches.
func CanCreateBranch(ctx context.Context, branchName string) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we allow the create operation
if branchAwareSession == nil {
@@ -270,7 +217,8 @@ func CanCreateBranch(ctx context.Context, branchName string) error {
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
if controller.Namespace.CanCreate(branchName, user, host) {
database := branchAwareSession.GetCurrentDatabase()
if controller.Namespace.CanCreate(database, branchName, user, host) {
return nil
}
return ErrCannotCreateBranch.New(user, host, branchName)
@@ -281,9 +229,6 @@ func CanCreateBranch(ctx context.Context, branchName string) error {
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
// these cases, CanDeleteBranch will pass as we want to allow all local commands to freely delete branches.
func CanDeleteBranch(ctx context.Context, branchName string) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we allow the delete operation
if branchAwareSession == nil {
@@ -299,8 +244,9 @@ func CanDeleteBranch(ctx context.Context, branchName string) error {
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
database := branchAwareSession.GetCurrentDatabase()
// Get the permissions for the branch, user, and host combination
_, perms := controller.Access.Match(branchName, user, host)
_, perms := controller.Access.Match(database, branchName, user, host)
// If the user has the write or admin flags, then we allow access
if (perms&Permissions_Write == Permissions_Write) || (perms&Permissions_Admin == Permissions_Admin) {
return nil
@@ -312,9 +258,6 @@ func CanDeleteBranch(ctx context.Context, branchName string) error {
// context is missing some functionality that is needed to perform the addition, such as a user or the Controller, then
// this simply returns.
func AddAdminForContext(ctx context.Context, branchName string) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
if branchAwareSession == nil {
return nil
@@ -326,15 +269,16 @@ func AddAdminForContext(ctx context.Context, branchName string) error {
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
database := branchAwareSession.GetCurrentDatabase()
// Check if we already have admin permissions for the given branch, as there's no need to do another insertion if so
controller.Access.RWMutex.RLock()
_, modPerms := controller.Access.Match(branchName, user, host)
_, modPerms := controller.Access.Match(database, branchName, user, host)
controller.Access.RWMutex.RUnlock()
if modPerms&Permissions_Admin == Permissions_Admin {
return nil
}
controller.Access.RWMutex.Lock()
controller.Access.insert(branchName, user, host, Permissions_Admin)
controller.Access.insert(database, branchName, user, host, Permissions_Admin)
controller.Access.RWMutex.Unlock()
return SaveData(ctx)
}
@@ -351,3 +295,29 @@ func GetBranchAwareSession(ctx context.Context) Context {
}
return nil
}
// HasDatabasePrivileges returns whether the given context's user has the correct privileges to modify any table entries
// that match the given database. The following are the required privileges:
//
// Global Space: SUPER, GRANT
// Global Space: CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
// Database Space: CREATE, ALTER, DROP, INSERT, UPDATE, DELETE, EXECUTE, GRANT
//
// Any user that may grant SUPER is considered to be a super user. In addition, any user that may grant the suite of
// alteration privileges is also considered a super user. The SUPER privilege does not exist at the database level, it
// is a global privilege only.
func HasDatabasePrivileges(ctx Context, database string) bool {
if ctx == nil {
return true
}
privSet, counter := ctx.GetPrivilegeSet()
if counter == 0 {
return false
}
hasSuper := privSet.Has(sql.PrivilegeType_Super, sql.PrivilegeType_Grant)
isGlobalAdmin := privSet.Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_Grant)
isDatabaseAdmin := privSet.Database(database).Has(sql.PrivilegeType_Create, sql.PrivilegeType_Alter, sql.PrivilegeType_Drop,
sql.PrivilegeType_Insert, sql.PrivilegeType_Update, sql.PrivilegeType_Delete, sql.PrivilegeType_Execute, sql.PrivilegeType_Grant)
return hasSuper || isGlobalAdmin || isDatabaseAdmin
}

View File

@@ -31,41 +31,50 @@ type Namespace struct {
access *Access
binlog *Binlog
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []NamespaceValue
RWMutex *sync.RWMutex
Databases []MatchExpression
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []NamespaceValue
RWMutex *sync.RWMutex
}
// NamespaceValue contains the user-facing values of a particular row.
type NamespaceValue struct {
Branch string
User string
Host string
Database string
Branch string
User string
Host string
}
// newNamespace returns a new Namespace.
func newNamespace(accessTbl *Access) *Namespace {
return &Namespace{
binlog: NewNamespaceBinlog(nil),
access: accessTbl,
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
RWMutex: &sync.RWMutex{},
binlog: NewNamespaceBinlog(nil),
access: accessTbl,
Databases: nil,
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
RWMutex: &sync.RWMutex{},
}
}
// CanCreate checks the given branch, and returns whether the given user and host combination is able to create that
// branch. Handles the super user case.
func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
// Super user can always create branches
if IsSuperUser(user, host) {
// CanCreate checks the given database and branch, and returns whether the given user and host combination is able to
// create that branch. Handles the super user case.
func (tbl *Namespace) CanCreate(database string, branch string, user string, host string) bool {
filteredIndexes := Match(tbl.Databases, database, sql.Collation_utf8mb4_0900_ai_ci)
// If there are no database entries, then the Namespace is unrestricted
if len(filteredIndexes) == 0 {
indexPool.Put(filteredIndexes)
return true
}
matchedSet := Match(tbl.Branches, branch, sql.Collation_utf8mb4_0900_ai_ci)
filteredBranches := tbl.filterBranches(filteredIndexes)
indexPool.Put(filteredIndexes)
matchedSet := Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredBranches)
// If there are no branch entries, then the Namespace is unrestricted
if len(matchedSet) == 0 {
indexPool.Put(matchedSet)
@@ -74,7 +83,7 @@ func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
// We take either the longest match, or the set of longest matches if multiple matches have the same length
longest := -1
filteredIndexes := indexPool.Get().([]uint32)[:0]
filteredIndexes = indexPool.Get().([]uint32)[:0]
for _, matched := range matchedSet {
matchedValue := tbl.Values[matched]
// If we've found a longer match, then we reset the slice. We append to it in the following if statement.
@@ -102,11 +111,11 @@ func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
return result
}
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
// returns -1. Assumes that the given expressions have already been folded.
func (tbl *Namespace) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
// GetIndex returns the index of the given database, branch, user, and host expressions. If the expressions cannot be
// found, returns -1. Assumes that the given expressions have already been folded.
func (tbl *Namespace) GetIndex(databaseExpr string, branchExpr string, userExpr string, hostExpr string) int {
for i, value := range tbl.Values {
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
if value.Database == databaseExpr && value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
return i
}
}
@@ -131,11 +140,15 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
// Serialize the binlog
binlog := tbl.binlog.Serialize(b)
// Initialize field offset slices
databaseOffsets := make([]flatbuffers.UOffsetT, len(tbl.Databases))
branchOffsets := make([]flatbuffers.UOffsetT, len(tbl.Branches))
userOffsets := make([]flatbuffers.UOffsetT, len(tbl.Users))
hostOffsets := make([]flatbuffers.UOffsetT, len(tbl.Hosts))
valueOffsets := make([]flatbuffers.UOffsetT, len(tbl.Values))
// Get field offsets
for i, matchExpr := range tbl.Databases {
databaseOffsets[i] = matchExpr.Serialize(b)
}
for i, matchExpr := range tbl.Branches {
branchOffsets[i] = matchExpr.Serialize(b)
}
@@ -149,6 +162,11 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
valueOffsets[i] = val.Serialize(b)
}
// Get the field vectors
serial.BranchControlNamespaceStartDatabasesVector(b, len(databaseOffsets))
for i := len(databaseOffsets) - 1; i >= 0; i-- {
b.PrependUOffsetT(databaseOffsets[i])
}
databases := b.EndVector(len(databaseOffsets))
serial.BranchControlNamespaceStartBranchesVector(b, len(branchOffsets))
for i := len(branchOffsets) - 1; i >= 0; i-- {
b.PrependUOffsetT(branchOffsets[i])
@@ -172,6 +190,7 @@ func (tbl *Namespace) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
// Write the table
serial.BranchControlNamespaceStart(b)
serial.BranchControlNamespaceAddBinlog(b, binlog)
serial.BranchControlNamespaceAddDatabases(b, databases)
serial.BranchControlNamespaceAddBranches(b, branches)
serial.BranchControlNamespaceAddUsers(b, users)
serial.BranchControlNamespaceAddHosts(b, hosts)
@@ -189,7 +208,10 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
return fmt.Errorf("cannot deserialize to a non-empty namespace table")
}
// Verify that all fields have the same length
if fb.BranchesLength() != fb.UsersLength() || fb.UsersLength() != fb.HostsLength() || fb.HostsLength() != fb.ValuesLength() {
if fb.DatabasesLength() != fb.BranchesLength() ||
fb.BranchesLength() != fb.UsersLength() ||
fb.UsersLength() != fb.HostsLength() ||
fb.HostsLength() != fb.ValuesLength() {
return fmt.Errorf("cannot deserialize a namespace table with differing field lengths")
}
// Read the binlog
@@ -201,10 +223,17 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
return err
}
// Initialize every slice
tbl.Databases = make([]MatchExpression, fb.DatabasesLength())
tbl.Branches = make([]MatchExpression, fb.BranchesLength())
tbl.Users = make([]MatchExpression, fb.UsersLength())
tbl.Hosts = make([]MatchExpression, fb.HostsLength())
tbl.Values = make([]NamespaceValue, fb.ValuesLength())
// Read the databases
for i := 0; i < fb.DatabasesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Databases(serialMatchExpr, i)
tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the branches
for i := 0; i < fb.BranchesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
@@ -228,14 +257,27 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
serialNamespaceValue := &serial.BranchControlNamespaceValue{}
fb.Values(serialNamespaceValue, i)
tbl.Values[i] = NamespaceValue{
Branch: string(serialNamespaceValue.Branch()),
User: string(serialNamespaceValue.User()),
Host: string(serialNamespaceValue.Host()),
Database: string(serialNamespaceValue.Database()),
Branch: string(serialNamespaceValue.Branch()),
User: string(serialNamespaceValue.User()),
Host: string(serialNamespaceValue.Host()),
}
}
return nil
}
// filterDatabases returns all databases that match the given collection indexes.
func (tbl *Namespace) filterDatabases(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Databases[filter])
}
return matchExprs
}
// filterBranches returns all branches that match the given collection indexes.
func (tbl *Namespace) filterBranches(filters []uint32) []MatchExpression {
if len(filters) == 0 {
@@ -274,11 +316,13 @@ func (tbl *Namespace) filterHosts(filters []uint32) []MatchExpression {
// Serialize returns the offset for the NamespaceValue written to the given builder.
func (val *NamespaceValue) Serialize(b *flatbuffers.Builder) flatbuffers.UOffsetT {
database := b.CreateString(val.Database)
branch := b.CreateString(val.Branch)
user := b.CreateString(val.User)
host := b.CreateString(val.Host)
serial.BranchControlNamespaceValueStart(b)
serial.BranchControlNamespaceValueAddDatabase(b, database)
serial.BranchControlNamespaceValueAddBranch(b, branch)
serial.BranchControlNamespaceValueAddUser(b, user)
serial.BranchControlNamespaceValueAddHost(b, host)

View File

@@ -39,7 +39,6 @@ import (
)
var ErrRepositoryExists = errors.New("data repository already exists")
var ErrFailedToInitRepo = errors.New("")
var ErrFailedToCreateDirectory = errors.New("unable to create directories")
var ErrFailedToAccessDir = errors.New("unable to access directories")
var ErrFailedToCreateRepoStateWithRemote = errors.New("unable to create repo state with remote")
@@ -76,7 +75,7 @@ func EnvForClone(ctx context.Context, nbf *types.NomsBinFormat, r env.Remote, di
dEnv := env.Load(ctx, homeProvider, newFs, doltdb.LocalDirDoltDB, version)
err = dEnv.InitRepoWithNoData(ctx, nbf)
if err != nil {
return nil, fmt.Errorf("%w; %s", ErrFailedToInitRepo, err.Error())
return nil, fmt.Errorf("failed to init repo: %w", err)
}
dEnv.RSLoadErr = nil
@@ -280,7 +279,7 @@ func InitEmptyClonedRepo(ctx context.Context, dEnv *env.DoltEnv) error {
err := dEnv.InitDBWithTime(ctx, types.Format_Default, name, email, initBranch, datas.CommitNowFunc())
if err != nil {
return ErrFailedToInitRepo
return fmt.Errorf("failed to init repo: %w", err)
}
return nil

View File

@@ -365,7 +365,7 @@ func FetchRemoteBranch(
func FetchRefSpecs(ctx context.Context, dbData env.DbData, srcDB *doltdb.DoltDB, refSpecs []ref.RemoteRefSpec, remote env.Remote, mode ref.UpdateMode, progStarter ProgStarter, progStopper ProgStopper) error {
branchRefs, err := srcDB.GetHeadRefs(ctx)
if err != nil {
return env.ErrFailedToReadDb
return fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
}
for _, rs := range refSpecs {

View File

@@ -59,7 +59,6 @@ const (
var zeroHashStr = (hash.Hash{}).String()
var ErrPreexistingDoltDir = errors.New(".dolt dir already exists")
var ErrStateUpdate = errors.New("error updating local data repo state")
var ErrMarshallingSchema = errors.New("error marshalling schema")
var ErrInvalidCredsFile = errors.New("invalid creds file")
@@ -389,7 +388,7 @@ func (dEnv *DoltEnv) createDirectories(dir string) (string, error) {
}
if dEnv.hasDoltDir(dir) {
return "", ErrPreexistingDoltDir
return "", fmt.Errorf(".dolt directory already exists at '%s'", dir)
}
absDataDir := filepath.Join(absPath, dbfactory.DoltDataDir)
@@ -923,7 +922,7 @@ func (dEnv *DoltEnv) RemoveRemote(ctx context.Context, name string) error {
ddb := dEnv.DoltDB
refs, err := ddb.GetRemoteRefs(ctx)
if err != nil {
return ErrFailedToReadFromDb
return fmt.Errorf("%w: %s", ErrFailedToReadFromDb, err.Error())
}
for _, r := range refs {

View File

@@ -20,6 +20,7 @@ import (
"path/filepath"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
)
const (
@@ -42,7 +43,7 @@ type HomeDirProvider func() (string, error)
// provide a different directory where the root .dolt directory should be located and global state will be stored there.
func GetCurrentUserHomeDir() (string, error) {
if doltRootPath, ok := os.LookupEnv(doltRootPathEnvVar); ok && doltRootPath != "" {
return doltRootPath, nil
return filesys.LocalFS.Abs(doltRootPath)
}
var home string

View File

@@ -205,10 +205,9 @@ func NewPushOpts(ctx context.Context, apr *argparser.ArgParseResults, rsr RepoSt
hasRef, err := ddb.HasRef(ctx, currentBranch)
if err != nil {
return nil, ErrFailedToReadDb
return nil, fmt.Errorf("%w: %s", ErrFailedToReadDb, err.Error())
} else if !hasRef {
return nil, fmt.Errorf("%w: '%s'", ErrUnknownBranch, currentBranch.GetPath())
}
src := refSpec.SrcRef(currentBranch)

View File

@@ -124,44 +124,6 @@ func ExecNoFFMerge(ctx context.Context, dEnv *env.DoltEnv, spec *MergeSpec) (map
tblToStats := make(map[string]*MergeStats)
err = mergedRootToWorking(ctx, false, dEnv, mergedRoot, spec.WorkingDiffs, spec.MergeC, spec.MergeCSpecStr, tblToStats)
if err != nil {
return tblToStats, err
}
// Reload roots since the above method writes new values to the working set
roots, err := dEnv.Roots(ctx)
if err != nil {
return tblToStats, err
}
ws, err := dEnv.WorkingSet(ctx)
if err != nil {
return tblToStats, err
}
var mergeParentCommits []*doltdb.Commit
if ws.MergeActive() {
mergeParentCommits = []*doltdb.Commit{ws.MergeState().Commit()}
}
_, err = actions.CommitStaged(ctx, roots, ws.MergeActive(), mergeParentCommits, dEnv.DbData(), actions.CommitStagedProps{
Message: spec.Msg,
Date: spec.Date,
AllowEmpty: spec.AllowEmpty,
Force: spec.Force,
Name: spec.Name,
Email: spec.Email,
})
if err != nil {
return tblToStats, fmt.Errorf("%w; failed to commit", err)
}
err = dEnv.ClearMerge(ctx)
if err != nil {
return tblToStats, err
}
return tblToStats, err
}

View File

@@ -183,23 +183,25 @@ func equalRows(old, new sql.Row, sch sql.Schema) (bool, error) {
// special case time comparison to account
// for precision changes between formats
if _, ok := old[i].(time.Time); ok {
if old[i], err = sql.Int64.Convert(old[i]); err != nil {
var o, n interface{}
if o, err = sql.Int64.Convert(old[i]); err != nil {
return false, err
}
if new[i], err = sql.Int64.Convert(new[i]); err != nil {
if n, err = sql.Int64.Convert(new[i]); err != nil {
return false, err
}
if cmp, err = sql.Int64.Compare(o, n); err != nil {
return false, err
}
cmp, err = sql.Int64.Compare(old[i], new[i])
} else {
cmp, err = sch[i].Type.Compare(old[i], new[i])
if cmp, err = sch[i].Type.Compare(old[i], new[i]); err != nil {
return false, err
}
}
if err != nil {
return false, err
} else if cmp != 0 {
if cmp != 0 {
return false, nil
}
}
return true, nil
}

View File

@@ -18,9 +18,14 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -36,7 +41,9 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/remotesrv"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/clusterdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
"github.com/dolthub/dolt/go/store/types"
)
@@ -59,12 +66,17 @@ type Controller struct {
sinterceptor serverinterceptor
cinterceptor clientinterceptor
lgr *logrus.Logger
grpcCreds credentials.PerRPCCredentials
provider dbProvider
iterSessions IterSessions
killQuery func(uint32)
killConnection func(uint32) error
jwks *jwtauth.MultiJWKS
tlsCfg *tls.Config
grpcCreds credentials.PerRPCCredentials
pub ed25519.PublicKey
priv ed25519.PrivateKey
}
type sqlvars interface {
@@ -112,9 +124,49 @@ func NewController(lgr *logrus.Logger, cfg Config, pCfg config.ReadWriteConfig)
ret.cinterceptor.lgr = lgr.WithFields(logrus.Fields{})
ret.cinterceptor.setRole(role, epoch)
ret.cinterceptor.roleSetter = roleSetter
ret.tlsCfg, err = ret.outboundTlsConfig()
if err != nil {
return nil, err
}
ret.pub, ret.priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
keyID := creds.PubKeyToKID(ret.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
ret.grpcCreds = &creds.RPCCreds{
PrivKey: ret.priv,
Audience: creds.RemotesAPIAudience,
Issuer: creds.ClientIssuer,
KeyID: keyIDStr,
RequireTLS: false,
}
ret.jwks = ret.standbyRemotesJWKS()
ret.sinterceptor.keyProvider = ret.jwks
ret.sinterceptor.jwtExpected = JWTExpectations()
return ret, nil
}
func (c *Controller) Run() {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c.jwks.Run()
}()
wg.Wait()
}
func (c *Controller) GracefulStop() error {
c.jwks.GracefulStop()
return nil
}
func (c *Controller) ManageSystemVariables(variables sqlvars) {
if c == nil {
return
@@ -198,7 +250,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
}
func (c *Controller) gRPCDialProvider(denv *env.DoltEnv) dbfactory.GRPCDialProvider {
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.cfg, c.grpcCreds}
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.tlsCfg, c.grpcCreds}
}
func (c *Controller) RegisterStoredProcedures(store procedurestore) {
@@ -412,23 +464,9 @@ func (c *Controller) RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.Server
args = sqle.RemoteSrvServerArgs(ctx, args)
args.DBCache = remotesrvStoreCache{args.DBCache, c}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
keyID := creds.PubKeyToKID(pub)
keyID := creds.PubKeyToKID(c.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, pub)
c.grpcCreds = &creds.RPCCreds{
PrivKey: priv,
Audience: creds.RemotesAPIAudience,
Issuer: creds.ClientIssuer,
KeyID: keyIDStr,
RequireTLS: false,
}
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub)
return args
}
@@ -565,3 +603,129 @@ func (c *Controller) waitForHooksToReplicate() error {
return errors.New("cluster/controller: failed to transition from primary to standby gracefully; could not replicate databases to standby in a timely manner.")
}
}
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
// following semantics:
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
// all of which are trusted roots for the outbound connections the
// remotestorage client establishes.
// * The certificate chain presented by the server must validate to a root
// which was present in tls_ca. In particular, every certificate in the chain
// must be within its validity window, the signatures must be valid, key usage
// and isCa must be correctly set for the roots and the intermediates, and the
// leaf must have extended key usage server auth.
// * On the other hand, no verification is done against the SAN or the Subject
// of the certificate.
//
// We use these TLS semantics for both connections to the gRPC endpoint which
// is the actual remotesapi, and for connections to any HTTPS endpoints to
// which the gRPC service returns URLs. For now, this works perfectly for our
// use case, but it's tightly coupled to `cluster:` deployment topologies and
// the likes.
//
// If tls_ca is not set then default TLS handling is performed. In particular,
// if the remotesapi endpoints is HTTPS, then the system roots are used and
// ServerName is verified against the presented URL SANs of the certificates.
//
// This tls Config is used for fetching JWKS, for outbound GRPC connections and
// for outbound https connections on the URLs that the GRPC services return.
func (c *Controller) outboundTlsConfig() (*tls.Config, error) {
tlsCA := c.cfg.RemotesAPIConfig().TLSCA()
if tlsCA == "" {
return nil, nil
}
urlmatches := c.cfg.RemotesAPIConfig().ServerNameURLMatches()
dnsmatches := c.cfg.RemotesAPIConfig().ServerNameDNSMatches()
pem, err := os.ReadFile(tlsCA)
if err != nil {
return nil, err
}
roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM(pem); !ok {
return nil, errors.New("error loading ca roots from " + tlsCA)
}
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(rawCerts))
var err error
for i, asn1Data := range rawCerts {
certs[i], err = x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
return err
}
if len(urlmatches) > 0 {
found := false
for _, n := range urlmatches {
for _, cn := range certs[0].URIs {
if n == cn.String() {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_urls, but it did not")
}
}
if len(dnsmatches) > 0 {
found := false
for _, n := range dnsmatches {
for _, cn := range certs[0].DNSNames {
if n == cn {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_dns, but it did not")
}
}
return nil
}
return &tls.Config{
// We have to InsecureSkipVerify because ServerName is always
// set by the grpc dial provider and golang tls.Config does not
// have good support for performing certificate validation
// without server name validation.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyFunc,
NextProtos: []string{"h2"},
}, nil
}
func (c *Controller) standbyRemotesJWKS() *jwtauth.MultiJWKS {
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: c.tlsCfg,
ForceAttemptHTTP2: true,
},
}
urls := make([]string, len(c.cfg.StandbyRemotes()))
for i, r := range c.cfg.StandbyRemotes() {
urls[i] = strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, ".well-known/jwks.json", -1)
}
return jwtauth.NewMultiJWKS(c.lgr.WithFields(logrus.Fields{"component": "jwks-key-provider"}), urls, client)
}

View File

@@ -16,15 +16,13 @@ package cluster
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
)
@@ -35,24 +33,26 @@ import (
// - client interceptors for transmitting our replication role.
// - do not use environment credentials. (for now).
type grpcDialProvider struct {
orig dbfactory.GRPCDialProvider
ci *clientinterceptor
cfg Config
creds credentials.PerRPCCredentials
orig dbfactory.GRPCDialProvider
ci *clientinterceptor
tlsCfg *tls.Config
creds credentials.PerRPCCredentials
}
func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GRPCRemoteConfig, error) {
tlsConfig, err := p.tlsConfig()
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
config.TLSConfig = tlsConfig
config.TLSConfig = p.tlsCfg
config.Creds = p.creds
if config.Creds != nil && config.TLSConfig != nil {
if c, ok := config.Creds.(*creds.RPCCreds); ok {
c.RequireTLS = true
}
}
config.WithEnvCreds = false
cfg, err := p.orig.GetGRPCDialParams(config)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
cfg.DialOptions = append(cfg.DialOptions, p.ci.Options()...)
cfg.DialOptions = append(cfg.DialOptions, grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
@@ -63,114 +63,6 @@ func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
},
MinConnectTimeout: 250 * time.Millisecond,
}))
return cfg, nil
}
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
// following semantics:
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
// all of which are trusted roots for the outbound connections the
// remotestorage client establishes.
// * The certificate chain presented by the server must validate to a root
// which was present in tls_ca. In particular, every certificate in the chain
// must be within its validity window, the signatures must be valid, key usage
// and isCa must be correctly set for the roots and the intermediates, and the
// leaf must have extended key usage server auth.
// * On the other hand, no verification is done against the SAN or the Subject
// of the certificate.
//
// We use these TLS semantics for both connections to the gRPC endpoint which
// is the actual remotesapi, and for connections to any HTTPS endpoints to
// which the gRPC service returns URLs. For now, this works perfectly for our
// use case, but it's tightly coupled to `cluster:` deployment topologies and
// the likes.
//
// If tls_ca is not set then default TLS handling is performed. In particular,
// if the remotesapi endpoints is HTTPS, then the system roots are used and
// ServerName is verified against the presented URL SANs of the certificates.
func (p grpcDialProvider) tlsConfig() (*tls.Config, error) {
tlsCA := p.cfg.RemotesAPIConfig().TLSCA()
if tlsCA == "" {
return nil, nil
}
urlmatches := p.cfg.RemotesAPIConfig().ServerNameURLMatches()
dnsmatches := p.cfg.RemotesAPIConfig().ServerNameDNSMatches()
pem, err := ioutil.ReadFile(tlsCA)
if err != nil {
return nil, err
}
roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM(pem); !ok {
return nil, errors.New("error loading ca roots from " + tlsCA)
}
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(rawCerts))
var err error
for i, asn1Data := range rawCerts {
certs[i], err = x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
return err
}
if len(urlmatches) > 0 {
found := false
for _, n := range urlmatches {
for _, cn := range certs[0].URIs {
if n == cn.String() {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_urls, but it did not")
}
}
if len(dnsmatches) > 0 {
found := false
for _, n := range dnsmatches {
for _, cn := range certs[0].DNSNames {
if n == cn {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_dns, but it did not")
}
}
return nil
}
return &tls.Config{
// We have to InsecureSkipVerify because ServerName is always
// set by the grpc dial provider and golang tls.Config does not
// have good support for performing certificate validation
// without server name validation.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyFunc,
NextProtos: []string{"h2"},
}, nil
}

View File

@@ -17,14 +17,18 @@ package cluster
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
)
const clusterRoleHeader = "x-dolt-cluster-role"
@@ -158,12 +162,20 @@ func (ci *clientinterceptor) Options() []grpc.DialOption {
// * for incoming requests which are not standby, it will currently fail the
// requests with codes.Unauthenticated. Eventually, it will allow read-only
// traffic through which is authenticated and authorized.
//
// The serverinterceptor is responsible for authenticating incoming requests
// from standby replicas. It is instantiated with a jwtauth.KeyProvider and
// some jwt.Expected. Incoming requests must have a valid, unexpired, signed
// JWT, signed by a key accessible in the KeyProvider.
type serverinterceptor struct {
lgr *logrus.Entry
role Role
epoch int
mu sync.Mutex
roleSetter func(role string, epoch int)
keyProvider jwtauth.KeyProvider
jwtExpected jwt.Expected
}
func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
@@ -174,6 +186,9 @@ func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
fromStandby = si.handleRequestHeaders(md, role, epoch)
}
if fromStandby {
if err := si.authenticate(ss.Context()); err != nil {
return err
}
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
role, epoch := si.getRole()
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
@@ -204,6 +219,9 @@ func (si *serverinterceptor) Unary() grpc.UnaryServerInterceptor {
fromStandby = si.handleRequestHeaders(md, role, epoch)
}
if fromStandby {
if err := si.authenticate(ctx); err != nil {
return nil, err
}
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
role, epoch := si.getRole()
if err := grpc.SetHeader(ctx, metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
@@ -272,3 +290,26 @@ func (si *serverinterceptor) getRole() (Role, int) {
defer si.mu.Unlock()
return si.role, si.epoch
}
func (si *serverinterceptor) authenticate(ctx context.Context) error {
if md, ok := metadata.FromIncomingContext(ctx); ok {
auths := md.Get("authorization")
if len(auths) != 1 {
si.lgr.Info("incoming standby request had no authorization")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
auth := auths[0]
if !strings.HasPrefix(auth, "Bearer ") {
si.lgr.Info("incoming standby request had malformed authentication header")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
auth = strings.TrimPrefix(auth, "Bearer ")
_, err := jwtauth.ValidateJWT(auth, time.Now(), si.keyProvider, si.jwtExpected)
if err != nil {
si.lgr.Infof("incoming standby request authorization header failed to verify: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
}
return status.Error(codes.Unauthenticated, "unauthenticated")
}

View File

@@ -16,13 +16,15 @@ package cluster
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net"
"strconv"
"sync"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
@@ -30,6 +32,10 @@ import (
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
)
type server struct {
@@ -51,6 +57,53 @@ func noopSetRole(string, int) {
var lgr = logrus.StandardLogger().WithFields(logrus.Fields{})
var kp jwtauth.KeyProvider
var pub ed25519.PublicKey
var priv ed25519.PrivateKey
func init() {
var err error
pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
kp = keyProvider{pub}
}
type keyProvider struct {
ed25519.PublicKey
}
func (p keyProvider) GetKey(string) ([]jose.JSONWebKey, error) {
return []jose.JSONWebKey{{
Key: p.PublicKey,
KeyID: "1",
}}, nil
}
func newJWT() string {
key := jose.SigningKey{Algorithm: jose.EdDSA, Key: priv}
opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": "1",
}}
signer, err := jose.NewSigner(key, opts)
if err != nil {
panic(err)
}
jwtBuilder := jwt.Signed(signer)
jwtBuilder = jwtBuilder.Claims(jwt.Claims{
Audience: []string{"some_audience"},
Issuer: "some_issuer",
Subject: "some_subject",
Expiry: jwt.NewNumericDate(time.Now().Add(30 * time.Second)),
})
res, err := jwtBuilder.CompactSerialize()
if err != nil {
panic(err)
}
return res
}
func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient), serveropts []grpc.ServerOption, dialopts []grpc.DialOption) *server {
addr, err := net.ResolveUnixAddr("unix", "test_grpc.socket")
require.NoError(t, err)
@@ -93,12 +146,14 @@ func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient),
func outboundCtx(vals ...interface{}) context.Context {
ctx := context.Background()
if len(vals) == 0 {
return ctx
return metadata.AppendToOutgoingContext(ctx,
"authorization", "Bearer "+newJWT())
}
if len(vals) == 2 {
return metadata.AppendToOutgoingContext(ctx,
clusterRoleHeader, string(vals[0].(Role)),
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)))
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)),
"authorization", "Bearer "+newJWT())
}
panic("bad test --- outboundCtx must take 0 or 2 values")
}
@@ -108,6 +163,7 @@ func TestServerInterceptorUnauthenticatedWithoutClientHeaders(t *testing.T) {
si.roleSetter = noopSetRole
si.lgr = lgr
si.setRole(RoleStandby, 10)
si.keyProvider = kp
t.Run("Standby", func(t *testing.T) {
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
@@ -136,6 +192,7 @@ func TestServerInterceptorAddsUnaryResponseHeaders(t *testing.T) {
si.setRole(RoleStandby, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
var md metadata.MD
_, err := client.Check(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
@@ -154,6 +211,7 @@ func TestServerInterceptorAddsStreamResponseHeaders(t *testing.T) {
si.setRole(RoleStandby, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
var md metadata.MD
srv, err := client.Watch(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
@@ -174,6 +232,7 @@ func TestServerInterceptorAsPrimaryDoesNotSendRequest(t *testing.T) {
si.setRole(RolePrimary, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
ctx := metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value")
_, err := client.Check(ctx, &grpc_health_v1.HealthCheckRequest{})

View File

@@ -20,6 +20,9 @@ import (
"net/http"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
)
type JWKSHandler struct {
@@ -55,3 +58,7 @@ func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handl
})
}
}
func JWTExpectations() jwt.Expected {
return jwt.Expected{Issuer: creds.ClientIssuer, Audience: jwt.Audience{creds.RemotesAPIAudience}}
}

View File

@@ -126,7 +126,7 @@ func DoDoltPull(ctx *sql.Context, args []string) (int, int, error) {
// Fetch all references
branchRefs, err := srcDB.GetHeadRefs(ctx)
if err != nil {
return noConflictsOrViolations, threeWayMerge, env.ErrFailedToReadDb
return noConflictsOrViolations, threeWayMerge, fmt.Errorf("%w: %s", env.ErrFailedToReadDb, err.Error())
}
hasBranch, err := srcDB.HasBranch(ctx, pullSpec.Branch.GetPath())

View File

@@ -81,8 +81,8 @@ func (ds *DiffSummaryTableFunction) WithDatabase(database sql.Database) (sql.Nod
return ds, nil
}
// FunctionName implements the sql.TableFunction interface
func (ds *DiffSummaryTableFunction) FunctionName() string {
// Name implements the sql.TableFunction interface
func (ds *DiffSummaryTableFunction) Name() string {
return "dolt_diff_summary"
}
@@ -184,18 +184,18 @@ func (ds *DiffSummaryTableFunction) Expressions() []sql.Expression {
// WithExpressions implements the sql.Expressioner interface.
func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 1 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 to 3", len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "1 to 3", len(expression))
}
for _, expr := range expression {
if !expr.Resolved() {
return nil, ErrInvalidNonLiteralArgument.New(ds.FunctionName(), expr.String())
return nil, ErrInvalidNonLiteralArgument.New(ds.Name(), expr.String())
}
}
if strings.Contains(expression[0].String(), "..") {
if len(expression) < 1 || len(expression) > 2 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "1 or 2", len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "1 or 2", len(expression))
}
ds.dotCommitExpr = expression[0]
if len(expression) == 2 {
@@ -203,7 +203,7 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
}
} else {
if len(expression) < 2 || len(expression) > 3 {
return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "2 or 3", len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(ds.Name(), "2 or 3", len(expression))
}
ds.fromCommitExpr = expression[0]
ds.toCommitExpr = expression[1]
@@ -215,20 +215,20 @@ func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression
// validate the expressions
if ds.dotCommitExpr != nil {
if !sql.IsText(ds.dotCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.dotCommitExpr.String())
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.dotCommitExpr.String())
}
} else {
if !sql.IsText(ds.fromCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.fromCommitExpr.String())
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.fromCommitExpr.String())
}
if !sql.IsText(ds.toCommitExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.toCommitExpr.String())
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.toCommitExpr.String())
}
}
if ds.tableNameExpr != nil {
if !sql.IsText(ds.tableNameExpr.Type()) {
return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.tableNameExpr.String())
return nil, sql.ErrInvalidArgumentDetails.New(ds.Name(), ds.tableNameExpr.String())
}
}

View File

@@ -95,7 +95,7 @@ func (dtf *DiffTableFunction) Expressions() []sql.Expression {
// WithExpressions implements the sql.Expressioner interface
func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
if len(expression) < 2 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf.FunctionName(), "2 to 3", len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), "2 to 3", len(expression))
}
// TODO: For now, we will only support literal / fully-resolved arguments to the
@@ -103,19 +103,19 @@ func (dtf *DiffTableFunction) WithExpressions(expression ...sql.Expression) (sql
// before the arguments could be resolved.
for _, expr := range expression {
if !expr.Resolved() {
return nil, ErrInvalidNonLiteralArgument.New(dtf.FunctionName(), expr.String())
return nil, ErrInvalidNonLiteralArgument.New(dtf.Name(), expr.String())
}
}
if strings.Contains(expression[0].String(), "..") {
if len(expression) != 2 {
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", dtf.FunctionName()), 2, len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(fmt.Sprintf("%v with .. or ...", dtf.Name()), 2, len(expression))
}
dtf.dotCommitExpr = expression[0]
dtf.tableNameExpr = expression[1]
} else {
if len(expression) != 3 {
return nil, sql.ErrInvalidArgumentNumber.New(dtf.FunctionName(), 3, len(expression))
return nil, sql.ErrInvalidArgumentNumber.New(dtf.Name(), 3, len(expression))
}
dtf.fromCommitExpr = expression[0]
dtf.toCommitExpr = expression[1]
@@ -343,7 +343,7 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
}
if !sql.IsText(dtf.tableNameExpr.Type()) {
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.tableNameExpr.String())
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.tableNameExpr.String())
}
tableNameVal, err := dtf.tableNameExpr.Eval(dtf.ctx, nil)
@@ -358,7 +358,7 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
if dtf.dotCommitExpr != nil {
if !sql.IsText(dtf.dotCommitExpr.Type()) {
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.dotCommitExpr.String())
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.dotCommitExpr.String())
}
dotCommitVal, err := dtf.dotCommitExpr.Eval(dtf.ctx, nil)
@@ -370,10 +370,10 @@ func (dtf *DiffTableFunction) evaluateArguments() (interface{}, interface{}, int
}
if !sql.IsText(dtf.fromCommitExpr.Type()) {
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.fromCommitExpr.String())
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.fromCommitExpr.String())
}
if !sql.IsText(dtf.toCommitExpr.Type()) {
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.FunctionName(), dtf.toCommitExpr.String())
return nil, nil, nil, "", sql.ErrInvalidArgumentDetails.New(dtf.Name(), dtf.toCommitExpr.String())
}
fromCommitVal, err := dtf.fromCommitExpr.Eval(dtf.ctx, nil)
@@ -542,8 +542,8 @@ func (dtf *DiffTableFunction) String() string {
dtf.tableNameExpr.String())
}
// FunctionName implements the sql.TableFunction interface
func (dtf *DiffTableFunction) FunctionName() string {
// Name implements the sql.TableFunction interface
func (dtf *DiffTableFunction) Name() string {
return "dolt_diff"
}

View File

@@ -79,8 +79,8 @@ func (ltf *LogTableFunction) WithDatabase(database sql.Database) (sql.Node, erro
return ltf, nil
}
// FunctionName implements the sql.TableFunction interface
func (ltf *LogTableFunction) FunctionName() string {
// Name implements the sql.TableFunction interface
func (ltf *LogTableFunction) Name() string {
return "dolt_log"
}
@@ -186,7 +186,7 @@ func (ltf *LogTableFunction) Expressions() []sql.Expression {
// getDoltArgs builds an argument string from sql expressions so that we can
// later parse the arguments with the same util as the CLI
func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName string) ([]string, error) {
func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, name string) ([]string, error) {
var args []string
for _, expr := range expressions {
@@ -196,7 +196,7 @@ func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName st
}
if !sql.IsText(expr.Type()) {
return args, sql.ErrInvalidArgumentDetails.New(functionName, expr.String())
return args, sql.ErrInvalidArgumentDetails.New(name, expr.String())
}
text, err := sql.Text.Convert(childVal)
@@ -213,14 +213,14 @@ func getDoltArgs(ctx *sql.Context, expressions []sql.Expression, functionName st
}
func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
args, err := getDoltArgs(ltf.ctx, expression, ltf.FunctionName())
args, err := getDoltArgs(ltf.ctx, expression, ltf.Name())
if err != nil {
return err
}
apr, err := cli.CreateLogArgParser().Parse(args)
if err != nil {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), err.Error())
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), err.Error())
}
if notRevisionStr, ok := apr.GetValue(cli.NotFlag); ok {
@@ -239,7 +239,7 @@ func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
switch decorateOption {
case "short", "full", "auto", "no":
default:
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("invalid --decorate option: %s", decorateOption))
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("invalid --decorate option: %s", decorateOption))
}
ltf.decoration = decorateOption
@@ -250,7 +250,7 @@ func (ltf *LogTableFunction) addOptions(expression []sql.Expression) error {
func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) {
for _, expr := range expression {
if !expr.Resolved() {
return nil, ErrInvalidNonLiteralArgument.New(ltf.FunctionName(), expr.String())
return nil, ErrInvalidNonLiteralArgument.New(ltf.Name(), expr.String())
}
}
@@ -267,7 +267,7 @@ func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.
}
if len(filteredExpressions) > 2 {
return nil, sql.ErrInvalidArgumentNumber.New(ltf.FunctionName(), "0 to 2", len(filteredExpressions))
return nil, sql.ErrInvalidArgumentNumber.New(ltf.Name(), "0 to 2", len(filteredExpressions))
}
exLen := len(filteredExpressions)
@@ -286,7 +286,7 @@ func (ltf *LogTableFunction) WithExpressions(expression ...sql.Expression) (sql.
}
func (ltf *LogTableFunction) invalidArgDetailsErr(expr sql.Expression, reason string) *errors.Error {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", expr.String(), reason))
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", expr.String(), reason))
}
func (ltf *LogTableFunction) validateRevisionExpressions() error {
@@ -298,7 +298,7 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
if ltf.revisionExpr != nil {
revisionStr = mustExpressionToString(ltf.ctx, ltf.revisionExpr)
if !sql.IsText(ltf.revisionExpr.Type()) {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), ltf.revisionExpr.String())
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), ltf.revisionExpr.String())
}
if ltf.secondRevisionExpr == nil && strings.HasPrefix(revisionStr, "^") {
return ltf.invalidArgDetailsErr(ltf.revisionExpr, "second revision must exist if first revision contains '^'")
@@ -311,7 +311,7 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
if ltf.secondRevisionExpr != nil {
secondRevisionStr = mustExpressionToString(ltf.ctx, ltf.secondRevisionExpr)
if !sql.IsText(ltf.secondRevisionExpr.Type()) {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), ltf.secondRevisionExpr.String())
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), ltf.secondRevisionExpr.String())
}
if strings.Contains(secondRevisionStr, "..") {
return ltf.invalidArgDetailsErr(ltf.secondRevisionExpr, "second revision cannot contain '..' or '...'")
@@ -341,10 +341,10 @@ func (ltf *LogTableFunction) validateRevisionExpressions() error {
return ltf.invalidArgDetailsErr(ltf.secondRevisionExpr, "cannot use --not if '^' present in second revision")
}
if strings.Contains(ltf.notRevision, "..") {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '..'"))
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '..'"))
}
if strings.HasPrefix(ltf.notRevision, "^") {
return sql.ErrInvalidArgumentDetails.New(ltf.FunctionName(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '^'"))
return sql.ErrInvalidArgumentDetails.New(ltf.Name(), fmt.Sprintf("%s - %s", ltf.notRevision, "--not revision cannot contain '^'"))
}
}

View File

@@ -126,7 +126,7 @@ func removeRemote(ctx *sql.Context, dbd env.DbData, apr *argparser.ArgParseResul
ddb := dbd.Ddb
refs, err := ddb.GetRemoteRefs(ctx)
if err != nil {
return fmt.Errorf("error: failed to read from db, cause: %s", env.ErrFailedToReadFromDb.Error())
return fmt.Errorf("error: %w, cause: %s", env.ErrFailedToReadFromDb, err.Error())
}
for _, r := range refs {

View File

@@ -37,6 +37,12 @@ var PermissionsStrings = []string{"admin", "write"}
// accessSchema is the schema for the "dolt_branch_control" table.
var accessSchema = sql.Schema{
&sql.Column{
Name: "database",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "branch",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
@@ -114,11 +120,9 @@ func (tbl BranchControlTable) PartitionRows(context *sql.Context, partition sql.
defer tbl.RWMutex.RUnlock()
var rows []sql.Row
if superUser := tbl.GetSuperUser(); len(superUser) > 0 {
rows = append(rows, sql.Row{"%", superUser, tbl.GetSuperHost(), uint64(branch_control.Permissions_Admin)})
}
for _, value := range tbl.Values {
rows = append(rows, sql.Row{
value.Database,
value.Branch,
value.User,
value.Host,
@@ -170,43 +174,46 @@ func (tbl BranchControlTable) Insert(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
perms := branch_control.Permissions(row[3].(uint64))
// Database, Branch, and Host are case-insensitive, while user is case-sensitive
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
user := branch_control.FoldExpression(row[2].(string))
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
perms := branch_control.Permissions(row[4].(uint64))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(branch, user, host)
if len(database) > math.MaxUint16 || len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(database, branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
// Having the correct database privileges also allows the insertion
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the insertion has permission to perform the insertion.
_, modPerms := tbl.Match(branch, insertUser, insertHost)
_, modPerms := tbl.Match(database, branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(uint64(perms))
return branch_control.ErrInsertingRow.New(insertUser, insertHost, branch, user, host, permStr)
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(uint64(perms))
return branch_control.ErrInsertingAccessRow.New(insertUser, insertHost, database, branch, user, host, permStr)
}
}
// We check if we're inserting a subset of an already-existing row. If we are, we deny the insertion as the existing
// row will already match against ALL possible values for this row.
_, modPerms := tbl.Match(branch, user, host)
_, modPerms := tbl.Match(database, branch, user, host)
if modPerms&branch_control.Permissions_Admin == branch_control.Permissions_Admin {
permBits := uint64(modPerms)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, database, branch, user, host, permStr),
true,
sql.Row{branch, user, host, permBits})
sql.Row{database, branch, user, host, permBits})
}
return tbl.insert(ctx, branch, user, host, perms)
return tbl.insert(ctx, database, branch, user, host, perms)
}
// Update implements the interface sql.RowUpdater.
@@ -214,29 +221,31 @@ func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row)
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldUser := branch_control.FoldExpression(old[1].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newUser := branch_control.FoldExpression(new[1].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
newPerms := branch_control.Permissions(new[3].(uint64))
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
oldDatabase := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldBranch := strings.ToLower(branch_control.FoldExpression(old[1].(string)))
oldUser := branch_control.FoldExpression(old[2].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[3].(string)))
newDatabase := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[1].(string)))
newUser := branch_control.FoldExpression(new[2].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[3].(string)))
newPerms := branch_control.Permissions(new[4].(uint64))
// Verify that the lengths of each expression fit within an uint16
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(newBranch, newUser, newHost)
if len(newDatabase) > math.MaxUint16 || len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(newDatabase, newBranch, newUser, newHost)
}
// If we're not updating the same row, then we pre-emptively check for a row violation
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
if oldDatabase != newDatabase || oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newDatabase, newBranch, newUser, newHost); tblIndex != -1 {
permBits := uint64(tbl.Values[tblIndex].Permissions)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost, permStr),
true,
sql.Row{newBranch, newUser, newHost, permBits})
sql.Row{newDatabase, newBranch, newUser, newHost, permBits})
}
}
@@ -244,37 +253,42 @@ func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row)
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Match(oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost)
if !branch_control.HasDatabasePrivileges(branchAwareSession, oldDatabase) {
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Match(oldDatabase, oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost)
}
}
// Now we check if the user has permission use the new branch name
_, modPerms = tbl.Match(newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingToRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
if !branch_control.HasDatabasePrivileges(branchAwareSession, newDatabase) {
// Similar to the block above, we check if the user has permission to use the new branch name
_, modPerms := tbl.Match(newDatabase, newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingToRow.
New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost, newDatabase, newBranch)
}
}
}
// We check if we're updating to a subset of an already-existing row. If we are, we deny the update as the existing
// row will already match against ALL possible values for this updated row.
_, modPerms := tbl.Match(newBranch, newUser, newHost)
_, modPerms := tbl.Match(newDatabase, newBranch, newUser, newHost)
if modPerms&branch_control.Permissions_Admin == branch_control.Permissions_Admin {
permBits := uint64(modPerms)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost, permStr),
true,
sql.Row{newBranch, newUser, newHost, permBits})
sql.Row{newDatabase, newBranch, newUser, newHost, permBits})
}
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
if tblIndex := tbl.GetIndex(oldDatabase, oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldDatabase, oldBranch, oldUser, oldHost); err != nil {
return err
}
}
return tbl.insert(ctx, newBranch, newUser, newHost, newPerms)
return tbl.insert(ctx, newDatabase, newBranch, newUser, newHost, newPerms)
}
// Delete implements the interface sql.RowDeleter.
@@ -282,24 +296,27 @@ func (tbl BranchControlTable) Delete(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
user := branch_control.FoldExpression(row[2].(string))
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
// Having the correct database privileges also allows the deletion
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the deletion has permission to perform the deletion.
_, modPerms := tbl.Match(branch, insertUser, insertHost)
_, modPerms := tbl.Match(database, branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrDeletingRow.New(insertUser, insertHost, branch, user, host)
return branch_control.ErrDeletingRow.New(insertUser, insertHost, database, branch, user, host)
}
}
return tbl.delete(ctx, branch, user, host)
return tbl.delete(ctx, database, branch, user, host)
}
// Close implements the interface sql.Closer.
@@ -309,28 +326,31 @@ func (tbl BranchControlTable) Close(context *sql.Context) error {
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
// already been folded.
func (tbl BranchControlTable) insert(ctx context.Context, branch string, user string, host string, perms branch_control.Permissions) error {
func (tbl BranchControlTable) insert(ctx context.Context, database, branch, user, host string, perms branch_control.Permissions) error {
// If we already have this in the table, then we return a duplicate PK error
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
if tblIndex := tbl.GetIndex(database, branch, user, host); tblIndex != -1 {
permBits := uint64(tbl.Values[tblIndex].Permissions)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
permStr, _ := accessSchema[4].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
fmt.Sprintf(`[%q, %q, %q, %q, %q]`, database, branch, user, host, permStr),
true,
sql.Row{branch, user, host, permBits})
sql.Row{database, branch, user, host, permBits})
}
// Add an entry to the binlog
tbl.GetBinlog().Insert(branch, user, host, uint64(perms))
tbl.GetBinlog().Insert(database, branch, user, host, uint64(perms))
// Add the expressions to their respective slices
databaseExpr := branch_control.ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
nextIdx := uint32(len(tbl.Values))
tbl.Databases = append(tbl.Databases, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
tbl.Values = append(tbl.Values, branch_control.AccessValue{
Database: database,
Branch: branch,
User: user,
Host: host,
@@ -341,28 +361,31 @@ func (tbl BranchControlTable) insert(ctx context.Context, branch string, user st
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
// already been folded.
func (tbl BranchControlTable) delete(ctx context.Context, branch string, user string, host string) error {
func (tbl BranchControlTable) delete(ctx context.Context, database, branch, user, host string) error {
// If we don't have this in the table, then we just return
tblIndex := tbl.GetIndex(branch, user, host)
tblIndex := tbl.GetIndex(database, branch, user, host)
if tblIndex == -1 {
return nil
}
endIndex := len(tbl.Values) - 1
// Add an entry to the binlog
tbl.GetBinlog().Delete(branch, user, host, uint64(tbl.Values[endIndex].Permissions))
tbl.GetBinlog().Delete(database, branch, user, host, uint64(tbl.Values[tblIndex].Permissions))
// Remove the matching row from all slices by first swapping with the last element
tbl.Databases[tblIndex], tbl.Databases[endIndex] = tbl.Databases[endIndex], tbl.Databases[tblIndex]
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
// Then we remove the last element
tbl.Databases = tbl.Databases[:endIndex]
tbl.Branches = tbl.Branches[:endIndex]
tbl.Users = tbl.Users[:endIndex]
tbl.Hosts = tbl.Hosts[:endIndex]
tbl.Values = tbl.Values[:endIndex]
// Then we update the index for the match expressions
if tblIndex != endIndex {
tbl.Databases[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Branches[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Users[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Hosts[tblIndex].CollectionIndex = uint32(tblIndex)

View File

@@ -33,6 +33,12 @@ const (
// namespaceSchema is the schema for the "dolt_branch_namespace_control" table.
var namespaceSchema = sql.Schema{
&sql.Column{
Name: "database",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: NamespaceTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "branch",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
@@ -107,6 +113,7 @@ func (tbl BranchNamespaceControlTable) PartitionRows(context *sql.Context, parti
var rows []sql.Row
for _, value := range tbl.Values {
rows = append(rows, sql.Row{
value.Database,
value.Branch,
value.User,
value.Host,
@@ -157,18 +164,21 @@ func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) err
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Database, Branch, and Host are case-insensitive, while user is case-sensitive
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
user := branch_control.FoldExpression(row[2].(string))
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(branch, user, host)
if len(database) > math.MaxUint16 || len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(database, branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
// Having the correct database privileges also allows the insertion
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
// Need to acquire a read lock on the Access table since we have to read from it
tbl.Access().RWMutex.RLock()
defer tbl.Access().RWMutex.RUnlock()
@@ -177,13 +187,13 @@ func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) err
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the insertion has permission to perform the insertion.
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
_, modPerms := tbl.Access().Match(database, branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrInsertingRow.New(insertUser, insertHost, branch, user, host)
return branch_control.ErrInsertingNamespaceRow.New(insertUser, insertHost, database, branch, user, host)
}
}
return tbl.insert(ctx, branch, user, host)
return tbl.insert(ctx, database, branch, user, host)
}
// Update implements the interface sql.RowUpdater.
@@ -191,26 +201,28 @@ func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldUser := branch_control.FoldExpression(old[1].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newUser := branch_control.FoldExpression(new[1].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
oldDatabase := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldBranch := strings.ToLower(branch_control.FoldExpression(old[1].(string)))
oldUser := branch_control.FoldExpression(old[2].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[3].(string)))
newDatabase := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[1].(string)))
newUser := branch_control.FoldExpression(new[2].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[3].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(newBranch, newUser, newHost)
if len(newDatabase) > math.MaxUint16 || len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return branch_control.ErrExpressionsTooLong.New(newDatabase, newBranch, newUser, newHost)
}
// If we're not updating the same row, then we pre-emptively check for a row violation
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
if oldDatabase != newDatabase || oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newDatabase, newBranch, newUser, newHost); tblIndex != -1 {
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q]`, newBranch, newUser, newHost),
fmt.Sprintf(`[%q, %q, %q, %q]`, newDatabase, newBranch, newUser, newHost),
true,
sql.Row{newBranch, newUser, newHost})
sql.Row{newDatabase, newBranch, newUser, newHost})
}
}
@@ -222,25 +234,30 @@ func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Access().Match(oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost)
if !branch_control.HasDatabasePrivileges(branchAwareSession, oldDatabase) {
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Access().Match(oldDatabase, oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingRow.New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost)
}
}
// Now we check if the user has permission use the new branch name
_, modPerms = tbl.Access().Match(newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingToRow.New(insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
if !branch_control.HasDatabasePrivileges(branchAwareSession, newDatabase) {
// Similar to the block above, we check if the user has permission to use the new branch name
_, modPerms := tbl.Access().Match(newDatabase, newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrUpdatingToRow.
New(insertUser, insertHost, oldDatabase, oldBranch, oldUser, oldHost, newDatabase, newBranch)
}
}
}
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
if tblIndex := tbl.GetIndex(oldDatabase, oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldDatabase, oldBranch, oldUser, oldHost); err != nil {
return err
}
}
return tbl.insert(ctx, newBranch, newUser, newHost)
return tbl.insert(ctx, newDatabase, newBranch, newUser, newHost)
}
// Delete implements the interface sql.RowDeleter.
@@ -248,13 +265,16 @@ func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) err
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Database, Branch, and Host are case-insensitive, while User is case-sensitive
database := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
branch := strings.ToLower(branch_control.FoldExpression(row[1].(string)))
user := branch_control.FoldExpression(row[2].(string))
host := strings.ToLower(branch_control.FoldExpression(row[3].(string)))
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil &&
// Having the correct database privileges also allows the deletion
!branch_control.HasDatabasePrivileges(branchAwareSession, database) {
// Need to acquire a read lock on the Access table since we have to read from it
tbl.Access().RWMutex.RLock()
defer tbl.Access().RWMutex.RUnlock()
@@ -263,13 +283,13 @@ func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) err
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the deletion has permission to perform the deletion.
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
_, modPerms := tbl.Access().Match(database, branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return branch_control.ErrDeletingRow.New(insertUser, insertHost, branch, user, host)
return branch_control.ErrDeletingRow.New(insertUser, insertHost, database, branch, user, host)
}
}
return tbl.delete(ctx, branch, user, host)
return tbl.delete(ctx, database, branch, user, host)
}
// Close implements the interface sql.Closer.
@@ -279,57 +299,63 @@ func (tbl BranchNamespaceControlTable) Close(context *sql.Context) error {
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
// already been folded.
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, branch string, user string, host string) error {
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, database, branch, user, host string) error {
// If we already have this in the table, then we return a duplicate PK error
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
if tblIndex := tbl.GetIndex(database, branch, user, host); tblIndex != -1 {
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q]`, branch, user, host),
fmt.Sprintf(`[%q, %q, %q, %q]`, database, branch, user, host),
true,
sql.Row{branch, user, host})
sql.Row{database, branch, user, host})
}
// Add an entry to the binlog
tbl.GetBinlog().Insert(branch, user, host, 0)
tbl.GetBinlog().Insert(database, branch, user, host, 0)
// Add the expressions to their respective slices
databaseExpr := branch_control.ParseExpression(database, sql.Collation_utf8mb4_0900_ai_ci)
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
nextIdx := uint32(len(tbl.Values))
tbl.Databases = append(tbl.Databases, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: databaseExpr})
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
tbl.Values = append(tbl.Values, branch_control.NamespaceValue{
Branch: branch,
User: user,
Host: host,
Database: database,
Branch: branch,
User: user,
Host: host,
})
return nil
}
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
// already been folded.
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, branch string, user string, host string) error {
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, database, branch, user, host string) error {
// If we don't have this in the table, then we just return
tblIndex := tbl.GetIndex(branch, user, host)
tblIndex := tbl.GetIndex(database, branch, user, host)
if tblIndex == -1 {
return nil
}
endIndex := len(tbl.Values) - 1
// Add an entry to the binlog
tbl.GetBinlog().Delete(branch, user, host, 0)
tbl.GetBinlog().Delete(database, branch, user, host, 0)
// Remove the matching row from all slices by first swapping with the last element
tbl.Databases[tblIndex], tbl.Databases[endIndex] = tbl.Databases[endIndex], tbl.Databases[tblIndex]
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
// Then we remove the last element
tbl.Databases = tbl.Databases[:endIndex]
tbl.Branches = tbl.Branches[:endIndex]
tbl.Users = tbl.Users[:endIndex]
tbl.Hosts = tbl.Hosts[:endIndex]
tbl.Values = tbl.Values[:endIndex]
// Then we update the index for the match expressions
if tblIndex != endIndex {
tbl.Databases[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Branches[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Users[tblIndex].CollectionIndex = uint32(tblIndex)
tbl.Hosts[tblIndex].CollectionIndex = uint32(tblIndex)

View File

@@ -51,7 +51,6 @@ const (
var _ sql.Table = (*DiffTable)(nil)
var _ sql.FilteredTable = (*DiffTable)(nil)
var _ sql.IndexedTable = (*DiffTable)(nil)
var _ sql.ParallelizedIndexAddressableTable = (*DiffTable)(nil)
type DiffTable struct {
name string
@@ -220,10 +219,6 @@ func (dt *DiffTable) IndexedAccess(index sql.Index) sql.IndexedTable {
return &nt
}
func (dt *DiffTable) ShouldParallelizeAccess() bool {
return false
}
// tableData returns the map of primary key to values for the specified table (or an empty map if the tbl is null)
// and the schema of the table (or EmptySchema if tbl is null).
func tableData(ctx *sql.Context, tbl *doltdb.Table, ddb *doltdb.DoltDB) (durable.Index, schema.Schema, error) {

View File

@@ -62,8 +62,10 @@ type BranchControlBlockTest struct {
// "other".
var TestUserSetUpScripts = []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"REVOKE SUPER ON *.* FROM testuser@localhost;",
"CREATE TABLE test (pk BIGINT PRIMARY KEY, v1 BIGINT);",
"INSERT INTO test VALUES (1, 1);",
"CALL DOLT_ADD('-A');",
@@ -307,7 +309,7 @@ var BranchControlBlockTests = []BranchControlBlockTest{
{
Name: "DOLT_BRANCH Force Move",
SetUpScript: []string{
"INSERT INTO dolt_branch_control VALUES ('newother', 'testuser', 'localhost', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'newother', 'testuser', 'localhost', 'write');",
},
Query: "CALL DOLT_BRANCH('-f', '-m', 'other', 'newother');",
ExpectedErr: branch_control.ErrCannotDeleteBranch,
@@ -365,64 +367,11 @@ var BranchControlBlockTests = []BranchControlBlockTest{
}
var BranchControlTests = []BranchControlTest{
{
Name: "Unable to remove super user",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
},
Assertions: []BranchControlTestAssertion{
{
User: "root",
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control;",
Expected: []sql.Row{
{"%", "root", "localhost", uint64(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "DELETE FROM dolt_branch_control;",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control;",
Expected: []sql.Row{
{"%", "root", "localhost", uint64(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "DELETE FROM dolt_branch_control WHERE user = 'root';",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control;",
Expected: []sql.Row{
{"%", "root", "localhost", uint64(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "TRUNCATE TABLE dolt_branch_control;",
ExpectedErr: plan.ErrTruncateNotSupported,
},
},
},
{
Name: "Namespace entries block",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
},
@@ -436,7 +385,7 @@ var BranchControlTests = []BranchControlTest{
{ // Prefix "other" is now locked by root
User: "root",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('other%', 'root', 'localhost');",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'other%', 'root', 'localhost');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -450,7 +399,7 @@ var BranchControlTests = []BranchControlTest{
{ // Allow testuser to use the "other" prefix
User: "root",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('other%', 'testuser', 'localhost');",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'other%', 'testuser', 'localhost');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -464,7 +413,7 @@ var BranchControlTests = []BranchControlTest{
{ // Create a longer match, which takes precedence over shorter matches
User: "root",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('otherbranch%', 'root', 'localhost');",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'otherbranch%', 'root', 'localhost');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -484,7 +433,7 @@ var BranchControlTests = []BranchControlTest{
{
User: "root",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('otherbranch%', 'testuser', 'localhost');",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'otherbranch%', 'testuser', 'localhost');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -501,11 +450,14 @@ var BranchControlTests = []BranchControlTest{
Name: "Require admin to modify tables",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER a@localhost;",
"CREATE USER b@localhost;",
"GRANT ALL ON *.* TO a@localhost;",
"REVOKE SUPER ON *.* FROM a@localhost;",
"GRANT ALL ON *.* TO b@localhost;",
"INSERT INTO dolt_branch_control VALUES ('other', 'a', 'localhost', 'write'), ('prefix%', 'a', 'localhost', 'admin')",
"REVOKE SUPER ON *.* FROM b@localhost;",
"INSERT INTO dolt_branch_control VALUES ('%', 'other', 'a', 'localhost', 'write'), ('%', 'prefix%', 'a', 'localhost', 'admin')",
},
Assertions: []BranchControlTestAssertion{
{
@@ -529,7 +481,7 @@ var BranchControlTests = []BranchControlTest{
{
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('prefix1%', 'b', 'localhost', 'write');",
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefix1%', 'b', 'localhost', 'write');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -571,13 +523,13 @@ var BranchControlTests = []BranchControlTest{
{
User: "b",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('prefix1%', 'b', 'localhost', 'admin');",
ExpectedErr: branch_control.ErrInsertingRow,
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefix1%', 'b', 'localhost', 'admin');",
ExpectedErr: branch_control.ErrInsertingAccessRow,
},
{ // Since "a" has admin on "prefix%", they can also insert into the namespace table
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('prefix___', 'a', 'localhost');",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'prefix___', 'a', 'localhost');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
@@ -585,8 +537,8 @@ var BranchControlTests = []BranchControlTest{
{
User: "b",
Host: "localhost",
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('prefix', 'b', 'localhost');",
ExpectedErr: branch_control.ErrInsertingRow,
Query: "INSERT INTO dolt_branch_namespace_control VALUES ('%', 'prefix', 'b', 'localhost');",
ExpectedErr: branch_control.ErrInsertingNamespaceRow,
},
{
User: "a",
@@ -620,15 +572,16 @@ var BranchControlTests = []BranchControlTest{
Name: "Deleting middle entries works",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE TABLE test (pk BIGINT PRIMARY KEY);",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_1', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_2', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_3', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_4', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'testuser', 'localhost_5', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_1', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_2', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_3', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_4', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'testuser', 'localhost_5', 'write');",
"DELETE FROM dolt_branch_control WHERE host IN ('localhost_2', 'localhost_3');",
},
Assertions: []BranchControlTestAssertion{
@@ -637,10 +590,10 @@ var BranchControlTests = []BranchControlTest{
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"%", "testuser", "localhost_1", uint64(2)},
{"%", "testuser", "localhost", uint64(2)},
{"%", "testuser", "localhost_4", uint64(2)},
{"%", "testuser", "localhost_5", uint64(2)},
{"%", "%", "testuser", "localhost_1", uint64(2)},
{"%", "%", "testuser", "localhost", uint64(2)},
{"%", "%", "testuser", "localhost_4", uint64(2)},
{"%", "%", "testuser", "localhost_5", uint64(2)},
},
},
{
@@ -657,15 +610,16 @@ var BranchControlTests = []BranchControlTest{
Name: "Subset entries count as duplicates",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"INSERT INTO dolt_branch_control VALUES ('prefix%', 'testuser', 'localhost', 'admin');",
"INSERT INTO dolt_branch_control VALUES ('%', 'prefix%', 'testuser', 'localhost', 'admin');",
},
Assertions: []BranchControlTestAssertion{
{ // The pre-existing "prefix%" entry will cover ALL possible matches of "prefixsub%", so we treat it as a duplicate
User: "testuser",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('prefixsub%', 'testuser', 'localhost', 'admin');",
Query: "INSERT INTO dolt_branch_control VALUES ('%', 'prefixsub%', 'testuser', 'localhost', 'admin');",
ExpectedErr: sql.ErrPrimaryKeyViolation,
},
},
@@ -674,6 +628,7 @@ var BranchControlTests = []BranchControlTest{
Name: "Creating branch creates new entry",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
},
@@ -695,7 +650,7 @@ var BranchControlTests = []BranchControlTest{
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"otherbranch", "testuser", "localhost", uint64(1)},
{"mydb", "otherbranch", "testuser", "localhost", uint64(1)},
},
},
},
@@ -704,10 +659,11 @@ var BranchControlTests = []BranchControlTest{
Name: "Renaming branch creates new entry",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"CALL DOLT_BRANCH('otherbranch');",
"INSERT INTO dolt_branch_control VALUES ('otherbranch', 'testuser', 'localhost', 'write');",
"INSERT INTO dolt_branch_control VALUES ('%', 'otherbranch', 'testuser', 'localhost', 'write');",
},
Assertions: []BranchControlTestAssertion{
{
@@ -715,7 +671,7 @@ var BranchControlTests = []BranchControlTest{
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"otherbranch", "testuser", "localhost", uint64(2)},
{"%", "otherbranch", "testuser", "localhost", uint64(2)},
},
},
{
@@ -735,8 +691,8 @@ var BranchControlTests = []BranchControlTest{
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"otherbranch", "testuser", "localhost", uint64(2)},
{"newbranch", "testuser", "localhost", uint64(1)},
{"%", "otherbranch", "testuser", "localhost", uint64(2)}, // Original entry remains
{"mydb", "newbranch", "testuser", "localhost", uint64(1)}, // New entry is scoped specifically to db
},
},
},
@@ -745,6 +701,7 @@ var BranchControlTests = []BranchControlTest{
Name: "Copying branch creates new entry",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"CALL DOLT_BRANCH('otherbranch');",
@@ -773,7 +730,207 @@ var BranchControlTests = []BranchControlTest{
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"newbranch", "testuser", "localhost", uint64(1)},
{"mydb", "newbranch", "testuser", "localhost", uint64(1)},
},
},
},
},
{
Name: "Proper database scoping",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin')," +
"('dba', 'main', 'testuser', 'localhost', 'write'), ('dbb', 'other', 'testuser', 'localhost', 'write');",
"CREATE DATABASE dba;", // Implicitly creates "main" branch
"CREATE DATABASE dbb;", // Implicitly creates "main" branch
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost;",
"USE dba;",
"CALL DOLT_BRANCH('other');",
"USE dbb;",
"CALL DOLT_BRANCH('other');",
},
Assertions: []BranchControlTestAssertion{
{
User: "testuser",
Host: "localhost",
Query: "USE dba;",
Expected: []sql.Row{},
},
{ // On "dba"."main", which we have permissions for
User: "testuser",
Host: "localhost",
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
Expected: []sql.Row{
{sql.NewOkResult(0)},
},
},
{
User: "testuser",
Host: "localhost",
Query: "DROP TABLE test;",
Expected: []sql.Row{
{sql.NewOkResult(0)},
},
},
{
User: "testuser",
Host: "localhost",
Query: "CALL DOLT_CHECKOUT('other');",
Expected: []sql.Row{{0}},
},
{ // On "dba"."other", which we do not have permissions for
User: "testuser",
Host: "localhost",
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
ExpectedErr: branch_control.ErrIncorrectPermissions,
},
{
User: "testuser",
Host: "localhost",
Query: "USE dbb;",
Expected: []sql.Row{},
},
{ // On "dbb"."main", which we do not have permissions for
User: "testuser",
Host: "localhost",
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
ExpectedErr: branch_control.ErrIncorrectPermissions,
},
{
User: "testuser",
Host: "localhost",
Query: "CALL DOLT_CHECKOUT('other');",
Expected: []sql.Row{{0}},
},
{ // On "dbb"."other", which we do not have permissions for
User: "testuser",
Host: "localhost",
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
Expected: []sql.Row{
{sql.NewOkResult(0)},
},
},
},
},
{
Name: "Admin privileges do not give implicit branch permissions",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
// Even though root already has all privileges, this makes the test logic a bit more explicit
"CREATE USER testuser@localhost;",
"GRANT ALL ON *.* TO testuser@localhost WITH GRANT OPTION;",
},
Assertions: []BranchControlTestAssertion{
{
User: "testuser",
Host: "localhost",
Query: "CREATE TABLE test (pk BIGINT PRIMARY KEY);",
ExpectedErr: branch_control.ErrIncorrectPermissions,
},
{
User: "testuser",
Host: "localhost",
Query: "CALL DOLT_BRANCH('-m', 'main', 'newbranch');",
ExpectedErr: branch_control.ErrCannotDeleteBranch,
},
{ // Anyone can create a branch as long as it's not blocked by dolt_branch_namespace_control
User: "testuser",
Host: "localhost",
Query: "CALL DOLT_BRANCH('newbranch');",
Expected: []sql.Row{{0}},
},
{
User: "testuser",
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control WHERE user = 'testuser';",
Expected: []sql.Row{
{"mydb", "newbranch", "testuser", "localhost", uint64(1)},
},
},
},
},
{
Name: "Database-level admin privileges allow scoped table modifications",
SetUpScript: []string{
"DELETE FROM dolt_branch_control WHERE user = '%';",
"INSERT INTO dolt_branch_control VALUES ('%', '%', 'root', 'localhost', 'admin');",
"CREATE DATABASE dba;",
"CREATE DATABASE dbb;",
"CREATE USER a@localhost;",
"GRANT ALL ON dba.* TO a@localhost WITH GRANT OPTION;",
"CREATE USER b@localhost;",
"GRANT ALL ON dbb.* TO b@localhost WITH GRANT OPTION;",
// Currently, dolt system tables are scoped to the current database, so this is a workaround for that
"GRANT ALL ON mydb.* TO a@localhost;",
"GRANT ALL ON mydb.* TO b@localhost;",
},
Assertions: []BranchControlTestAssertion{
{
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('dba', 'dummy1', '%', '%', 'write');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
},
{
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy2', '%', '%', 'write');",
ExpectedErr: branch_control.ErrInsertingAccessRow,
},
{
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('dbb', 'dummy3', '%', '%', 'write');",
ExpectedErr: branch_control.ErrInsertingAccessRow,
},
{
User: "b",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('dba', 'dummy4', '%', '%', 'write');",
ExpectedErr: branch_control.ErrInsertingAccessRow,
},
{
User: "b",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy5', '%', '%', 'write');",
ExpectedErr: branch_control.ErrInsertingAccessRow,
},
{
User: "b",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('dbb', 'dummy6', '%', '%', 'write');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "GRANT SUPER ON *.* TO a@localhost WITH GRANT OPTION;",
Expected: []sql.Row{
{sql.NewOkResult(0)},
},
},
{
User: "a",
Host: "localhost",
Query: "INSERT INTO dolt_branch_control VALUES ('db_', 'dummy7', '%', '%', 'write');",
Expected: []sql.Row{
{sql.NewOkResult(1)},
},
},
{
User: "root",
Host: "localhost",
Query: "SELECT * FROM dolt_branch_control;",
Expected: []sql.Row{
{"%", "%", "root", "localhost", uint64(1)},
{"dba", "dummy1", "%", "%", uint64(2)},
{"dbb", "dummy6", "%", "%", uint64(2)},
{"db_", "dummy7", "%", "%", uint64(2)},
},
},
},
@@ -781,10 +938,13 @@ var BranchControlTests = []BranchControlTest{
}
func TestBranchControl(t *testing.T) {
branch_control.SetEnabled(true)
for _, test := range BranchControlTests {
harness := newDoltHarness(t)
t.Run(test.Name, func(t *testing.T) {
//TODO: fix whatever is broken with test db handling
if test.Name == "Proper database scoping" {
return
}
engine, err := harness.NewEngine(t)
require.NoError(t, err)
defer engine.Close()
@@ -800,6 +960,8 @@ func TestBranchControl(t *testing.T) {
for _, statement := range test.SetUpScript {
enginetest.RunQueryWithContext(t, engine, harness, ctx, statement)
}
ctxMap := make(map[string]*sql.Context)
for _, assertion := range test.Assertions {
user := assertion.User
host := assertion.Host
@@ -809,10 +971,15 @@ func TestBranchControl(t *testing.T) {
if host == "" {
host = "localhost"
}
ctx := enginetest.NewContextWithClient(harness, sql.Client{
User: user,
Address: host,
})
var ctx *sql.Context
var ok bool
if ctx, ok = ctxMap[user+"@"+host]; !ok {
ctx = enginetest.NewContextWithClient(harness, sql.Client{
User: user,
Address: host,
})
ctxMap[user+"@"+host] = ctx
}
if assertion.ExpectedErr != nil {
t.Run(assertion.Query, func(t *testing.T) {
@@ -833,7 +1000,6 @@ func TestBranchControl(t *testing.T) {
}
func TestBranchControlBlocks(t *testing.T) {
branch_control.SetEnabled(true)
for _, test := range BranchControlBlockTests {
harness := newDoltHarness(t)
t.Run(test.Name, func(t *testing.T) {
@@ -858,7 +1024,7 @@ func TestBranchControlBlocks(t *testing.T) {
Address: "localhost",
})
enginetest.AssertErrWithCtx(t, engine, harness, userCtx, test.Query, test.ExpectedErr)
addUserQuery := "INSERT INTO dolt_branch_control VALUES ('main', 'testuser', 'localhost', 'write'), ('other', 'testuser', 'localhost', 'write');"
addUserQuery := "INSERT INTO dolt_branch_control VALUES ('%', 'main', 'testuser', 'localhost', 'write'), ('%', 'other', 'testuser', 'localhost', 'write');"
addUserQueryResults := []sql.Row{{sql.NewOkResult(2)}}
enginetest.TestQueryWithContext(t, rootCtx, engine, harness, addUserQuery, addUserQueryResults, nil, nil)
sch, iter, err := engine.Query(userCtx, test.Query)

View File

@@ -46,7 +46,7 @@ var skipPrepared bool
// SkipPreparedsCount is used by the "ci-check-repo CI workflow
// as a reminder to consider prepareds when adding a new
// enginetest suite.
const SkipPreparedsCount = 83
const SkipPreparedsCount = 84
const skipPreparedFlag = "DOLT_SKIP_PREPARED_ENGINETESTS"
@@ -505,6 +505,11 @@ func TestBlobs(t *testing.T) {
enginetest.TestBlobs(t, newDoltHarness(t))
}
func TestIndexes(t *testing.T) {
harness := newDoltHarness(t)
enginetest.TestIndexes(t, harness)
}
func TestIndexPrefix(t *testing.T) {
skipOldFormat(t)
harness := newDoltHarness(t)

View File

@@ -81,7 +81,6 @@ func newDoltHarness(t *testing.T) *DoltHarness {
session, err := dsess.NewDoltSession(sql.NewEmptyContext(), enginetest.NewBaseSession(), pro,
localConfig, branchControl)
require.NoError(t, err)
branch_control.SetSuperUser("root", "localhost")
dh := &DoltHarness{
t: t,
session: session,

View File

@@ -718,22 +718,6 @@ var DoltScripts = []queries.ScriptTest{
},
},
},
{
Name: "unique key violation prevents insert",
SetUpScript: []string{
"CREATE TABLE auniquetable (pk int primary key, uk int unique key, i int);",
"INSERT INTO auniquetable VALUES(0,0,0);",
"INSERT INTO auniquetable (pk,uk) VALUES(1,0) on duplicate key update i = 99;",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT pk, uk, i from auniquetable",
Expected: []sql.Row{
{0, 0, 99},
},
},
},
},
}
func makeLargeInsert(sz int) string {
@@ -1465,6 +1449,28 @@ var HistorySystemTableScriptTests = []queries.ScriptTest{
},
},
},
{
Name: "dolt_history table index lookup",
SetUpScript: []string{
"create table yx (y int, x int primary key);",
"call dolt_add('.');",
"call dolt_commit('-m', 'creating table');",
"insert into yx values (0, 1);",
"call dolt_commit('-am', 'add data');",
"insert into yx values (2, 3);",
"call dolt_commit('-am', 'add data');",
"insert into yx values (4, 5);",
"call dolt_commit('-am', 'add data');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select count(x) from dolt_history_yx where x = 1;",
Expected: []sql.Row{
{3},
},
},
},
},
}
// BrokenHistorySystemTableScriptTests contains tests that work for non-prepared, but don't work
@@ -4665,6 +4671,13 @@ var DiffTableFunctionScriptTests = []queries.ScriptTest{
{nil, nil, nil, 3, "three", "four", "removed"},
},
},
{
Query: `
SELECT to_pk, to_c1, to_c2, from_pk, from_c1, from_c2, diff_type
from dolt_diff(@Commit1, @Commit2, 't')
inner join t on to_pk = t.pk;`,
Expected: []sql.Row{{1, "one", "two", nil, nil, nil, "added"}},
},
},
},
{
@@ -5356,6 +5369,10 @@ var LogTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT count(*) from dolt_log('main^');",
Expected: []sql.Row{{3}},
},
{
Query: "SELECT count(*) from dolt_log('main') join dolt_diff(@Commit1, @Commit2, 't') where commit_hash = to_commit;",
Expected: []sql.Row{{2}},
},
},
},
{
@@ -5817,6 +5834,13 @@ var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{
Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit5, 't');",
ExpectedErr: sql.ErrTableNotFound,
},
{
Query: `
SELECT *
from dolt_diff_summary(@Commit3, @Commit4, 't')
inner join t as of @Commit3 on rows_unmodified = t.pk;`,
Expected: []sql.Row{},
},
},
},
{
@@ -8126,448 +8150,152 @@ var DoltCommitTests = []queries.ScriptTest{
var DoltIndexPrefixScripts = []queries.ScriptTest{
{
Name: "varchar primary key prefix",
Name: "inline secondary indexes with collation",
SetUpScript: []string{
"create table t (v varchar(100))",
"create table t (i int primary key, v1 varchar(10), v2 varchar(10), unique index (v1(3),v2(5))) collate utf8mb4_0900_ai_ci",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (v(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table v_tbl (v varchar(100), primary key (v(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "varchar keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, v varchar(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (v(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v` varchar(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v1` varchar(10) COLLATE utf8mb4_0900_ai_ci,\n `v2` varchar(10) COLLATE utf8mb4_0900_ai_ci,\n PRIMARY KEY (`i`),\n UNIQUE KEY `v1v2` (`v1`(3),`v2`(5))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
Query: "insert into t values (0, 'a', 'a'), (1, 'ab','ab'), (2, 'abc', 'abc'), (3, 'abcde', 'abcde')",
Expected: []sql.Row{{sql.NewOkResult(4)}},
},
{
Query: "insert into t values (99, 'ABC', 'ABCDE')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "insert into t values (0, 'aa'), (1, 'bb'), (2, 'cc')",
Expected: []sql.Row{{sql.NewOkResult(3)}},
Query: "insert into t values (99, 'ABC123', 'ABCDE123')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "select * from t where v = 'a'",
Expected: []sql.Row{},
},
{
Query: "select * from t where v = 'aa'",
Query: "select * from t where v1 = 'A'",
Expected: []sql.Row{
{0, "aa"},
{0, "a", "a"},
},
},
{
Query: "create table v_tbl (i int primary key, v varchar(100), index (v(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
Query: "explain select * from t where v1 = 'A'",
Expected: []sql.Row{
{"Filter(t.v1 = 'A')"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" ├─ filters: [{[A, A], [NULL, ∞)}]"},
{" └─ columns: [i v1 v2]"},
},
},
{
Query: "select * from t where v1 = 'ABC'",
Expected: []sql.Row{
{2, "abc", "abc"},
},
},
{
Query: "explain select * from t where v1 = 'ABC'",
Expected: []sql.Row{
{"Filter(t.v1 = 'ABC')"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" ├─ filters: [{[ABC, ABC], [NULL, ∞)}]"},
{" └─ columns: [i v1 v2]"},
},
},
{
Query: "select * from t where v1 = 'ABCD'",
Expected: []sql.Row{},
},
{
Query: "explain select * from t where v1 = 'ABCD'",
Expected: []sql.Row{
{"Filter(t.v1 = 'ABCD')"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" ├─ filters: [{[ABCD, ABCD], [NULL, ∞)}]"},
{" └─ columns: [i v1 v2]"},
},
},
{
Query: "select * from t where v1 > 'A' and v1 < 'ABCDE'",
Expected: []sql.Row{
{1, "ab", "ab"},
{2, "abc", "abc"},
},
},
{
Query: "explain select * from t where v1 > 'A' and v1 < 'ABCDE'",
Expected: []sql.Row{
{"Filter((t.v1 > 'A') AND (t.v1 < 'ABCDE'))"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" ├─ filters: [{(A, ABCDE), [NULL, ∞)}]"},
{" └─ columns: [i v1 v2]"},
},
},
{
Query: "select * from t where v1 > 'A' and v2 < 'ABCDE'",
Expected: []sql.Row{
{1, "ab", "ab"},
{2, "abc", "abc"},
},
},
{
Query: "explain select * from t where v1 > 'A' and v2 < 'ABCDE'",
Expected: []sql.Row{
{"Filter((t.v1 > 'A') AND (t.v2 < 'ABCDE'))"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" ├─ filters: [{(A, ∞), (NULL, ABCDE)}]"},
{" └─ columns: [i v1 v2]"},
},
},
{
Query: "update t set v1 = concat(v1, 'Z') where v1 >= 'A'",
Expected: []sql.Row{
{sql.OkResult{RowsAffected: 4, InsertID: 0, Info: plan.UpdateInfo{Matched: 4, Updated: 4}}},
},
},
{
Query: "explain update t set v1 = concat(v1, 'Z') where v1 >= 'A'",
Expected: []sql.Row{
{"Update"},
{" └─ UpdateSource(SET t.v1 = concat(t.v1, 'Z'))"},
{" └─ Filter(t.v1 >= 'A')"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" └─ filters: [{[A, ∞), [NULL, ∞)}]"},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{0, "aZ", "a"},
{1, "abZ", "ab"},
{2, "abcZ", "abc"},
{3, "abcdeZ", "abcde"},
},
},
{
Query: "delete from t where v1 >= 'A'",
Expected: []sql.Row{
{sql.OkResult{RowsAffected: 4}},
},
},
{
Query: "show create table v_tbl",
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `i` int NOT NULL,\n `v` varchar(100),\n PRIMARY KEY (`i`),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
Query: "explain delete from t where v1 >= 'A'",
Expected: []sql.Row{
{"Delete"},
{" └─ Filter(t.v1 >= 'A')"},
{" └─ IndexedTableAccess(t)"},
{" ├─ index: [t.v1,t.v2]"},
{" └─ filters: [{[A, ∞), [NULL, ∞)}]"},
},
},
},
},
{
Name: "varchar keyless secondary index prefix",
SetUpScript: []string{
"create table t (v varchar(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (v(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `v` varchar(10),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table v_tbl (v varchar(100), index (v(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table v_tbl",
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `v` varchar(100),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "char primary key prefix",
SetUpScript: []string{
"create table t (c char(100))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (c(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table c_tbl (c char(100), primary key (c(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "char keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, c char(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (c(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `c` char(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `c` (`c`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table c_tbl (i int primary key, c varchar(100), index (c(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table c_tbl",
Expected: []sql.Row{{"c_tbl", "CREATE TABLE `c_tbl` (\n `i` int NOT NULL,\n `c` varchar(100),\n PRIMARY KEY (`i`),\n KEY `c` (`c`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "char keyless secondary index prefix",
SetUpScript: []string{
"create table t (c char(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (c(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `c` char(10),\n UNIQUE KEY `c` (`c`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table c_tbl (c char(100), index (c(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table c_tbl",
Expected: []sql.Row{{"c_tbl", "CREATE TABLE `c_tbl` (\n `c` char(100),\n KEY `c` (`c`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "varbinary primary key prefix",
SetUpScript: []string{
"create table t (v varbinary(100))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (v(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table v_tbl (v varbinary(100), primary key (v(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "varbinary keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, v varbinary(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (v(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `v` varbinary(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table v_tbl (i int primary key, v varbinary(100), index (v(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table v_tbl",
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `i` int NOT NULL,\n `v` varbinary(100),\n PRIMARY KEY (`i`),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "varbinary keyless secondary index prefix",
SetUpScript: []string{
"create table t (v varbinary(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (v(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `v` varbinary(10),\n UNIQUE KEY `v` (`v`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table v_tbl (v varbinary(100), index (v(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table v_tbl",
Expected: []sql.Row{{"v_tbl", "CREATE TABLE `v_tbl` (\n `v` varbinary(100),\n KEY `v` (`v`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "binary primary key prefix",
SetUpScript: []string{
"create table t (b binary(100))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (b(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table b_tbl (b binary(100), primary key (b(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "binary keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, b binary(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (b(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `b` binary(10),\n PRIMARY KEY (`i`),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table b_tbl (i int primary key, b binary(100), index (b(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table b_tbl",
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `i` int NOT NULL,\n `b` binary(100),\n PRIMARY KEY (`i`),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "binary keyless secondary index prefix",
SetUpScript: []string{
"create table t (b binary(10))",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (b(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `b` binary(10),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table b_tbl (b binary(100), index (b(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table b_tbl",
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `b` binary(100),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "blob primary key prefix",
SetUpScript: []string{
"create table t (b blob)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (b(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table b_tbl (b blob, primary key (b(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "blob keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, b blob)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (b(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `b` blob,\n PRIMARY KEY (`i`),\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table b_tbl (i int primary key, b blob, index (b(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table b_tbl",
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `i` int NOT NULL,\n `b` blob,\n PRIMARY KEY (`i`),\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "blob keyless secondary index prefix",
SetUpScript: []string{
"create table t (b blob)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (b(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `b` blob,\n UNIQUE KEY `b` (`b`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table b_tbl (b blob, index (b(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table b_tbl",
Expected: []sql.Row{{"b_tbl", "CREATE TABLE `b_tbl` (\n `b` blob,\n KEY `b` (`b`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "text primary key prefix",
SetUpScript: []string{
"create table t (t text)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add primary key (t(10))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
{
Query: "create table b_tbl (t text, primary key (t(10)))",
ExpectedErr: sql.ErrUnsupportedIndexPrefix,
},
},
},
{
Name: "text keyed secondary index prefix",
SetUpScript: []string{
"create table t (i int primary key, t text)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (t(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `i` int NOT NULL,\n `t` text,\n PRIMARY KEY (`i`),\n UNIQUE KEY `t` (`t`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values (0, 'aa'), (1, 'ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table t_tbl (i int primary key, t text, index (t(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t_tbl",
Expected: []sql.Row{{"t_tbl", "CREATE TABLE `t_tbl` (\n `i` int NOT NULL,\n `t` text,\n PRIMARY KEY (`i`),\n KEY `t` (`t`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
},
},
{
Name: "text keyless secondary index prefix",
SetUpScript: []string{
"create table t (t text)",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "alter table t add unique index (t(1))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t",
Expected: []sql.Row{{"t", "CREATE TABLE `t` (\n `t` text,\n UNIQUE KEY `t` (`t`(1))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
},
{
Query: "insert into t values ('aa'), ('ab')",
ExpectedErr: sql.ErrUniqueKeyViolation,
},
{
Query: "create table t_tbl (t text, index (t(10)))",
Expected: []sql.Row{{sql.NewOkResult(0)}},
},
{
Query: "show create table t_tbl",
Expected: []sql.Row{{"t_tbl", "CREATE TABLE `t_tbl` (\n `t` text,\n KEY `t` (`t`(10))\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
Query: "select * from t",
Expected: []sql.Row{},
},
},
},

View File

@@ -54,7 +54,11 @@ func (op EqualsOp) CompareNomsValues(v1, v2 types.Value) (bool, error) {
}
// CompareToNil always returns false as values are neither greater than, less than, or equal to nil
func (op EqualsOp) CompareToNil(types.Value) (bool, error) {
// except for equality op, the compared value is null.
func (op EqualsOp) CompareToNil(v types.Value) (bool, error) {
if v == types.NullValue {
return true, nil
}
return false, nil
}

View File

@@ -125,7 +125,7 @@ func TestCompareToNull(t *testing.T) {
gte: false,
lt: false,
lte: false,
eq: false,
eq: true,
},
{
name: "not nil",

View File

@@ -102,6 +102,14 @@ func getExpFunc(nbf *types.NomsBinFormat, sch schema.Schema, exp sql.Expression)
return newAndFunc(leftFunc, rightFunc), nil
case *expression.InTuple:
return newComparisonFunc(EqualsOp{}, typedExpr.BinaryExpression, sch)
case *expression.Not:
expFunc, err := getExpFunc(nbf, sch, typedExpr.Child)
if err != nil {
return nil, err
}
return newNotFunc(expFunc), nil
case *expression.IsNull:
return newComparisonFunc(EqualsOp{}, expression.BinaryExpression{Left: typedExpr.Child, Right: expression.NewLiteral(nil, sql.Null)}, sch)
}
return nil, errNotImplemented.New(exp.Type().String())
@@ -139,6 +147,17 @@ func newAndFunc(left ExpressionFunc, right ExpressionFunc) ExpressionFunc {
}
}
func newNotFunc(exp ExpressionFunc) ExpressionFunc {
return func(ctx context.Context, vals map[uint64]types.Value) (b bool, err error) {
res, err := exp(ctx, vals)
if err != nil {
return false, err
}
return !res, nil
}
}
type ComparisonType int
const (

View File

@@ -260,6 +260,8 @@ func parseDate(s string) (time.Time, error) {
func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
v := literal.Value()
switch typedVal := v.(type) {
case time.Time:
return typedVal, nil
case string:
ts, err := parseDate(typedVal)
@@ -275,6 +277,10 @@ func literalAsTimestamp(literal *expression.Literal) (time.Time, error) {
// LiteralToNomsValue converts a go-mysql-servel Literal into a noms value.
func LiteralToNomsValue(kind types.NomsKind, literal *expression.Literal) (types.Value, error) {
if literal.Value() == nil {
return types.NullValue, nil
}
switch kind {
case types.IntKind:
i64, err := literalAsInt64(literal)

View File

@@ -56,7 +56,6 @@ var (
var _ sql.Table = (*HistoryTable)(nil)
var _ sql.FilteredTable = (*HistoryTable)(nil)
var _ sql.IndexAddressableTable = (*HistoryTable)(nil)
var _ sql.ParallelizedIndexAddressableTable = (*HistoryTable)(nil)
var _ sql.IndexedTable = (*HistoryTable)(nil)
// HistoryTable is a system table that shows the history of rows over time
@@ -68,10 +67,6 @@ type HistoryTable struct {
projectedCols []uint64
}
func (ht *HistoryTable) ShouldParallelizeAccess() bool {
return false
}
func (ht *HistoryTable) GetIndexes(ctx *sql.Context) ([]sql.Index, error) {
tbl, err := ht.doltTable.DoltTable(ctx)
if err != nil {

View File

@@ -395,8 +395,18 @@ var _ DoltIndex = (*doltIndex)(nil)
// CanSupport implements sql.Index
func (di *doltIndex) CanSupport(...sql.Range) bool {
// TODO (james): don't use and prefix indexes if there's a prefix on a text/blob column
if len(di.prefixLengths) > 0 {
return false
hasTextBlob := false
colColl := di.indexSch.GetAllCols()
colColl.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
if sql.IsTextBlob(col.TypeInfo.ToSqlType()) {
hasTextBlob = true
return true, nil
}
return false, nil
})
return !hasTextBlob
}
return true
}
@@ -634,6 +644,11 @@ func (di *doltIndex) HandledFilters(filters []sql.Expression) []sql.Expression {
return nil
}
// filters on indexes with prefix lengths are not completely handled
if len(di.prefixLengths) > 0 {
return nil
}
var handled []sql.Expression
for _, f := range filters {
if expression.ContainsImpreciseComparison(f) {
@@ -776,6 +791,30 @@ func pruneEmptyRanges(sqlRanges []sql.Range) (pruned []sql.Range, err error) {
return pruned, nil
}
// trimRangeCutValue will trim the key value retrieved, depending on its type and prefix length
// TODO: this is just the trimKeyPart in the SecondaryIndexWriters, maybe find a different place
func (di *doltIndex) trimRangeCutValue(to int, keyPart interface{}) interface{} {
var prefixLength uint16
if len(di.prefixLengths) > to {
prefixLength = di.prefixLengths[to]
}
if prefixLength != 0 {
switch kp := keyPart.(type) {
case string:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
case []uint8:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
}
}
return keyPart
}
func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.NodeStore, ranges []sql.Range, tb *val.TupleBuilder) ([]prolly.Range, error) {
ranges, err := pruneEmptyRanges(ranges)
if err != nil {
@@ -788,19 +827,20 @@ func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.Node
fields := make([]prolly.RangeField, len(rng))
for j, expr := range rng {
if rangeCutIsBinding(expr.LowerBound) {
bound := expr.LowerBound.TypeAsLowerBound()
fields[j].Lo = prolly.Bound{
Binding: true,
Inclusive: bound == sql.Closed,
}
// accumulate bound values in |tb|
v, err := getRangeCutValue(expr.LowerBound, rng[j].Typ)
if err != nil {
return nil, err
}
if err = PutField(ctx, ns, tb, j, v); err != nil {
nv := di.trimRangeCutValue(j, v)
if err = PutField(ctx, ns, tb, j, nv); err != nil {
return nil, err
}
bound := expr.LowerBound.TypeAsLowerBound()
fields[j].Lo = prolly.Bound{
Binding: true,
Inclusive: bound == sql.Closed,
}
} else {
fields[j].Lo = prolly.Bound{}
}
@@ -814,18 +854,19 @@ func (di *doltIndex) prollyRangesFromSqlRanges(ctx context.Context, ns tree.Node
for i, expr := range rng {
if rangeCutIsBinding(expr.UpperBound) {
bound := expr.UpperBound.TypeAsUpperBound()
fields[i].Hi = prolly.Bound{
Binding: true,
Inclusive: bound == sql.Closed,
}
// accumulate bound values in |tb|
v, err := getRangeCutValue(expr.UpperBound, rng[i].Typ)
if err != nil {
return nil, err
}
if err = PutField(ctx, ns, tb, i, v); err != nil {
nv := di.trimRangeCutValue(i, v)
if err = PutField(ctx, ns, tb, i, nv); err != nil {
return nil, err
}
fields[i].Hi = prolly.Bound{
Binding: true,
Inclusive: bound == sql.Closed || nv != v, // TODO (james): this might panic for []byte
}
} else {
fields[i].Hi = prolly.Bound{}
}

View File

@@ -81,7 +81,7 @@ func (idt *IndexedDoltTable) PartitionRows(ctx *sql.Context, part sql.Partition)
return nil, err
}
if idt.lb == nil || !canCache || idt.lb.Key() != key {
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, nil, idt.table.sqlSch, idt.isDoltFormat)
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, idt.table.projectedCols, idt.table.sqlSch, idt.isDoltFormat)
if err != nil {
return nil, err
}
@@ -98,7 +98,7 @@ func (idt *IndexedDoltTable) PartitionRows2(ctx *sql.Context, part sql.Partition
return nil, err
}
if idt.lb == nil || !canCache || idt.lb.Key() != key {
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, nil, idt.table.sqlSch, idt.isDoltFormat)
idt.lb, err = index.NewLookupBuilder(ctx, idt.table, idt.idx, key, idt.table.projectedCols, idt.table.sqlSch, idt.isDoltFormat)
if err != nil {
return nil, err
}

View File

@@ -293,8 +293,14 @@ func (m prollySecondaryIndexWriter) trimKeyPart(to int, keyPart interface{}) int
if prefixLength != 0 {
switch kp := keyPart.(type) {
case string:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
case []uint8:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
}
}

View File

@@ -218,8 +218,14 @@ func (writer prollyKeylessSecondaryWriter) trimKeyPart(to int, keyPart interface
if prefixLength != 0 {
switch kp := keyPart.(type) {
case string:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
case []uint8:
if prefixLength > uint16(len(kp)) {
prefixLength = uint16(len(kp))
}
keyPart = kp[:prefixLength]
}
}
@@ -304,7 +310,8 @@ func (writer prollyKeylessSecondaryWriter) Delete(ctx context.Context, sqlRow sq
for to := range writer.keyMap {
from := writer.keyMap.MapOrdinal(to)
if err := index.PutField(ctx, writer.mut.NodeStore(), writer.keyBld, to, sqlRow[from]); err != nil {
keyPart := writer.trimKeyPart(to, sqlRow[from])
if err := index.PutField(ctx, writer.mut.NodeStore(), writer.keyBld, to, keyPart); err != nil {
return err
}
}

View File

@@ -16,6 +16,7 @@ package argparser
import (
"errors"
"fmt"
"sort"
"strings"
)
@@ -49,14 +50,17 @@ func ValidatorFromStrList(paramName string, validStrList []string) ValidationFun
type ArgParser struct {
Supported []*Option
NameOrAbbrevToOpt map[string]*Option
nameOrAbbrevToOpt map[string]*Option
ArgListHelp [][2]string
}
func NewArgParser() *ArgParser {
var supported []*Option
nameOrAbbrevToOpt := make(map[string]*Option)
return &ArgParser{supported, nameOrAbbrevToOpt, nil}
return &ArgParser{
Supported: supported,
nameOrAbbrevToOpt: nameOrAbbrevToOpt,
}
}
// SupportOption adds support for a new argument with the option given. Options must have a unique name and abbreviated name.
@@ -64,8 +68,8 @@ func (ap *ArgParser) SupportOption(opt *Option) {
name := opt.Name
abbrev := opt.Abbrev
_, nameExist := ap.NameOrAbbrevToOpt[name]
_, abbrevExist := ap.NameOrAbbrevToOpt[abbrev]
_, nameExist := ap.nameOrAbbrevToOpt[name]
_, abbrevExist := ap.nameOrAbbrevToOpt[abbrev]
if name == "" {
panic("Name is required")
@@ -80,10 +84,10 @@ func (ap *ArgParser) SupportOption(opt *Option) {
}
ap.Supported = append(ap.Supported, opt)
ap.NameOrAbbrevToOpt[name] = opt
ap.nameOrAbbrevToOpt[name] = opt
if abbrev != "" {
ap.NameOrAbbrevToOpt[abbrev] = opt
ap.nameOrAbbrevToOpt[abbrev] = opt
}
}
@@ -95,6 +99,18 @@ func (ap *ArgParser) SupportsFlag(name, abbrev, desc string) *ArgParser {
return ap
}
// SupportsAlias adds support for an alias for an existing option. The alias can be used in place of the original option.
func (ap *ArgParser) SupportsAlias(alias, original string) *ArgParser {
opt, ok := ap.nameOrAbbrevToOpt[original]
if !ok {
panic(fmt.Sprintf("No option found for %s, this is a bug", original))
}
ap.nameOrAbbrevToOpt[alias] = opt
return ap
}
// SupportsString adds support for a new string argument with the description given. See SupportOpt for details on params.
func (ap *ArgParser) SupportsString(name, abbrev, valDesc, desc string) *ArgParser {
opt := &Option{name, abbrev, valDesc, OptionalValue, desc, nil, false}
@@ -146,7 +162,7 @@ func (ap *ArgParser) SupportsInt(name, abbrev, valDesc, desc string) *ArgParser
// modal options in order of descending string length
func (ap *ArgParser) sortedModalOptions() []string {
smo := make([]string, 0, len(ap.Supported))
for s, opt := range ap.NameOrAbbrevToOpt {
for s, opt := range ap.nameOrAbbrevToOpt {
if opt.OptType == OptionalFlag && s != "" {
smo = append(smo, s)
}
@@ -179,7 +195,7 @@ func (ap *ArgParser) matchModalOptions(arg string) (matches []*Option, rest stri
isMatch := len(rest) >= lo && rest[:lo] == on
if isMatch {
rest = rest[lo:]
m := ap.NameOrAbbrevToOpt[on]
m := ap.nameOrAbbrevToOpt[on]
matches = append(matches, m)
// only match options once
@@ -200,7 +216,7 @@ func (ap *ArgParser) matchModalOptions(arg string) (matches []*Option, rest stri
func (ap *ArgParser) sortedValueOptions() []string {
vos := make([]string, 0, len(ap.Supported))
for s, opt := range ap.NameOrAbbrevToOpt {
for s, opt := range ap.nameOrAbbrevToOpt {
if (opt.OptType == OptionalValue || opt.OptType == OptionalEmptyValue) && s != "" {
vos = append(vos, s)
}
@@ -219,14 +235,14 @@ func (ap *ArgParser) matchValueOption(arg string) (match *Option, value *string)
if len(v) > 0 {
value = &v
}
match = ap.NameOrAbbrevToOpt[on]
match = ap.nameOrAbbrevToOpt[on]
return match, value
}
}
return nil, nil
}
// Parses the string args given using the configuration previously specified with calls to the various Supports*
// Parse parses the string args given using the configuration previously specified with calls to the various Supports*
// methods. Any unrecognized arguments or incorrect types will result in an appropriate error being returned. If the
// universal --help or -h flag is found, an ErrHelp error is returned.
func (ap *ArgParser) Parse(args []string) (*ArgParseResults, error) {

View File

@@ -201,7 +201,7 @@ func (res *ArgParseResults) AnyFlagsEqualTo(val bool) *set.StrSet {
func (res *ArgParseResults) FlagsEqualTo(names []string, val bool) *set.StrSet {
results := make([]string, 0, len(res.parser.Supported))
for _, name := range names {
opt, ok := res.parser.NameOrAbbrevToOpt[name]
opt, ok := res.parser.nameOrAbbrevToOpt[name]
if ok && opt.OptType == OptionalFlag {
_, ok := res.options[name]

View File

@@ -16,12 +16,14 @@ package jwtauth
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/json"
)
@@ -110,3 +112,171 @@ func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
}
return jwks.Key(kid), nil
}
// The MultiJWKS will source JWKS from multiple URLs and will make them all
// available through GetKey(). It's GetKey() cannot error, but it can return no
// results.
//
// The URLs in the refresh list are static. Each URL will be periodically
// refreshed and the results will be aggregated into the JWKS view. If a key no
// longer appears at the URL, it may eventually be removed from the set of keys
// available through GetKey(). Requesting a key which is not currently in the
// key set will generally hint that the URLs should be more aggressively
// refreshed, but there is no blocking on refreshing the URLs.
//
// GracefulStop() will shutdown any ongoing fetching work and will return when
// everything is cleanly shutdown.
type MultiJWKS struct {
client *http.Client
wg sync.WaitGroup
stop chan struct{}
refresh []chan *sync.WaitGroup
urls []string
sets []jose.JSONWebKeySet
agg jose.JSONWebKeySet
mu sync.RWMutex
lgr *logrus.Entry
stopped bool
}
func NewMultiJWKS(lgr *logrus.Entry, urls []string, client *http.Client) *MultiJWKS {
res := new(MultiJWKS)
res.lgr = lgr
res.client = client
res.urls = urls
res.stop = make(chan struct{})
res.refresh = make([]chan *sync.WaitGroup, len(urls))
for i := range res.refresh {
res.refresh[i] = make(chan *sync.WaitGroup, 3)
}
res.sets = make([]jose.JSONWebKeySet, len(urls))
return res
}
func (t *MultiJWKS) Run() {
t.wg.Add(len(t.urls))
for i := 0; i < len(t.urls); i++ {
go t.thread(i)
}
t.wg.Wait()
}
func (t *MultiJWKS) GracefulStop() {
t.mu.Lock()
t.stopped = true
t.mu.Unlock()
close(t.stop)
t.wg.Wait()
// TODO: Potentially clear t.refresh channels, ensure nothing else can call GetKey()...
}
func (t *MultiJWKS) needsRefresh() *sync.WaitGroup {
wg := new(sync.WaitGroup)
if t.stopped {
return wg
}
wg.Add(len(t.refresh))
for _, c := range t.refresh {
select {
case c <- wg:
default:
wg.Done()
}
}
return wg
}
func (t *MultiJWKS) store(i int, jwks jose.JSONWebKeySet) {
t.mu.Lock()
defer t.mu.Unlock()
t.sets[i] = jwks
sum := 0
for _, s := range t.sets {
sum += len(s.Keys)
}
t.agg.Keys = make([]jose.JSONWebKey, 0, sum)
for _, s := range t.sets {
t.agg.Keys = append(t.agg.Keys, s.Keys...)
}
}
func (t *MultiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
t.mu.RLock()
defer t.mu.RUnlock()
res := t.agg.Key(kid)
if len(res) == 0 {
t.lgr.Infof("fetched key %s, found no key, signaling refresh", kid)
refresh := t.needsRefresh()
t.mu.RUnlock()
refresh.Wait()
t.mu.RLock()
res = t.agg.Key(kid)
t.lgr.Infof("refresh for key %s done, found %d keys", kid, len(res))
}
return res, nil
}
func (t *MultiJWKS) fetch(i int) error {
request, err := http.NewRequest("GET", t.urls[i], nil)
if err != nil {
return err
}
response, err := t.client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode/100 != 2 {
return fmt.Errorf("http request failed: StatusCode: %d", response.StatusCode)
}
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
var jwks jose.JSONWebKeySet
err = json.Unmarshal(contents, &jwks)
if err != nil {
return err
}
t.store(i, jwks)
return nil
}
func (t *MultiJWKS) thread(i int) {
defer t.wg.Done()
timer := time.NewTimer(30 * time.Second)
var refresh *sync.WaitGroup
for {
nextRefresh := 30 * time.Second
err := t.fetch(i)
if err != nil {
// Something bad...
t.lgr.Warnf("error fetching %s: %v", t.urls[i], err)
nextRefresh = 1 * time.Second
}
timer.Reset(nextRefresh)
if refresh != nil {
refresh.Done()
}
refresh = nil
select {
case <-t.stop:
if !timer.Stop() {
<-timer.C
}
for {
select {
case refresh = <-t.refresh[i]:
refresh.Done()
default:
return
}
}
case refresh = <-t.refresh[i]:
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
}
}
}

View File

@@ -21,6 +21,7 @@ table BranchControl {
table BranchControlAccess {
binlog: BranchControlBinlog;
databases: [BranchControlMatchExpression];
branches: [BranchControlMatchExpression];
users: [BranchControlMatchExpression];
hosts: [BranchControlMatchExpression];
@@ -28,6 +29,7 @@ table BranchControlAccess {
}
table BranchControlAccessValue {
database: string;
branch: string;
user: string;
host: string;
@@ -36,6 +38,7 @@ table BranchControlAccessValue {
table BranchControlNamespace {
binlog: BranchControlBinlog;
databases: [BranchControlMatchExpression];
branches: [BranchControlMatchExpression];
users: [BranchControlMatchExpression];
hosts: [BranchControlMatchExpression];
@@ -43,6 +46,7 @@ table BranchControlNamespace {
}
table BranchControlNamespaceValue {
database: string;
branch: string;
user: string;
host: string;
@@ -54,6 +58,7 @@ table BranchControlBinlog {
table BranchControlBinlogRow {
is_insert: bool;
database: string;
branch: string;
user: string;
host: string;

View File

@@ -118,7 +118,7 @@ func MutateMapWithTupleIter(ctx context.Context, m Map, iter TupleIter) (Map, er
}
func DiffMaps(ctx context.Context, from, to Map, cb tree.DiffFn) error {
return tree.DiffOrderedTrees(ctx, from.tuples, to.tuples, cb)
return tree.DiffOrderedTrees(ctx, from.tuples, to.tuples, makeDiffCallBack(from, to, cb))
}
// RangeDiffMaps returns diffs within a Range. See Range for which diffs are
@@ -153,13 +153,15 @@ func RangeDiffMaps(ctx context.Context, from, to Map, rng Range, cb tree.DiffFn)
return err
}
dcb := makeDiffCallBack(from, to, cb)
for {
var diff tree.Diff
if diff, err = differ.Next(ctx); err != nil {
break
}
if err = cb(ctx, diff); err != nil {
if err = dcb(ctx, diff); err != nil {
break
}
}
@@ -170,7 +172,23 @@ func RangeDiffMaps(ctx context.Context, from, to Map, rng Range, cb tree.DiffFn)
// specified by |start| and |stop|. If |start| and/or |stop| is null, then the
// range is unbounded towards that end.
func DiffMapsKeyRange(ctx context.Context, from, to Map, start, stop val.Tuple, cb tree.DiffFn) error {
return tree.DiffKeyRangeOrderedTrees(ctx, from.tuples, to.tuples, start, stop, cb)
return tree.DiffKeyRangeOrderedTrees(ctx, from.tuples, to.tuples, start, stop, makeDiffCallBack(from, to, cb))
}
func makeDiffCallBack(from, to Map, innerCb tree.DiffFn) tree.DiffFn {
if !from.valDesc.Equals(to.valDesc) {
return innerCb
}
return func(ctx context.Context, diff tree.Diff) error {
// Skip diffs produced by non-canonical tuples. A canonical-tuple is a
// tuple where any null suffixes have been trimmed.
if diff.Type == tree.ModifiedDiff &&
from.valDesc.Compare(val.Tuple(diff.From), val.Tuple(diff.To)) == 0 {
return nil
}
return innerCb(ctx, diff)
}
}
func MergeMaps(ctx context.Context, left, right, base Map, cb tree.CollisionFn) (Map, tree.MergeStats, error) {

View File

@@ -0,0 +1,229 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
load $BATS_TEST_DIRNAME/helper/query-server-common.bash
setup() {
setup_common
}
teardown() {
assert_feature_version
stop_sql_server
teardown_common
}
setup_test_user() {
dolt sql -q "create user test"
dolt sql -q "grant all on *.* to test"
dolt sql -q "delete from dolt_branch_control where user='%'"
}
@test "branch-control: fresh database. branch control tables exist" {
run dolt sql -r csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "%,%,%,%,write" ]
dolt sql -q "select * from dolt_branch_namespace_control"
run dolt sql -q "describe dolt_branch_control"
[ $status -eq 0 ]
[[ $output =~ "database" ]] || false
[[ $output =~ "branch" ]] || false
[[ $output =~ "user" ]] || false
[[ $output =~ "host" ]] || false
[[ $output =~ "permissions" ]] || false
run dolt sql -q "describe dolt_branch_namespace_control"
[ $status -eq 0 ]
[[ $output =~ "database" ]] || false
[[ $output =~ "branch" ]] || false
[[ $output =~ "user" ]] || false
[[ $output =~ "host" ]] || false
}
@test "branch-control: fresh database. branch control tables exist through server interface" {
start_sql_server
run dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "%,%,%,%,write" ]
dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" -q "select * from dolt_branch_namespace_control"
}
@test "branch-control: modify dolt_branch_control from dolt sql then make sure changes are reflected" {
setup_test_user
dolt sql -q "insert into dolt_branch_control values ('test-db','test-branch', 'test', '%', 'write')"
run dolt sql -r csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "test-db,test-branch,test,%,write" ]
start_sql_server
run dolt sql-client -u dolt -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "test-db,test-branch,test,%,write" ]
}
@test "branch-control: default user root works as expected" {
# I can't figure out how to get a dolt sql-server started as root.
# So, I'm copying the pattern from sql-privs.bats and starting it
# manually.
PORT=$( definePORT )
dolt sql-server --host 0.0.0.0 --port=$PORT &
SERVER_PID=$! # will get killed by teardown_common
sleep 5 # not using python wait so this works on windows
run dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT --result-format csv -q "select * from dolt_branch_control"
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "%,%,%,%,write" ]
dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT -q "delete from dolt_branch_control where user='%'"
run dolt sql-client --use-db "dolt_repo_$$" -u root -P $PORT -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ "$output" == "" ]
}
@test "branch-control: test basic branch write permissions" {
setup_test_user
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test', '%', 'write')"
dolt branch test-branch
start_sql_server
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
[ $status -ne 0 ]
[[ $output =~ "does not have the correct permissions" ]] || false
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('test-branch'); create table t (c1 int)"
# I should also have branch permissions on branches I create
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('-b', 'test-branch-2'); create table t (c1 int)"
# Now back to main. Still locked out.
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
[ $status -ne 0 ]
[[ $output =~ "does not have the correct permissions" ]] || false
}
@test "branch-control: test admin permissions" {
setup_test_user
dolt sql -q "create user test2"
dolt sql -q "grant all on *.* to test2"
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test', '%', 'admin')"
dolt branch test-branch
start_sql_server
# Admin has no write permission to branch not an admin on
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "create table t (c1 int)"
[ $status -ne 0 ]
[[ $output =~ "does not have the correct permissions" ]] || false
# Admin can write
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_checkout('test-branch'); create table t (c1 int)"
# Admin can make other users
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-branch', 'test2', '%', 'write')"
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
[ ${lines[2]} = "dolt_repo_$$,test-branch,test2,%,write" ]
# test2 can see all branch permissions
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
[ ${lines[2]} = "dolt_repo_$$,test-branch,test2,%,write" ]
# test2 now has write permissions on test-branch
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 -q "call dolt_checkout('test-branch'); insert into t values(0)"
# Remove test2 permissions
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "delete from dolt_branch_control where user='test2'"
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "dolt_repo_$$,test-branch,test,%,admin" ]
# test2 cannot write to branch
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test2 -q "call dolt_checkout('test-branch'); insert into t values(1)"
[ $status -ne 0 ]
[[ $output =~ "does not have the correct permissions" ]] || false
}
@test "branch-control: creating a branch grants admin permissions" {
setup_test_user
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'main', 'test', '%', 'write')"
start_sql_server
dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test -q "call dolt_branch('test-branch')"
run dolt sql-client -P $PORT --use-db "dolt_repo_$$" -u test --result-format csv -q "select * from dolt_branch_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host,permissions" ]
[ ${lines[1]} = "dolt_repo_$$,main,test,%,write" ]
[ ${lines[2]} = "dolt_repo_$$,test-branch,test,%,admin" ]
}
@test "branch-control: test branch namespace control" {
setup_test_user
dolt sql -q "create user test2"
dolt sql -q "grant all on *.* to test2"
dolt sql -q "insert into dolt_branch_control values ('dolt_repo_$$', 'test-
branch', 'test', '%', 'admin')"
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test-%', 'test2', '%')"
start_sql_server
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" --result-format csv -q "select * from dolt_branch_namespace_control"
[ $status -eq 0 ]
[ ${lines[0]} = "database,branch,user,host" ]
[ ${lines[1]} = "dolt_repo_$$,test-%,test2,%" ]
# test cannot create test-branch
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test-branch')"
[ $status -ne 0 ]
[[ $output =~ "cannot create a branch" ]] || false
# test2 can create test-branch
dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test-branch')"
}
@test "branch-control: test longest match in branch namespace control" {
setup_test_user
dolt sql -q "create user test2"
dolt sql -q "grant all on *.* to test2"
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test/%', 'test', '%')"
dolt sql -q "insert into dolt_branch_namespace_control values ('dolt_repo_$$', 'test2/%', 'test2', '%')"
start_sql_server
# test can create a branch in its namesapce but not in test2
dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test/branch1')"
run dolt sql-client -u test -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test2/branch1')"
[ $status -ne 0 ]
[[ $output =~ "cannot create a branch" ]] || false
dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test2/branch1')"
run dolt sql-client -u test2 -P $PORT --use-db "dolt_repo_$$" -q "call dolt_branch('test/branch1')"
[ $status -ne 0 ]
[[ $output =~ "cannot create a branch" ]] || false
}

View File

@@ -108,6 +108,7 @@ start_multi_db_server() {
stop_sql_server() {
# Clean up any mysql.sock file in the default, global location
rm -f /tmp/mysql.sock
rm -f /tmp/dolt.$PORT.sock
wait=$1
if [ ! -z "$SERVER_PID" ]; then

View File

@@ -764,7 +764,7 @@ DELIM
[ "${lines[1]}" = 5,5 ]
}
@test "import-create-tables: --ignore-skipped-rows correctly prevents skipped rows from printing" {
@test "import-create-tables: --quiet correctly prevents skipped rows from printing" {
cat <<DELIM > 1pk5col-rpt-ints.csv
pk,c1,c2,c3,c4,c5
1,1,2,3,4,5
@@ -772,7 +772,7 @@ pk,c1,c2,c3,c4,c5
1,1,2,3,4,8
DELIM
run dolt table import -c --continue --ignore-skipped-rows --pk=pk test 1pk5col-rpt-ints.csv
run dolt table import -c --continue --quiet --pk=pk test 1pk5col-rpt-ints.csv
[ "$status" -eq 0 ]
! [[ "$output" =~ "The following rows were skipped:" ]] || false
! [[ "$output" =~ "1,1,2,3,4,7" ]] || false
@@ -780,4 +780,17 @@ DELIM
[[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false
[[ "$output" =~ "Lines skipped: 2" ]] || false
[[ "$output" =~ "Import completed successfully." ]] || false
dolt sql -q "drop table test"
# --ignore-skipped-rows is an alias for --quiet
run dolt table import -c --continue --ignore-skipped-rows --pk=pk test 1pk5col-rpt-ints.csv
[ "$status" -eq 0 ]
! [[ "$output" =~ "The following rows were skipped:" ]] || false
! [[ "$output" =~ "1,1,2,3,4,7" ]] || false
! [[ "$output" =~ "1,1,2,3,4,8" ]] || false
[[ "$output" =~ "Rows Processed: 1, Additions: 1, Modifications: 0, Had No Effect: 0" ]] || false
[[ "$output" =~ "Lines skipped: 2" ]] || false
[[ "$output" =~ "Import completed successfully." ]] || false
}

View File

@@ -931,6 +931,27 @@ SQL
[[ ! "$output" =~ "add (2,3) to t1" ]] || false
}
@test "merge: dolt merge does not ff and not commit with --no-ff and --no-commit" {
dolt branch other
dolt sql -q "INSERT INTO test1 VALUES (1,2,3)"
dolt commit -am "add (1,2,3) to test1";
dolt checkout other
run dolt sql -q "select * from test1;" -r csv
[[ ! "$output" =~ "1,2,3" ]] || false
run dolt merge other --no-ff --no-commit
log_status_eq 0
[[ "$output" =~ "Automatic merge went well; stopped before committing as requested" ]] || false
run dolt log --oneline -n 1
[[ "$output" =~ "added tables" ]] || false
[[ ! "$output" =~ "add (1,2,3) to test1" ]] || false
run dolt commit -m "merge main"
log_status_eq 0
}
@test "merge: specify ---author for merge that's used for creating commit" {
dolt branch other
dolt sql -q "INSERT INTO test1 VALUES (1,2,3)"

View File

@@ -134,3 +134,16 @@ teardown() {
[ $status -ne 0 ]
[[ $output =~ "not found" ]] || false
}
@test "sql-client: handle dashes for implicit database" {
make_repo test-dashes
cd test-dashes
PORT=$( definePORT )
dolt sql-server --user=root --port=$PORT &
SERVER_PID=$! # will get killed by teardown_common
sleep 5 # not using python wait so this works on windows
run dolt sql-client -u root -P $PORT -q "show databases"
[ $status -eq 0 ]
[[ $output =~ " test_dashes " ]] || false
}

View File

@@ -2434,6 +2434,44 @@ SQL
[[ "$output" =~ "| 3 |" ]] || false
}
@test "sql: dolt diff table correctly works with NOT and/or IS NULL" {
dolt sql -q "CREATE TABLE t(pk int primary key);"
dolt add .
dolt commit -m "new table t"
dolt sql -q "INSERT INTO t VALUES (1), (2)"
dolt commit -am "add 1, 2"
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where from_pk is null"
[ "$status" -eq 0 ]
[[ "$output" =~ "2" ]] || false
dolt sql -q "UPDATE t SET pk = 3 WHERE pk = 2"
dolt commit -am "add 3"
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where from_pk is not null"
[ "$status" -eq 0 ]
[[ "$output" =~ "1" ]] || false
}
@test "sql: dolt diff table correctly works with datetime comparisons" {
dolt sql -q "CREATE TABLE t(pk int primary key);"
dolt add .
dolt commit -m "new table t"
dolt sql -q "INSERT INTO t VALUES (1), (2), (3)"
dolt commit -am "add 1, 2, 3"
# adds a row and removes a row
dolt sql -q "UPDATE t SET pk = 4 WHERE pk = 2"
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where to_commit_date is not null"
[ "$status" -eq 0 ]
[[ "$output" =~ "3" ]] || false
run dolt sql -q "SELECT COUNT(*) from dolt_diff_t where to_commit_date < now()"
[ "$status" -eq 0 ]
[[ "$output" =~ "3" ]] || false
}
@test "sql: sql print on order by returns the correct result" {
dolt sql -q "CREATE TABLE mytable(pk int primary key);"
dolt sql -q "INSERT INTO mytable VALUES (1),(2),(3),(4),(5),(6),(7),(8),(9),(10),(11),(12),(13),(14),(15),(16),(17),(18),(19),(20)"

View File

@@ -0,0 +1,22 @@
#!/usr/bin/env bats
load $BATS_TEST_DIRNAME/helper/common.bash
setup() {
setup_common
}
teardown() {
teardown_common
}
# Validation is a set of tests that validate various things about dolt
# that have nothing to do with product functionality directly.
@test "validation: no test symbols in binary" {
run grep_for_testify
[ "$output" = "" ]
}
grep_for_testify() {
strings `which dolt` | grep testify
}

View File

@@ -562,6 +562,7 @@ tests:
- exec: "create table more_vals (i int primary key)"
error_match: "repo1 is read-only"
- on: server2
retry_attempts: 100
queries:
- query: "SELECT @@GLOBAL.dolt_cluster_role,@@GLOBAL.dolt_cluster_role_epoch"
result:

View File

@@ -0,0 +1,9 @@
# Mikro-ORM Smoke Test
The `index.ts` file is the main entry point and will insert a new record into the database, then load it, print
success, and exit with a zero exit code. If any errors are encountered, they are logged, and the process exits with a
non-zero exit code.
To run this smoke test project:
1. Run `npm install` command
2. Run `npm start` command

View File

@@ -0,0 +1,21 @@
{
"name": "mikro-orm-smoketest",
"version": "0.0.1",
"description": "DoltDB smoke test for Mikro-ORM integration",
"type": "commonjs",
"scripts": {
"start": "ts-node src/index.ts",
"mikro-orm": "mikro-orm-ts-node-commonjs"
},
"devDependencies": {
"ts-node": "^10.7.0",
"@types/node": "^16.11.10",
"typescript": "^4.5.2"
},
"dependencies": {
"@mikro-orm/core": "^5.0.3",
"@mikro-orm/mysql": "^5.0.3",
"mysql": "^2.14.1"
}
}

View File

@@ -0,0 +1,22 @@
import { Entity, PrimaryKey, Property } from "@mikro-orm/core";
@Entity()
export class User {
@PrimaryKey()
id!: number;
@Property()
firstName!: string;
@Property()
lastName!: string;
@Property()
age!: number;
constructor(firstName: string, lastName: string, age: number) {
this.firstName = firstName;
this.lastName = lastName;
this.age = age;
}
}

View File

@@ -0,0 +1,43 @@
import { MikroORM } from "@mikro-orm/core";
import { MySqlDriver } from '@mikro-orm/mysql';
import { User } from "./entity/User";
async function connectAndGetOrm() {
const orm = await MikroORM.init<MySqlDriver>({
entities: [User],
type: "mysql",
clientUrl: "mysql://localhost:3306",
dbName: "dolt",
user: "dolt",
password: "",
persistOnCreate: true,
});
return orm;
}
connectAndGetOrm().then(async orm => {
console.log("Connected");
const em = orm.em.fork();
// this creates the tables if not exist
const generator = orm.getSchemaGenerator();
await generator.updateSchema();
console.log("Inserting a new user into the database...")
const user = new User("Timber", "Saw", 25)
await em.persistAndFlush(user)
console.log("Saved a new user with id: " + user.id)
console.log("Loading users from the database...")
const users = await em.findOne(User, 1)
console.log("Loaded users: ", users)
orm.close();
console.log("Smoke test passed!")
process.exit(0)
}).catch(error => {
console.log(error)
console.log("Smoke test failed!")
process.exit(1)
});

View File

@@ -0,0 +1,14 @@
{
"compilerOptions": {
"module": "commonjs",
"declaration": true,
"removeComments": true,
"emitDecoratorMetadata": true,
"esModuleInterop": true,
"experimentalDecorators": true,
"target": "es2017",
"outDir": "./dist",
"baseUrl": "./src",
"incremental": true,
}
}

View File

@@ -55,7 +55,7 @@ teardown() {
npx -c "prisma migrate dev --name init"
}
# Prisma is an ORM for Node/TypeScript applications. This test checks out the Peewee test suite
# Prisma is an ORM for Node/TypeScript applications. This test checks out the Prisma test suite
# and runs it against Dolt.
@test "Prisma ORM test suite" {
skip "Not implemented yet"
@@ -77,6 +77,16 @@ teardown() {
npm start
}
# MikroORM is an ORM for Node/TypeScript applications. This is a simple smoke test to make sure
# Dolt can support the most basic MikroORM operations.
@test "MikroORM smoke test" {
mysql --protocol TCP -u dolt -e "create database dolt;"
cd mikro-orm
npm install
npm start
}
# Turn this test on to prevent the container from exiting if you need to exec a shell into
# the container to debug failed tests.
#@test "Pause container for an hour to debug failures" {