Branch Control

This commit is contained in:
Daylon Wilkins
2022-10-12 08:24:03 -07:00
parent 0f55b04583
commit 54225ca657
15 changed files with 1957 additions and 5 deletions
+4 -2
View File
@@ -55,7 +55,9 @@ Similar to {{.EmphasisLeft}}dolt sql-server{{.EmphasisRight}}, this command may
},
}
type SqlClientCmd struct{}
type SqlClientCmd struct {
VersionStr string
}
var _ cli.Command = SqlClientCmd{}
@@ -126,7 +128,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
serverController = NewServerController()
go func() {
_, _ = Serve(ctx, SqlServerCmd{}.VersionStr, serverConfig, serverController, dEnv)
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv)
}()
err = serverController.WaitForStart()
if err != nil {
+1 -1
View File
@@ -73,7 +73,7 @@ var doltCommand = cli.NewSubCommandHandler("dolt", "it's git for data", []cli.Co
commands.SqlCmd{VersionStr: Version},
admin.Commands,
sqlserver.SqlServerCmd{VersionStr: Version},
sqlserver.SqlClientCmd{},
sqlserver.SqlClientCmd{VersionStr: Version},
commands.LogCmd{},
commands.BranchCmd{},
commands.CheckoutCmd{},
@@ -0,0 +1,141 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"sync"
"github.com/dolthub/go-mysql-server/sql"
)
// Permissions are a set of flags that denote a user's allowed functionality on a branch.
type Permissions uint64
const (
Permissions_Admin Permissions = 1 << iota // Permissions_Admin grants unrestricted control over a branch, including modification of table entries
Permissions_Write // Permissions_Write allows for all modifying operations on a branch, but does not allow modification of table entries
)
// Access contains all of the expressions that comprise the "dolt_branch_control" table, which handles write Access to
// branches, along with write access to the branch control system tables.
type Access struct {
binlog *Binlog
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []AccessValue
SuperUser string
SuperHost string
RWMutex *sync.RWMutex
}
// AccessValue contains the user-facing values of a particular row, along with the permissions for a row.
type AccessValue struct {
Branch string
User string
Host string
Permissions Permissions
}
// newAccess returns a new Access.
func newAccess(superUser string, superHost string) *Access {
return &Access{
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
SuperUser: superUser,
SuperHost: superHost,
RWMutex: &sync.RWMutex{},
}
}
// Match returns whether any entries match the given branch, user, and host, along with their permissions. Requires
// external synchronization handling, therefore manually manage the RWMutex.
func (tbl *Access) Match(branch string, user string, host string) (bool, Permissions) {
filteredIndexes := Match(tbl.Users, user, sql.Collation_utf8mb4_0900_bin)
filteredHosts := tbl.filterHosts(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredHosts)
filteredBranches := tbl.filterBranches(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredBranches)
bRes, pRes := len(filteredIndexes) > 0, tbl.gatherPermissions(filteredIndexes)
indexPool.Put(filteredIndexes)
return bRes, pRes
}
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
// returns -1. Assumes that the given expressions have already been folded. Requires external synchronization handling,
// therefore manually manage the RWMutex.
func (tbl *Access) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
for i, value := range tbl.Values {
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
return i
}
}
return -1
}
// filterBranches returns all branches that match the given collection indexes.
func (tbl *Access) filterBranches(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Branches[filter])
}
return matchExprs
}
// filterUsers returns all users that match the given collection indexes.
func (tbl *Access) filterUsers(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Users[filter])
}
return matchExprs
}
// filterHosts returns all hosts that match the given collection indexes.
func (tbl *Access) filterHosts(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Hosts[filter])
}
return matchExprs
}
// gatherPermissions combines all permissions from the given collection indexes and returns the result.
func (tbl *Access) gatherPermissions(collectionIndexes []uint32) Permissions {
perms := Permissions(0)
for _, collectionIndex := range collectionIndexes {
perms |= tbl.Values[collectionIndex].Permissions
}
return perms
}
@@ -0,0 +1,249 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"bytes"
"encoding/binary"
"fmt"
"sync"
)
const (
currentBinlogVersion = uint16(1)
)
//TODO: add stored procedure functions for modifying the binlog
// Binlog is a running log file that tracks changes to tables within branch control. This is used for history purposes,
// as well as transactional purposes through the use of the BinlogOverlay.
type Binlog struct {
rows []BinlogRow
RWMutex *sync.RWMutex
}
// BinlogRow is a row within the Binlog.
type BinlogRow struct {
IsInsert bool
Branch string
User string
Host string
Permissions uint64
}
// BinlogOverlay enables transactional use cases over Binlog. Unlike a Binlog, a BinlogOverlay requires external
// synchronization.
type BinlogOverlay struct {
parentLength int
rows []BinlogRow
}
// Serialize returns the Binlog as a byte slice. All encoded integers are big-endian.
func (binlog *Binlog) Serialize() []byte {
binlog.RWMutex.RLock()
defer binlog.RWMutex.RUnlock()
buffer := bytes.Buffer{}
// Write the version bytes
writeUint16(&buffer, currentBinlogVersion)
// Write the number of entries
binlogSize := uint64(len(binlog.rows))
writeUint64(&buffer, binlogSize)
// Write the rows
for _, binlogRow := range binlog.rows {
binlogRow.Serialize(&buffer)
}
return buffer.Bytes()
}
// Deserialize populates the binlog with the given data. Returns an error if the data cannot be deserialized, or if the
// Binlog has already been written to. Deserialize must be called on an empty Binlog.
func (binlog *Binlog) Deserialize(data []byte) error {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
if len(binlog.rows) != 0 {
return fmt.Errorf("cannot deserialize to a non-empty binlog")
}
position := uint64(0)
// Read the version
version := binary.BigEndian.Uint16(data)
position += 2
if version != currentBinlogVersion {
// If we ever increment the binlog version, this will instead handle the conversion from previous versions
return fmt.Errorf(`cannot deserialize a binlog with version "%d"`, version)
}
// Read the number of entries
binlogSize := binary.BigEndian.Uint64(data[position:])
position += 8
// Read the rows
binlog.rows = make([]BinlogRow, binlogSize)
for i := uint64(0); i < binlogSize; i++ {
binlog.rows[i], position = deserializeBinlogRow(data, position)
}
return nil
}
// NewOverlay returns a new BinlogOverlay for the calling Binlog.
func (binlog *Binlog) NewOverlay() *BinlogOverlay {
binlog.RWMutex.RLock()
defer binlog.RWMutex.RUnlock()
return &BinlogOverlay{
parentLength: len(binlog.rows),
rows: nil,
}
}
// MergeOverlay merges the given BinlogOverlay with the calling Binlog. Fails if the Binlog has been written to since
// the overlay was created.
func (binlog *Binlog) MergeOverlay(overlay *BinlogOverlay) error {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
// Except for recovery situations, the binlog is an append-only structure, therefore if there are a different number
// of entries than when the overlay was created, then it has probably been written to. The likelihood of there being
// an outstanding overlay while the binlog is being modified is exceedingly low.
if len(binlog.rows) != overlay.parentLength {
return fmt.Errorf("cannot merge overlay as binlog has been modified")
}
binlog.rows = append(binlog.rows, overlay.rows...)
return nil
}
// Insert adds an insert entry to the Binlog.
func (binlog *Binlog) Insert(branch string, user string, host string, permissions uint64) {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
binlog.rows = append(binlog.rows, BinlogRow{
IsInsert: true,
Branch: branch,
User: user,
Host: host,
Permissions: permissions,
})
}
// Delete adds a delete entry to the Binlog.
func (binlog *Binlog) Delete(branch string, user string, host string, permissions uint64) {
binlog.RWMutex.Lock()
defer binlog.RWMutex.Unlock()
binlog.rows = append(binlog.rows, BinlogRow{
IsInsert: false,
Branch: branch,
User: user,
Host: host,
Permissions: permissions,
})
}
// Serialize writes the row to the given buffer. All encoded integers are big-endian.
func (row *BinlogRow) Serialize(buffer *bytes.Buffer) {
// Write whether this was an insertion or deletion
if row.IsInsert {
buffer.WriteByte(1)
} else {
buffer.WriteByte(0)
}
// Write the branch
branchLen := uint16(len(row.Branch))
writeUint16(buffer, branchLen)
buffer.WriteString(row.Branch)
// Write the user
userLen := uint16(len(row.User))
writeUint16(buffer, userLen)
buffer.WriteString(row.User)
// Write the host
hostLen := uint16(len(row.Host))
writeUint16(buffer, hostLen)
buffer.WriteString(row.Host)
// Write the permissions
writeUint64(buffer, row.Permissions)
}
// deserializeBinlogRow returns a BinlogRow from the data at the given position. Also returns the new position. Assumes
// that the given data's encoded integers are big-endian.
func deserializeBinlogRow(data []byte, position uint64) (BinlogRow, uint64) {
binlogRow := BinlogRow{}
// Read whether this was an insert or write
if data[position] == 1 {
binlogRow.IsInsert = true
} else {
binlogRow.IsInsert = false
}
position += 1
// Read the branch
branchLen := uint64(binary.BigEndian.Uint16(data[position:]))
position += 2
binlogRow.Branch = string(data[position : position+branchLen])
position += branchLen
// Read the user
userLen := uint64(binary.BigEndian.Uint16(data[position:]))
position += 2
binlogRow.User = string(data[position : position+userLen])
position += userLen
// Read the host
hostLen := uint64(binary.BigEndian.Uint16(data[position:]))
position += 2
binlogRow.Host = string(data[position : position+hostLen])
position += hostLen
// Read the permissions
binlogRow.Permissions = binary.BigEndian.Uint64(data[position:])
position += 8
return binlogRow, position
}
// Insert adds an insert entry to the BinlogOverlay.
func (overlay *BinlogOverlay) Insert(branch string, user string, host string, permissions uint64) {
overlay.rows = append(overlay.rows, BinlogRow{
IsInsert: true,
Branch: branch,
User: user,
Host: host,
Permissions: permissions,
})
}
// Delete adds a delete entry to the BinlogOverlay.
func (overlay *BinlogOverlay) Delete(branch string, user string, host string, permissions uint64) {
overlay.rows = append(overlay.rows, BinlogRow{
IsInsert: false,
Branch: branch,
User: user,
Host: host,
Permissions: permissions,
})
}
// writeUint64 writes an uint64 into the buffer.
func writeUint64(buffer *bytes.Buffer, val uint64) {
buffer.WriteByte(byte(val >> 56))
buffer.WriteByte(byte(val >> 48))
buffer.WriteByte(byte(val >> 40))
buffer.WriteByte(byte(val >> 32))
buffer.WriteByte(byte(val >> 24))
buffer.WriteByte(byte(val >> 16))
buffer.WriteByte(byte(val >> 8))
buffer.WriteByte(byte(val))
}
// writeUint16 writes an uint16 into the buffer.
func writeUint16(buffer *bytes.Buffer, val uint16) {
buffer.WriteByte(byte(val >> 8))
buffer.WriteByte(byte(val))
}
@@ -0,0 +1,176 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"context"
"os"
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/go-mysql-server/sql"
)
var (
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
)
// Context represents the interface that must be inherited from the context.
type Context interface {
GetBranch() (string, error)
GetUser() string
GetHost() string
GetController() *Controller
}
// Controller is the central hub for branch control functions. This is passed within a context.
type Controller struct {
Access *Access
Namespace *Namespace
}
//TODO: delete me
var StaticController = &Controller{}
var enabled = false
func init() {
if os.Getenv("DOLT_ENABLE_BRANCH_CONTROL") != "" {
enabled = true
}
StaticController = CreateControllerWithSuperUser(context.Background(), "root", "localhost")
}
// CreateController returns an empty *Controller.
func CreateController(ctx context.Context) *Controller {
accessTbl := newAccess("", "")
return &Controller{
Access: accessTbl,
Namespace: newNamespace(accessTbl, "", ""),
}
}
// CreateControllerWithSuperUser returns a controller with the given user and host set as an immutable super user.
func CreateControllerWithSuperUser(ctx context.Context, superUser string, superHost string) *Controller {
accessTbl := newAccess(superUser, superHost)
return &Controller{
Access: accessTbl,
Namespace: newNamespace(accessTbl, superUser, superHost),
}
}
// CheckAccess returns whether the given context has the correct permissions on its selected branch. In general, SQL
// statements will almost always return a *sql.Context, so any checks from the SQL path will correctly check for branch
// permissions. However, not all CLI commands use *sql.Context, and therefore will not have any user associated with
// the context. In these cases, CheckAccess will pass as we want to allow all local commands to ignore branch
// permissions.
func CheckAccess(ctx context.Context, flags Permissions) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we allow all operations
if branchAwareSession == nil {
return nil
}
StaticController.Access.RWMutex.RLock()
defer StaticController.Access.RWMutex.RUnlock()
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
// Check if the user is the super user, which has access to all operations
if user == StaticController.Access.SuperUser && host == StaticController.Access.SuperHost {
return nil
}
branch, err := branchAwareSession.GetBranch()
if err != nil {
return err
}
// Get the permissions for the branch, user, and host combination
_, perms := StaticController.Access.Match(branch, user, host)
// If either the flags match or the user is an admin for this branch, then we allow access
if (perms&flags == flags) || (perms&Permissions_Admin == Permissions_Admin) {
return nil
}
return ErrIncorrectPermissions.New(user, host, branch)
}
// CanCreateBranch returns whether the given context can create a branch with the given name. In general, SQL statements
// will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
// these cases, CanCreateBranch will pass as we want to allow all local commands to freely create branches.
func CanCreateBranch(ctx context.Context, branchName string) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
if branchAwareSession == nil {
return nil
}
StaticController.Namespace.RWMutex.RLock()
defer StaticController.Namespace.RWMutex.RUnlock()
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
if StaticController.Namespace.CanCreate(branchName, user, host) {
return nil
}
return ErrCannotCreateBranch.New(user, host, branchName)
}
// CanDeleteBranch returns whether the given context can delete a branch with the given name. In general, SQL statements
// will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
// these cases, CanDeleteBranch will pass as we want to allow all local commands to freely delete branches.
func CanDeleteBranch(ctx context.Context, branchName string) error {
if !enabled {
return nil
}
branchAwareSession := GetBranchAwareSession(ctx)
// A nil session means we're not in the SQL context, so we allow the delete operation
if branchAwareSession == nil {
return nil
}
StaticController.Access.RWMutex.RLock()
defer StaticController.Access.RWMutex.RUnlock()
user := branchAwareSession.GetUser()
host := branchAwareSession.GetHost()
// Check if the user is the super user, which is always able to delete branches
if user == StaticController.Access.SuperUser && host == StaticController.Access.SuperHost {
return nil
}
// Get the permissions for the branch, user, and host combination
_, perms := StaticController.Access.Match(branchName, user, host)
// If the user has the write or admin flags, then we allow access
if (perms&Permissions_Write == Permissions_Write) || (perms&Permissions_Admin == Permissions_Admin) {
return nil
}
return ErrCannotDeleteBranch.New(user, host, branchName)
}
// GetBranchAwareSession returns the session contained within the context. If the context does NOT contain a session,
// then nil is returned.
func GetBranchAwareSession(ctx context.Context) Context {
if sqlCtx, ok := ctx.(*sql.Context); ok {
if bas, ok := sqlCtx.Session.(Context); ok {
return bas
}
} else if bas, ok := ctx.(Context); ok {
return bas
}
return nil
}
@@ -0,0 +1,253 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"math"
"sync"
"unicode/utf8"
"github.com/dolthub/go-mysql-server/sql"
)
const (
singleMatch = -1 // Equivalent to the single match character '_'
anyMatch = -2 // Equivalent to the any-length match character '%'
)
// invalidMatchExpression is a match expression that does not match anything
var invalidMatchExpression = MatchExpression{math.MaxUint32, nil}
// matchExprPool is a pool for MatchExpression slices. Provides a significant performance benefit.
var matchExprPool = &sync.Pool{
New: func() any {
return make([]MatchExpression, 0, 32)
},
}
// indexPool is a pool for index slices (such as those returned by Match). Provides a decent performance benefit.
var indexPool = &sync.Pool{
New: func() any {
return make([]uint32, 0, 32)
},
}
// MatchExpression represents a parsed expression that may be matched against. It contains a list of sort orders, which
// each represent a comparable value to determine whether any given character is a match. A character's sort order is
// obtained from a collation. Also contains its index in the table. MatchExpression contents are not meant to be
// comparable to one another, therefore please use the index to compare equivalence.
type MatchExpression struct {
CollectionIndex uint32 // CollectionIndex represents this expression's index in its parent slice.
SortOrders []int32 // These are the sort orders that will be compared against when matching a given rune.
}
// FoldExpression folds the given expression into its smallest form. Expressions have two wildcard operators:
// '_' and '%'. '_' matches exactly one character, and it can be any character. '%' can match zero or more of any
// character. Taking these two ops into account, the configurations "%_" and "_%" both resolve to matching one or more
// of any character. However, the "_%" form is more economical, as you enforce the single match first before checking
// for remaining matches. Similarly, "%%" is equivalent to a single '%'. Both of these rules are applied in this
// function, guaranteeing that the returned expression is the smallest form that still exactly represents the original.
//
// This also assumes that '\' is the escape character.
func FoldExpression(str string) string {
// This loop only terminates when we complete a run where no substitutions were made. Substitutions are applied
// linearly, therefore it's possible that one substitution may create an opportunity for another substitution.
// To keep the code simple, we continue looping until we have nothing more to do.
for true {
newStrRunes := make([]rune, 0, len(str))
// Skip next is set whenever we encounter the escape character, which is used to explicitly match against '_' and '%'
skipNext := false
// Consider next is set whenever we encounter an unescaped '%', indicating we may need to apply the substitutions
considerNext := false
for _, r := range str {
if skipNext {
skipNext = false
newStrRunes = append(newStrRunes, r)
continue
} else if considerNext {
considerNext = false
switch r {
case '\\':
newStrRunes = append(newStrRunes, '%', r) // False alarm, reinsert % before this rune
skipNext = true // We also need to ignore the next rune
case '_':
newStrRunes = append(newStrRunes, r, '%') // Replacing %_ with _%
case '%':
newStrRunes = append(newStrRunes, r) // Replacing %% with %
default:
newStrRunes = append(newStrRunes, '%', r) // False alarm, reinsert % before this rune
}
continue
}
switch r {
case '\\':
newStrRunes = append(newStrRunes, r)
skipNext = true
case '%':
considerNext = true
default:
newStrRunes = append(newStrRunes, r)
}
}
// If the very last rune is '%', then this will be true and we need to append it to the end
if considerNext {
newStrRunes = append(newStrRunes, '%')
}
newStr := string(newStrRunes)
if str == newStr {
break
}
str = newStr
}
return str
}
// ParseExpression parses the given string expression into a slice of sort ints, which will be used in a MatchExpression.
// Returns nil if the string is too long. Assumes that the given string expression has already been folded.
func ParseExpression(str string, collation sql.CollationID) []int32 {
if len(str) > math.MaxUint16 {
return nil
}
sortFunc := collation.Sorter()
var orders []int32
escaped := false
for _, r := range str {
if escaped {
escaped = false
orders = append(orders, sortFunc(r))
} else {
switch r {
case '\\':
escaped = true
case '%':
orders = append(orders, anyMatch)
case '_':
orders = append(orders, singleMatch)
default:
orders = append(orders, sortFunc(r))
}
}
}
return orders
}
// Match takes the match expression collection, and returns a slice of which collection indexes matched against the
// given string. The given indices may be used to further reduce the match expression collection, which will also reduce
// the total number of comparisons as they're narrowed down.
//
// It is vastly more performant to return a slice of collection indexes here, rather than a slice of match expressions.
// This is true even when the match expressions are pooled. The reason is unknown, but as we only need the collection
// indexes anyway, we discard the match expressions and return only their indexes.
func Match(matchExprCollection []MatchExpression, str string, collation sql.CollationID) []uint32 {
if len(str) == 0 {
return nil
}
sortFunc := collation.Sorter()
// Grab the first rune and also remove it from the string
r, rSize := utf8.DecodeRuneInString(str)
str = str[rSize:]
// Grab a slice from the pool, which reduces the GC pressure.
matchSubset := matchExprPool.Get().([]MatchExpression)[:0]
// We do a pass using the first rune over all expressions to get the subset that we'll be testing against
for _, testExpr := range matchExprCollection {
if matched, next, extra := testExpr.Matches(sortFunc(r)); matched {
if extra.IsValid() {
matchSubset = append(matchSubset, next, extra)
} else {
matchSubset = append(matchSubset, next)
}
}
}
// Bail early if there are no matches here
if len(matchSubset) == 0 {
matchExprPool.Put(matchSubset)
// We return a slice from the index pool as we later will return it to the pool. We don't want to stick a
// nil/empty slice into the pool.
return indexPool.Get().([]uint32)[:0]
}
// This is the slice that we'll put matches into. This will also flip to become the match subset. This way we reuse
// the underlying arrays. We also grab this from the pool.
matches := matchExprPool.Get().([]MatchExpression)[:0]
// Now that we have our set of expressions to test, we loop over the remainder of the input string
for _, r = range str {
for _, testExpr := range matchSubset {
if matched, next, extra := testExpr.Matches(sortFunc(r)); matched {
if extra.IsValid() {
matches = append(matches, next, extra)
} else {
matches = append(matches, next)
}
}
}
// Swap the two, and put the slice of matches to be at the beginning of the previous subset array to reuse it
matches, matchSubset = matchSubset[:0], matches
}
matchExprPool.Put(matches)
// Grab the indices of all valid matches
validMatches := indexPool.Get().([]uint32)[:0]
for _, match := range matchSubset {
if match.IsAtEnd() && (len(validMatches) == 0 ||
(len(validMatches) > 0 && match.CollectionIndex != validMatches[len(validMatches)-1])) {
validMatches = append(validMatches, match.CollectionIndex)
}
}
matchExprPool.Put(matchSubset)
return validMatches
}
// Matches returns true when the given sort order matches the expectation of the calling match expression. Returns a
// reduced match expression as `next`, which should take the place of the calling match function. In the event of a
// branch, returns the branching match expression as `extra`.
//
// Branches occur when the '%' operator sees that the given sort order matches the sort order after the '%'. As it
// cannot be determined which path is the correct one (whether to consume the '%' or continue using it), a branch is
// created. The `extra` should be checked for validity by calling IsValid.
func (matchExpr MatchExpression) Matches(sortOrder int32) (matched bool, next MatchExpression, extra MatchExpression) {
if len(matchExpr.SortOrders) == 0 {
return false, invalidMatchExpression, invalidMatchExpression
}
switch matchExpr.SortOrders[0] {
case singleMatch:
return true, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[1:]}, invalidMatchExpression
case anyMatch:
if len(matchExpr.SortOrders) > 1 && matchExpr.SortOrders[1] == sortOrder {
return true, matchExpr, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[2:]}
}
return true, matchExpr, invalidMatchExpression
default:
if sortOrder == matchExpr.SortOrders[0] {
return true, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[1:]}, invalidMatchExpression
} else {
return false, invalidMatchExpression, invalidMatchExpression
}
}
}
// IsValid returns whether the match expression is valid. An invalid MatchExpression will have a collection index that
// is at the maximum value for an uint32.
func (matchExpr MatchExpression) IsValid() bool {
return matchExpr.CollectionIndex < math.MaxUint32
}
// IsAtEnd returns whether the match expression has matched every character. There is a special case where, if the last
// character is '%', it is considered to be at the end.
func (matchExpr MatchExpression) IsAtEnd() bool {
return len(matchExpr.SortOrders) == 0 || (len(matchExpr.SortOrders) == 1 && matchExpr.SortOrders[0] == anyMatch)
}
@@ -0,0 +1,168 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"fmt"
"testing"
"github.com/dolthub/go-mysql-server/sql"
"github.com/stretchr/testify/require"
)
func TestSingleMatch(t *testing.T) {
tests := []struct {
expression string
testStr string
matches bool
collation sql.CollationID
}{
{"a__", "abc", true, sql.Collation_utf8mb4_0900_bin},
{"a__", "abcd", false, sql.Collation_utf8mb4_0900_bin},
{"a%b", "acb", true, sql.Collation_utf8mb4_0900_bin},
{"a%b", "acdkeflskjfdklb", true, sql.Collation_utf8mb4_0900_bin},
{"a%b", "ab", true, sql.Collation_utf8mb4_0900_bin},
{"a%b", "a", false, sql.Collation_utf8mb4_0900_bin},
{"a_b", "ab", false, sql.Collation_utf8mb4_0900_bin},
{"aa:%", "aa:bb:cc:dd:ee:ff", true, sql.Collation_utf8mb4_0900_bin},
{"aa:%", "AA:BB:CC:DD:EE:FF", false, sql.Collation_utf8mb4_0900_bin},
{"aa:%", "AA:BB:CC:DD:EE:FF", true, sql.Collation_utf8mb4_0900_ai_ci},
{"a_%_b%_%c", "AaAbCc", true, sql.Collation_utf8mb4_0900_ai_ci},
{"a_%_b%_%c", "AaAbBcCbCc", true, sql.Collation_utf8mb4_0900_ai_ci},
{"a_%_b%_%c", "AbbbbC", true, sql.Collation_utf8mb4_0900_ai_ci},
{"a_%_n%_%z", "aBcDeFgHiJkLmNoPqRsTuVwXyZ", true, sql.Collation_utf8mb4_0900_ai_ci},
{`a\%b`, "acb", false, sql.Collation_utf8mb4_0900_bin},
{`a\%b`, "a%b", true, sql.Collation_utf8mb4_0900_bin},
{`a\%b`, "A%B", false, sql.Collation_utf8mb4_0900_bin},
{`a\%b`, "A%B", true, sql.Collation_utf8mb4_0900_ai_ci},
{`a`, "a", true, sql.Collation_utf8mb4_0900_bin},
{`ab`, "a", false, sql.Collation_utf8mb4_0900_bin},
{`a\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
{`a\\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
{`a\\\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
{`a`, "a", true, sql.Collation_utf8mb4_0900_ai_ci},
{`ab`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
{`a\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
{`a\\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
{`a\\\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
{`A%%%%`, "abc", true, sql.Collation_utf8mb4_0900_ai_ci},
{`A%%%%bc`, "abc", true, sql.Collation_utf8mb4_0900_ai_ci},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%q matches %q", test.testStr, test.expression), func(t *testing.T) {
parsedExpression := ParseExpression(FoldExpression(test.expression), test.collation)
matchCount := Match([]MatchExpression{{0, parsedExpression}}, test.testStr, test.collation)
if test.matches {
require.Len(t, matchCount, 1)
} else {
require.Len(t, matchCount, 0)
}
})
}
}
func TestMultipleMatch(t *testing.T) {
collation := sql.Collation_utf8mb4_0900_ai_ci
matchExprs := []MatchExpression{
{0, ParseExpression(FoldExpression("a__"), collation)},
{1, ParseExpression(FoldExpression("a%b"), collation)},
{2, ParseExpression(FoldExpression("a_b"), collation)},
{3, ParseExpression(FoldExpression("aa:%"), collation)},
{4, ParseExpression(FoldExpression("a_%_b%_%c"), collation)},
{5, ParseExpression(FoldExpression(`a\%b`), collation)},
{6, ParseExpression(FoldExpression(`a`), collation)},
{7, ParseExpression(FoldExpression(`ab`), collation)},
{8, ParseExpression(FoldExpression(`a\\b`), collation)},
{9, ParseExpression(FoldExpression(`A%%%%`), collation)},
{10, ParseExpression(FoldExpression(`A%%%%bc`), collation)},
{11, ParseExpression(FoldExpression("a_%_b%_%c%"), collation)},
{12, ParseExpression(FoldExpression("a_%_n%_%z"), collation)},
}
tests := []struct {
testStr string
indexes []uint32
}{
{"a", []uint32{6, 9}},
{"ab", []uint32{1, 7, 9}},
{"abc", []uint32{0, 9, 10}},
{"acb", []uint32{0, 1, 2, 9}},
{"abcd", []uint32{9}},
{"acdkeflskjfdklb", []uint32{1, 9}},
{"acdkeflskjfdklbc", []uint32{9, 10}},
{"aa:bb:cc:dd:ee:ff", []uint32{3, 9, 11}},
{"AA:BB:CC:DD:EE:FF", []uint32{3, 9, 11}},
{"AaAbCc", []uint32{4, 9, 11}},
{"AaAbBcCbCc", []uint32{4, 9, 11}},
{"AbbbbC", []uint32{4, 9, 10, 11}},
{"aBcDeFgHiJkLmNoPqRsTuVwXyZ", []uint32{9, 12}},
{"a%b", []uint32{0, 1, 2, 5, 9}},
{"A%B", []uint32{0, 1, 2, 5, 9}},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%q", test.testStr), func(t *testing.T) {
actualMatches := Match(matchExprs, test.testStr, collation)
require.ElementsMatch(t, test.indexes, actualMatches)
})
}
}
func BenchmarkSimpleCase(b *testing.B) {
collation := sql.Collation_utf8mb4_0900_ai_ci
matchExprs := []MatchExpression{
{0, ParseExpression(FoldExpression("a__"), collation)},
{1, ParseExpression(FoldExpression("a%b"), collation)},
{2, ParseExpression(FoldExpression("a_b"), collation)},
{3, ParseExpression(FoldExpression("aa:%"), collation)},
{4, ParseExpression(FoldExpression("a_%_b%_%c"), collation)},
{5, ParseExpression(FoldExpression(`a\%b`), collation)},
{6, ParseExpression(FoldExpression(`a`), collation)},
{7, ParseExpression(FoldExpression(`ab`), collation)},
{8, ParseExpression(FoldExpression(`a\\b`), collation)},
{9, ParseExpression(FoldExpression(`A%%%%`), collation)},
{10, ParseExpression(FoldExpression(`A%%%%bc`), collation)},
{11, ParseExpression(FoldExpression("a_%_b%_%c%"), collation)},
{12, ParseExpression(FoldExpression("a_%_n%_%z"), collation)},
}
tests := []struct {
testStr string
matches []uint32
}{
{"a", []uint32{6, 9}},
{"ab", []uint32{1, 7, 9}},
{"abc", []uint32{0, 9, 10}},
{"acb", []uint32{0, 1, 2, 9}},
{"abcd", []uint32{9}},
{"acdkeflskjfdklb", []uint32{1, 9}},
{"acdkeflskjfdklbc", []uint32{9, 10}},
{"aa:bb:cc:dd:ee:ff", []uint32{3, 9, 11}},
{"AA:BB:CC:DD:EE:FF", []uint32{3, 9, 11}},
{"AaAbCc", []uint32{4, 9, 11}},
{"AaAbBcCbCc", []uint32{4, 9, 11}},
{"AbbbbC", []uint32{4, 9, 10, 11}},
{"aBcDeFgHiJkLmNoPqRsTuVwXyZ", []uint32{9, 12}},
{"a%b", []uint32{0, 1, 2, 5, 9}},
{"A%B", []uint32{0, 1, 2, 5, 9}},
}
testLen := len(tests)
b.ResetTimer()
for i := 0; i < b.N; i++ {
test := tests[i%testLen]
indexes := Match(matchExprs, test.testStr, collation)
indexPool.Put(indexes)
}
b.ReportAllocs()
}
@@ -0,0 +1,154 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package branch_control
import (
"sync"
"github.com/dolthub/go-mysql-server/sql"
)
// Namespace contains all of the expressions that comprise the "dolt_branch_namespace_control" table, which controls
// which users may use which branch names when creating branches. Modification of this table is handled by the Access
// table.
type Namespace struct {
access *Access
binlog *Binlog
Branches []MatchExpression
Users []MatchExpression
Hosts []MatchExpression
Values []NamespaceValue
SuperUser string
SuperHost string
RWMutex *sync.RWMutex
}
// NamespaceValue contains the user-facing values of a particular row.
type NamespaceValue struct {
Branch string
User string
Host string
}
// newNamespace returns a new Namespace.
func newNamespace(accessTbl *Access, superUser string, superHost string) *Namespace {
return &Namespace{
access: accessTbl,
Branches: nil,
Users: nil,
Hosts: nil,
Values: nil,
SuperUser: superUser,
SuperHost: superHost,
RWMutex: &sync.RWMutex{},
}
}
// CanCreate checks the given branch, and returns whether the given user and host combination is able to create that
// branch. Handles the super user case.
func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
// Super user can always create branches
if user == tbl.SuperUser && host == tbl.SuperHost {
return true
}
matchedSet := Match(tbl.Branches, branch, sql.Collation_utf8mb4_0900_ai_ci)
// If there are no branch entries, then the Namespace is unrestricted
if len(matchedSet) == 0 {
indexPool.Put(matchedSet)
return true
}
// We take either the longest match, or the set of longest matches if multiple matches have the same length
longest := -1
filteredIndexes := indexPool.Get().([]uint32)[:0]
for _, matched := range matchedSet {
matchedValue := tbl.Values[matched]
// If we've found a longer match, then we reset the slice. We append to it in the following if statement.
if len(matchedValue.Branch) > longest {
filteredIndexes = filteredIndexes[:0]
}
if len(matchedValue.Branch) >= longest {
filteredIndexes = append(filteredIndexes, matched)
}
}
indexPool.Put(matchedSet)
filteredUsers := tbl.filterUsers(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredUsers, user, sql.Collation_utf8mb4_0900_bin)
matchExprPool.Put(filteredUsers)
filteredHosts := tbl.filterHosts(filteredIndexes)
indexPool.Put(filteredIndexes)
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
matchExprPool.Put(filteredHosts)
result := len(filteredIndexes) > 0
indexPool.Put(filteredIndexes)
return result
}
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
// returns -1. Assumes that the given expressions have already been folded.
func (tbl *Namespace) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
for i, value := range tbl.Values {
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
return i
}
}
return -1
}
// Access returns the Access table.
func (tbl *Namespace) Access() *Access {
return tbl.access
}
// filterBranches returns all branches that match the given collection indexes.
func (tbl *Namespace) filterBranches(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Branches[filter])
}
return matchExprs
}
// filterUsers returns all users that match the given collection indexes.
func (tbl *Namespace) filterUsers(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Users[filter])
}
return matchExprs
}
// filterHosts returns all hosts that match the given collection indexes.
func (tbl *Namespace) filterHosts(filters []uint32) []MatchExpression {
if len(filters) == 0 {
return nil
}
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
for _, filter := range filters {
matchExprs = append(matchExprs, tbl.Hosts[filter])
}
return matchExprs
}
+27
View File
@@ -26,6 +26,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/mysql_db"
"gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk"
@@ -361,6 +362,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
if head == nil {
var err error
head, err = ds.GetHeadCommit(ctx, db.Name())
if err != nil {
return nil, false, err
}
@@ -465,6 +467,10 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
dt, found = dtables.NewMergeStatusTable(db.name), true
case doltdb.TagsTableName:
dt, found = dtables.NewTagsTable(ctx, db.ddb), true
case dtables.AccessTableName:
dt, found = dtables.NewBranchControlTable(branch_control.StaticController.Access), true
case dtables.NamespaceTableName:
dt, found = dtables.NewBranchNamespaceControlTable(branch_control.StaticController.Namespace), true
}
if found {
return dt, found, nil
@@ -728,6 +734,9 @@ func (db Database) GetHeadRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
// DropTable drops the table with the name given.
// The planner returns the correct case sensitive name in tableName
func (db Database) DropTable(ctx *sql.Context, tableName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if doltdb.IsReadOnlySystemTable(tableName) {
return ErrSystemTableAlter.New(tableName)
}
@@ -827,6 +836,9 @@ func (db Database) removeTableFromAutoIncrementTracker(
// CreateTable creates a table with the name and schema given.
func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, collation sql.CollationID) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if strings.ToLower(tableName) == doltdb.DocTableName {
// validate correct schema
if !dtables.DoltDocsSqlSchema.Equals(sch.Schema) {
@@ -938,6 +950,9 @@ func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, pkSc
// RenameTable implements sql.TableRenamer
func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := db.GetRoot(ctx)
if err != nil {
@@ -1152,15 +1167,24 @@ func (db Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureD
// SaveStoredProcedure implements sql.StoredProcedureDatabase.
func (db Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
return DoltProceduresAddProcedure(ctx, db, spd)
}
// DropStoredProcedure implements sql.StoredProcedureDatabase.
func (db Database) DropStoredProcedure(ctx *sql.Context, name string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
return DoltProceduresDropProcedure(ctx, db, name)
}
func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, definition string, created time.Time, existingErr error) (err error) {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
tbl, err := GetOrCreateDoltSchemasTable(ctx, db)
if err != nil {
return err
@@ -1212,6 +1236,9 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
}
func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
stbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
if err != nil {
return err
@@ -23,6 +23,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
@@ -120,6 +121,12 @@ func renameBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseRe
if oldBranchName == "" || newBranchName == "" {
return EmptyBranchNameErr
}
if err := branch_control.CanDeleteBranch(ctx, oldBranchName); err != nil {
return err
}
if err := branch_control.CanCreateBranch(ctx, newBranchName); err != nil {
return err
}
force := apr.Contains(cli.ForceFlag)
if !force {
@@ -168,6 +175,13 @@ func deleteBranches(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParse
}
}
// Verify that we can delete all branches before continuing
for _, branchName := range apr.Args {
if err = branch_control.CanDeleteBranch(ctx, branchName); err != nil {
return err
}
}
var updateFS = false
for _, branchName := range apr.Args {
if len(branchName) == 0 {
@@ -278,6 +292,9 @@ func createNewBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgPars
return EmptyBranchNameErr
}
if err := branch_control.CanCreateBranch(ctx, branchName); err != nil {
return err
}
return actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag))
}
@@ -301,6 +318,9 @@ func copyBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResu
}
func copyABranch(ctx *sql.Context, dbData env.DbData, srcBr string, destBr string, force bool) error {
if err := branch_control.CanCreateBranch(ctx, destBr); err != nil {
return err
}
err := actions.CopyBranchOnDB(ctx, dbData.Ddb, srcBr, destBr, force)
if err != nil {
if err == doltdb.ErrBranchNotFound {
@@ -15,6 +15,7 @@
package dsess
import (
"context"
"errors"
"fmt"
"strconv"
@@ -26,6 +27,7 @@ import (
goerrors "gopkg.in/src-d/go-errors.v1"
"github.com/dolthub/dolt/go/cmd/dolt/cli"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
@@ -68,6 +70,7 @@ type DoltSession struct {
var _ sql.Session = (*DoltSession)(nil)
var _ sql.PersistableSession = (*DoltSession)(nil)
var _ branch_control.Context = (*DoltSession)(nil)
// DefaultSession creates a DoltSession with default values
func DefaultSession(pro DoltDatabaseProvider) *DoltSession {
@@ -1192,6 +1195,31 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
return sysVars, nil
}
// GetBranch implements the interface branch_control.Context.
func (d *DoltSession) GetBranch() (string, error) {
branchRef, err := d.CWBHeadRef(sql.NewContext(context.Background(), sql.WithSession(d)), d.Session.GetCurrentDatabase())
if err != nil {
return "", err
}
return branchRef.GetPath(), nil
}
// GetUser implements the interface branch_control.Context.
func (d *DoltSession) GetUser() string {
return d.Session.Client().User
}
// GetHost implements the interface branch_control.Context.
func (d *DoltSession) GetHost() string {
return d.Session.Client().Address
}
// GetController implements the interface branch_control.Context.
func (d *DoltSession) GetController() *branch_control.Controller {
//TODO implement me
panic("implement me")
}
// validatePersistedSysVar checks whether a system variable exists and is dynamic
func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) {
sysVar, val, ok := sql.SystemVariables.GetGlobal(name)
@@ -0,0 +1,344 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing branch_control.Permissions and
// limitations under the License.
package dtables
import (
"context"
"fmt"
"math"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
)
const (
AccessTableName = "dolt_branch_control"
)
// PermissionsStrings is a slice of strings representing the available branch_control.branch_control.Permissions. The order of the
// strings should exactly match the order of the branch_control.Permissions according to their flag value.
var PermissionsStrings = []string{"admin", "write"}
// accessSchema is the schema for the "dolt_branch_control" table.
var accessSchema = sql.Schema{
&sql.Column{
Name: "branch",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "user",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
Source: AccessTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "host",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "branch_control.Permissions",
Type: sql.MustCreateSetType(PermissionsStrings, sql.Collation_utf8mb4_0900_ai_ci),
Source: AccessTableName,
PrimaryKey: false,
},
}
// BranchControlTable provides a layer over the branch_control.Access structure, exposing it as a system table.
type BranchControlTable struct {
*branch_control.Access
}
var _ sql.Table = BranchControlTable{}
var _ sql.InsertableTable = BranchControlTable{}
var _ sql.ReplaceableTable = BranchControlTable{}
var _ sql.UpdatableTable = BranchControlTable{}
var _ sql.DeletableTable = BranchControlTable{}
var _ sql.RowInserter = BranchControlTable{}
var _ sql.RowReplacer = BranchControlTable{}
var _ sql.RowUpdater = BranchControlTable{}
var _ sql.RowDeleter = BranchControlTable{}
// NewBranchControlTable returns a new BranchControlTable.
func NewBranchControlTable(access *branch_control.Access) BranchControlTable {
return BranchControlTable{access}
}
// Name implements the interface sql.Table.
func (tbl BranchControlTable) Name() string {
return AccessTableName
}
// String implements the interface sql.Table.
func (tbl BranchControlTable) String() string {
return AccessTableName
}
// Schema implements the interface sql.Table.
func (tbl BranchControlTable) Schema() sql.Schema {
return accessSchema
}
// Collation implements the interface sql.Table.
func (tbl BranchControlTable) Collation() sql.CollationID {
return sql.Collation_Default
}
// Partitions implements the interface sql.Table.
func (tbl BranchControlTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
return index.SinglePartitionIterFromNomsMap(nil), nil
}
// PartitionRows implements the interface sql.Table.
func (tbl BranchControlTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
tbl.RWMutex.RLock()
defer tbl.RWMutex.RUnlock()
rows := []sql.Row{{"%", tbl.SuperUser, tbl.SuperHost, uint64(branch_control.Permissions_Admin)}}
for _, value := range tbl.Values {
rows = append(rows, sql.Row{
value.Branch,
value.User,
value.Host,
uint64(value.Permissions),
})
}
return sql.RowsToRowIter(rows...), nil
}
// Inserter implements the interface sql.InsertableTable.
func (tbl BranchControlTable) Inserter(context *sql.Context) sql.RowInserter {
return tbl
}
// Replacer implements the interface sql.ReplaceableTable.
func (tbl BranchControlTable) Replacer(ctx *sql.Context) sql.RowReplacer {
return tbl
}
// Updater implements the interface sql.UpdatableTable.
func (tbl BranchControlTable) Updater(ctx *sql.Context) sql.RowUpdater {
return tbl
}
// Deleter implements the interface sql.DeletableTable.
func (tbl BranchControlTable) Deleter(context *sql.Context) sql.RowDeleter {
return tbl
}
// StatementBegin implements the interface sql.TableEditor.
func (tbl BranchControlTable) StatementBegin(ctx *sql.Context) {
//TODO: will use the binlog to implement
}
// DiscardChanges implements the interface sql.TableEditor.
func (tbl BranchControlTable) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
//TODO: will use the binlog to implement
return nil
}
// StatementComplete implements the interface sql.TableEditor.
func (tbl BranchControlTable) StatementComplete(ctx *sql.Context) error {
//TODO: will use the binlog to implement
return nil
}
// Insert implements the interface sql.RowInserter.
func (tbl BranchControlTable) Insert(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
perms := branch_control.Permissions(row[3].(uint64))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the insertion has permission to perform the insertion.
_, modPerms := tbl.Match(branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(uint64(perms))
return fmt.Errorf("`%s`@`%s` cannot add the row [%q, %q, %q, %q]",
insertUser, insertHost, branch, user, host, permStr)
}
}
return tbl.insert(ctx, branch, user, host, perms)
}
// Update implements the interface sql.RowUpdater.
func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldUser := branch_control.FoldExpression(old[1].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newUser := branch_control.FoldExpression(new[1].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
newPerms := branch_control.Permissions(new[3].(uint64))
// Verify that the lengths of each expression fit within an uint16
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", newBranch, newUser, newHost)
}
// If we're not updating the same row, then we pre-emptively check for a row violation
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
permBits := uint64(tbl.Values[tblIndex].Permissions)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
true,
sql.Row{newBranch, newUser, newHost, permBits})
}
}
// A nil session means we're not in the SQL context, so we'd allow the update in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Match(oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q]",
insertUser, insertHost, oldBranch, oldUser, oldHost)
}
// Now we check if the user has permission use the new branch name
_, modPerms = tbl.Match(newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q",
insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
}
}
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
return err
}
}
return tbl.insert(ctx, newBranch, newUser, newHost, newPerms)
}
// Delete implements the interface sql.RowDeleter.
func (tbl BranchControlTable) Delete(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the deletion has permission to perform the deletion.
_, modPerms := tbl.Match(branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot delete the row [%q, %q, %q]",
insertUser, insertHost, branch, user, host)
}
}
return tbl.delete(ctx, branch, user, host)
}
// Close implements the interface sql.Closer.
func (tbl BranchControlTable) Close(context *sql.Context) error {
//TODO: write the binlog
return nil
}
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
// already been folded.
func (tbl BranchControlTable) insert(ctx context.Context, branch string, user string, host string, perms branch_control.Permissions) error {
// If we already have this in the table, then we return a duplicate PK error
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
permBits := uint64(tbl.Values[tblIndex].Permissions)
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
true,
sql.Row{branch, user, host, permBits})
}
// Add the expressions to their respective slices
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
nextIdx := uint32(len(tbl.Values))
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
tbl.Values = append(tbl.Values, branch_control.AccessValue{
Branch: branch,
User: user,
Host: host,
Permissions: perms,
})
return nil
}
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
// already been folded.
func (tbl BranchControlTable) delete(ctx context.Context, branch string, user string, host string) error {
// If we don't have this in the table, then we just return
tblIndex := tbl.GetIndex(branch, user, host)
if tblIndex == -1 {
return nil
}
endIndex := len(tbl.Values) - 1
// Remove the matching row from all slices by first swapping with the last element
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
// Then we remove the last element
tbl.Branches = tbl.Branches[:endIndex]
tbl.Users = tbl.Users[:endIndex]
tbl.Hosts = tbl.Hosts[:endIndex]
tbl.Values = tbl.Values[:endIndex]
return nil
}
@@ -0,0 +1,333 @@
// Copyright 2022 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dtables
import (
"context"
"fmt"
"math"
"strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
)
const (
NamespaceTableName = "dolt_branch_namespace_control"
)
// namespaceSchema is the schema for the "dolt_branch_namespace_control" table.
var namespaceSchema = sql.Schema{
&sql.Column{
Name: "branch",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: NamespaceTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "user",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
Source: NamespaceTableName,
PrimaryKey: true,
},
&sql.Column{
Name: "host",
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
Source: NamespaceTableName,
PrimaryKey: true,
},
}
// BranchNamespaceControlTable provides a layer over the branch_control.Namespace structure, exposing it as a system
// table.
type BranchNamespaceControlTable struct {
*branch_control.Namespace
}
var _ sql.Table = BranchNamespaceControlTable{}
var _ sql.InsertableTable = BranchNamespaceControlTable{}
var _ sql.ReplaceableTable = BranchNamespaceControlTable{}
var _ sql.UpdatableTable = BranchNamespaceControlTable{}
var _ sql.DeletableTable = BranchNamespaceControlTable{}
var _ sql.RowInserter = BranchNamespaceControlTable{}
var _ sql.RowReplacer = BranchNamespaceControlTable{}
var _ sql.RowUpdater = BranchNamespaceControlTable{}
var _ sql.RowDeleter = BranchNamespaceControlTable{}
// NewBranchNamespaceControlTable returns a new BranchNamespaceControlTable.
func NewBranchNamespaceControlTable(namespace *branch_control.Namespace) BranchNamespaceControlTable {
return BranchNamespaceControlTable{namespace}
}
// Name implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) Name() string {
return NamespaceTableName
}
// String implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) String() string {
return NamespaceTableName
}
// Schema implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) Schema() sql.Schema {
return namespaceSchema
}
// Collation implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) Collation() sql.CollationID {
return sql.Collation_Default
}
// Partitions implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
return index.SinglePartitionIterFromNomsMap(nil), nil
}
// PartitionRows implements the interface sql.Table.
func (tbl BranchNamespaceControlTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
tbl.RWMutex.RLock()
defer tbl.RWMutex.RUnlock()
var rows []sql.Row
for _, value := range tbl.Values {
rows = append(rows, sql.Row{
value.Branch,
value.User,
value.Host,
})
}
return sql.RowsToRowIter(rows...), nil
}
// Inserter implements the interface sql.InsertableTable.
func (tbl BranchNamespaceControlTable) Inserter(context *sql.Context) sql.RowInserter {
return tbl
}
// Replacer implements the interface sql.ReplaceableTable.
func (tbl BranchNamespaceControlTable) Replacer(ctx *sql.Context) sql.RowReplacer {
return tbl
}
// Updater implements the interface sql.UpdatableTable.
func (tbl BranchNamespaceControlTable) Updater(ctx *sql.Context) sql.RowUpdater {
return tbl
}
// Deleter implements the interface sql.DeletableTable.
func (tbl BranchNamespaceControlTable) Deleter(context *sql.Context) sql.RowDeleter {
return tbl
}
// StatementBegin implements the interface sql.TableEditor.
func (tbl BranchNamespaceControlTable) StatementBegin(ctx *sql.Context) {
//TODO: will use the binlog to implement
}
// DiscardChanges implements the interface sql.TableEditor.
func (tbl BranchNamespaceControlTable) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
//TODO: will use the binlog to implement
return nil
}
// StatementComplete implements the interface sql.TableEditor.
func (tbl BranchNamespaceControlTable) StatementComplete(ctx *sql.Context) error {
//TODO: will use the binlog to implement
return nil
}
// Insert implements the interface sql.RowInserter.
func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
// Need to acquire a read lock on the Access table since we have to read from it
tbl.Access().RWMutex.RLock()
defer tbl.Access().RWMutex.RUnlock()
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the insertion has permission to perform the insertion.
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot add the row [%q, %q, %q]",
insertUser, insertHost, branch, user, host)
}
}
return tbl.insert(ctx, branch, user, host)
}
// Update implements the interface sql.RowUpdater.
func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
oldUser := branch_control.FoldExpression(old[1].(string))
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
newUser := branch_control.FoldExpression(new[1].(string))
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", newBranch, newUser, newHost)
}
// If we're not updating the same row, then we pre-emptively check for a row violation
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q]`, newBranch, newUser, newHost),
true,
sql.Row{newBranch, newUser, newHost})
}
}
// A nil session means we're not in the SQL context, so we'd allow the update in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
// Need to acquire a read lock on the Access table since we have to read from it
tbl.Access().RWMutex.RLock()
defer tbl.Access().RWMutex.RUnlock()
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the update has permission to perform the update on the old branch name.
_, modPerms := tbl.Access().Match(oldBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q]",
insertUser, insertHost, oldBranch, oldUser, oldHost)
}
// Now we check if the user has permission use the new branch name
_, modPerms = tbl.Access().Match(newBranch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q",
insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
}
}
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
return err
}
}
return tbl.insert(ctx, newBranch, newUser, newHost)
}
// Delete implements the interface sql.RowDeleter.
func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) error {
tbl.RWMutex.Lock()
defer tbl.RWMutex.Unlock()
// Branch and Host are case-insensitive, while user is case-sensitive
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
user := branch_control.FoldExpression(row[1].(string))
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
// Verify that the lengths of each expression fit within an uint16
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
}
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
// Need to acquire a read lock on the Access table since we have to read from it
tbl.Access().RWMutex.RLock()
defer tbl.Access().RWMutex.RUnlock()
insertUser := branchAwareSession.GetUser()
insertHost := branchAwareSession.GetHost()
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
// determine if the user attempting the deletion has permission to perform the deletion.
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
return fmt.Errorf("`%s`@`%s` cannot delete the row [%q, %q, %q]",
insertUser, insertHost, branch, user, host)
}
}
return tbl.delete(ctx, branch, user, host)
}
// Close implements the interface sql.Closer.
func (tbl BranchNamespaceControlTable) Close(context *sql.Context) error {
//TODO: write the binlog
return nil
}
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
// already been folded.
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, branch string, user string, host string) error {
// If we already have this in the table, then we return a duplicate PK error
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
return sql.NewUniqueKeyErr(
fmt.Sprintf(`[%q, %q, %q]`, branch, user, host),
true,
sql.Row{branch, user, host})
}
// Add the expressions to their respective slices
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
nextIdx := uint32(len(tbl.Values))
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
return nil
}
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
// already been folded.
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, branch string, user string, host string) error {
// If we don't have this in the table, then we just return
tblIndex := tbl.GetIndex(branch, user, host)
if tblIndex == -1 {
return nil
}
endIndex := len(tbl.Values) - 1
// Remove the matching row from all slices by first swapping with the last element
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
// Then we remove the last element
tbl.Branches = tbl.Branches[:endIndex]
tbl.Users = tbl.Users[:endIndex]
tbl.Hosts = tbl.Hosts[:endIndex]
tbl.Values = tbl.Values[:endIndex]
return nil
}
@@ -447,12 +447,11 @@ func TestJSONTableScripts(t *testing.T) {
}
func TestUserPrivileges(t *testing.T) {
t.Skip("Need to add more collations")
enginetest.TestUserPrivileges(t, newDoltHarness(t))
}
func TestUserAuthentication(t *testing.T) {
t.Skip("Need to add more collations")
t.Skip("Unexpected panic, need to fix")
enginetest.TestUserAuthentication(t, newDoltHarness(t))
}
+58
View File
@@ -31,6 +31,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
@@ -609,6 +610,9 @@ func (t *WritableDoltTable) WithProjections(colNames []string) sql.Table {
// Inserter implements sql.InsertableTable
func (t *WritableDoltTable) Inserter(ctx *sql.Context) sql.RowInserter {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return sqlutil.NewStaticErrorEditor(err)
}
te, err := t.getTableEditor(ctx)
if err != nil {
return sqlutil.NewStaticErrorEditor(err)
@@ -648,6 +652,9 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri
// Deleter implements sql.DeletableTable
func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return sqlutil.NewStaticErrorEditor(err)
}
te, err := t.getTableEditor(ctx)
if err != nil {
return sqlutil.NewStaticErrorEditor(err)
@@ -657,6 +664,9 @@ func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter {
// Replacer implements sql.ReplaceableTable
func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return sqlutil.NewStaticErrorEditor(err)
}
te, err := t.getTableEditor(ctx)
if err != nil {
return sqlutil.NewStaticErrorEditor(err)
@@ -666,6 +676,9 @@ func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
// Truncate implements sql.TruncateableTable
func (t *WritableDoltTable) Truncate(ctx *sql.Context) (int, error) {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return 0, err
}
table, err := t.DoltTable.DoltTable(ctx)
if err != nil {
return 0, err
@@ -753,6 +766,9 @@ func (t *WritableDoltTable) truncate(
// Updater implements sql.UpdatableTable
func (t *WritableDoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return sqlutil.NewStaticErrorEditor(err)
}
te, err := t.getTableEditor(ctx)
if err != nil {
return sqlutil.NewStaticErrorEditor(err)
@@ -1157,6 +1173,9 @@ func (t *AlterableDoltTable) WithProjections(colNames []string) sql.Table {
// AddColumn implements sql.AlterableTable
func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := t.getRoot(ctx)
if err != nil {
return err
@@ -1290,6 +1309,9 @@ func (t *AlterableDoltTable) RewriteInserter(
oldColumn *sql.Column,
newColumn *sql.Column,
) (sql.RowInserter, error) {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return nil, err
}
err := validateSchemaChange(t.Name(), oldSchema, newSchema, oldColumn, newColumn)
if err != nil {
return nil, err
@@ -1555,6 +1577,9 @@ func (t *AlterableDoltTable) adjustForeignKeysForDroppedPk(ctx *sql.Context, roo
// DropColumn implements sql.AlterableTable
func (t *AlterableDoltTable) DropColumn(ctx *sql.Context, columnName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if types.IsFormat_DOLT(t.nbf) {
return nil
}
@@ -1671,6 +1696,9 @@ func (t *AlterableDoltTable) dropColumnData(ctx *sql.Context, updatedTable *dolt
// ModifyColumn implements sql.AlterableTable. ModifyColumn operations are only used for operations that change only
// the schema of a table, not the data. For those operations, |RewriteInserter| is used.
func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
ws, err := t.db.GetWorkingSet(ctx)
if err != nil {
return err
@@ -1832,6 +1860,9 @@ func (t *AlterableDoltTable) CreateIndex(
indexColumns []sql.IndexColumn,
comment string,
) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if constraint != sql.IndexConstraint_None && constraint != sql.IndexConstraint_Unique {
return fmt.Errorf("only the following types of index constraints are supported: none, unique")
}
@@ -1901,6 +1932,9 @@ func (t *AlterableDoltTable) CreateIndex(
// DropIndex implements sql.IndexAlterableTable
func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
// We disallow removing internal dolt_ tables from SQL directly
if strings.HasPrefix(indexName, "dolt_") {
return fmt.Errorf("dolt internal indexes may not be dropped")
@@ -1927,6 +1961,9 @@ func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error
// RenameIndex implements sql.IndexAlterableTable
func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
// RenameIndex will error if there is a name collision or an index does not exist
_, err := t.sch.Indexes().RenameIndex(fromIndexName, toIndexName)
if err != nil {
@@ -1965,6 +2002,9 @@ func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string,
// AddForeignKey implements sql.ForeignKeyTable
func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKeyConstraint) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if sqlFk.Name != "" && !doltdb.IsValidForeignKeyName(sqlFk.Name) {
return fmt.Errorf("invalid foreign key name `%s` as it must match the regular expression %s", sqlFk.Name, doltdb.ForeignKeyNameRegexStr)
}
@@ -2152,6 +2192,9 @@ func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKe
// DropForeignKey implements sql.ForeignKeyTable
func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := t.getRoot(ctx)
if err != nil {
return err
@@ -2177,6 +2220,9 @@ func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) err
// UpdateForeignKey implements sql.ForeignKeyTable
func (t *AlterableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, sqlFk sql.ForeignKeyConstraint) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := t.getRoot(ctx)
if err != nil {
return err
@@ -2520,6 +2566,9 @@ func (t *AlterableDoltTable) updateFromRoot(ctx *sql.Context, root *doltdb.RootV
}
func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := t.getRoot(ctx)
if err != nil {
return err
@@ -2573,6 +2622,9 @@ func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefin
}
func (t *AlterableDoltTable) DropCheck(ctx *sql.Context, chName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
root, err := t.getRoot(ctx)
if err != nil {
return err
@@ -2668,6 +2720,9 @@ func (t *AlterableDoltTable) constraintNameExists(ctx *sql.Context, name string)
}
func (t *AlterableDoltTable) CreatePrimaryKey(ctx *sql.Context, columns []sql.IndexColumn) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if types.IsFormat_DOLT(t.nbf) {
return nil
}
@@ -2702,6 +2757,9 @@ func (t *AlterableDoltTable) CreatePrimaryKey(ctx *sql.Context, columns []sql.In
}
func (t *AlterableDoltTable) DropPrimaryKey(ctx *sql.Context) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if types.IsFormat_DOLT(t.nbf) {
return nil
}