mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-08 02:36:27 -05:00
Branch Control
This commit is contained in:
@@ -55,7 +55,9 @@ Similar to {{.EmphasisLeft}}dolt sql-server{{.EmphasisRight}}, this command may
|
||||
},
|
||||
}
|
||||
|
||||
type SqlClientCmd struct{}
|
||||
type SqlClientCmd struct {
|
||||
VersionStr string
|
||||
}
|
||||
|
||||
var _ cli.Command = SqlClientCmd{}
|
||||
|
||||
@@ -126,7 +128,7 @@ func (cmd SqlClientCmd) Exec(ctx context.Context, commandStr string, args []stri
|
||||
|
||||
serverController = NewServerController()
|
||||
go func() {
|
||||
_, _ = Serve(ctx, SqlServerCmd{}.VersionStr, serverConfig, serverController, dEnv)
|
||||
_, _ = Serve(ctx, cmd.VersionStr, serverConfig, serverController, dEnv)
|
||||
}()
|
||||
err = serverController.WaitForStart()
|
||||
if err != nil {
|
||||
|
||||
+1
-1
@@ -73,7 +73,7 @@ var doltCommand = cli.NewSubCommandHandler("dolt", "it's git for data", []cli.Co
|
||||
commands.SqlCmd{VersionStr: Version},
|
||||
admin.Commands,
|
||||
sqlserver.SqlServerCmd{VersionStr: Version},
|
||||
sqlserver.SqlClientCmd{},
|
||||
sqlserver.SqlClientCmd{VersionStr: Version},
|
||||
commands.LogCmd{},
|
||||
commands.BranchCmd{},
|
||||
commands.CheckoutCmd{},
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
// Permissions are a set of flags that denote a user's allowed functionality on a branch.
|
||||
type Permissions uint64
|
||||
|
||||
const (
|
||||
Permissions_Admin Permissions = 1 << iota // Permissions_Admin grants unrestricted control over a branch, including modification of table entries
|
||||
Permissions_Write // Permissions_Write allows for all modifying operations on a branch, but does not allow modification of table entries
|
||||
)
|
||||
|
||||
// Access contains all of the expressions that comprise the "dolt_branch_control" table, which handles write Access to
|
||||
// branches, along with write access to the branch control system tables.
|
||||
type Access struct {
|
||||
binlog *Binlog
|
||||
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []AccessValue
|
||||
SuperUser string
|
||||
SuperHost string
|
||||
RWMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// AccessValue contains the user-facing values of a particular row, along with the permissions for a row.
|
||||
type AccessValue struct {
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
Permissions Permissions
|
||||
}
|
||||
|
||||
// newAccess returns a new Access.
|
||||
func newAccess(superUser string, superHost string) *Access {
|
||||
return &Access{
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
SuperUser: superUser,
|
||||
SuperHost: superHost,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// Match returns whether any entries match the given branch, user, and host, along with their permissions. Requires
|
||||
// external synchronization handling, therefore manually manage the RWMutex.
|
||||
func (tbl *Access) Match(branch string, user string, host string) (bool, Permissions) {
|
||||
filteredIndexes := Match(tbl.Users, user, sql.Collation_utf8mb4_0900_bin)
|
||||
|
||||
filteredHosts := tbl.filterHosts(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredHosts)
|
||||
|
||||
filteredBranches := tbl.filterBranches(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredBranches, branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredBranches)
|
||||
|
||||
bRes, pRes := len(filteredIndexes) > 0, tbl.gatherPermissions(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
return bRes, pRes
|
||||
}
|
||||
|
||||
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
|
||||
// returns -1. Assumes that the given expressions have already been folded. Requires external synchronization handling,
|
||||
// therefore manually manage the RWMutex.
|
||||
func (tbl *Access) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
|
||||
for i, value := range tbl.Values {
|
||||
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// filterBranches returns all branches that match the given collection indexes.
|
||||
func (tbl *Access) filterBranches(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Branches[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterUsers returns all users that match the given collection indexes.
|
||||
func (tbl *Access) filterUsers(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Users[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterHosts returns all hosts that match the given collection indexes.
|
||||
func (tbl *Access) filterHosts(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Hosts[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// gatherPermissions combines all permissions from the given collection indexes and returns the result.
|
||||
func (tbl *Access) gatherPermissions(collectionIndexes []uint32) Permissions {
|
||||
perms := Permissions(0)
|
||||
for _, collectionIndex := range collectionIndexes {
|
||||
perms |= tbl.Values[collectionIndex].Permissions
|
||||
}
|
||||
return perms
|
||||
}
|
||||
@@ -0,0 +1,249 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
currentBinlogVersion = uint16(1)
|
||||
)
|
||||
|
||||
//TODO: add stored procedure functions for modifying the binlog
|
||||
|
||||
// Binlog is a running log file that tracks changes to tables within branch control. This is used for history purposes,
|
||||
// as well as transactional purposes through the use of the BinlogOverlay.
|
||||
type Binlog struct {
|
||||
rows []BinlogRow
|
||||
RWMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// BinlogRow is a row within the Binlog.
|
||||
type BinlogRow struct {
|
||||
IsInsert bool
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
Permissions uint64
|
||||
}
|
||||
|
||||
// BinlogOverlay enables transactional use cases over Binlog. Unlike a Binlog, a BinlogOverlay requires external
|
||||
// synchronization.
|
||||
type BinlogOverlay struct {
|
||||
parentLength int
|
||||
rows []BinlogRow
|
||||
}
|
||||
|
||||
// Serialize returns the Binlog as a byte slice. All encoded integers are big-endian.
|
||||
func (binlog *Binlog) Serialize() []byte {
|
||||
binlog.RWMutex.RLock()
|
||||
defer binlog.RWMutex.RUnlock()
|
||||
|
||||
buffer := bytes.Buffer{}
|
||||
// Write the version bytes
|
||||
writeUint16(&buffer, currentBinlogVersion)
|
||||
// Write the number of entries
|
||||
binlogSize := uint64(len(binlog.rows))
|
||||
writeUint64(&buffer, binlogSize)
|
||||
|
||||
// Write the rows
|
||||
for _, binlogRow := range binlog.rows {
|
||||
binlogRow.Serialize(&buffer)
|
||||
}
|
||||
return buffer.Bytes()
|
||||
}
|
||||
|
||||
// Deserialize populates the binlog with the given data. Returns an error if the data cannot be deserialized, or if the
|
||||
// Binlog has already been written to. Deserialize must be called on an empty Binlog.
|
||||
func (binlog *Binlog) Deserialize(data []byte) error {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
if len(binlog.rows) != 0 {
|
||||
return fmt.Errorf("cannot deserialize to a non-empty binlog")
|
||||
}
|
||||
position := uint64(0)
|
||||
// Read the version
|
||||
version := binary.BigEndian.Uint16(data)
|
||||
position += 2
|
||||
if version != currentBinlogVersion {
|
||||
// If we ever increment the binlog version, this will instead handle the conversion from previous versions
|
||||
return fmt.Errorf(`cannot deserialize a binlog with version "%d"`, version)
|
||||
}
|
||||
// Read the number of entries
|
||||
binlogSize := binary.BigEndian.Uint64(data[position:])
|
||||
position += 8
|
||||
// Read the rows
|
||||
binlog.rows = make([]BinlogRow, binlogSize)
|
||||
for i := uint64(0); i < binlogSize; i++ {
|
||||
binlog.rows[i], position = deserializeBinlogRow(data, position)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewOverlay returns a new BinlogOverlay for the calling Binlog.
|
||||
func (binlog *Binlog) NewOverlay() *BinlogOverlay {
|
||||
binlog.RWMutex.RLock()
|
||||
defer binlog.RWMutex.RUnlock()
|
||||
|
||||
return &BinlogOverlay{
|
||||
parentLength: len(binlog.rows),
|
||||
rows: nil,
|
||||
}
|
||||
}
|
||||
|
||||
// MergeOverlay merges the given BinlogOverlay with the calling Binlog. Fails if the Binlog has been written to since
|
||||
// the overlay was created.
|
||||
func (binlog *Binlog) MergeOverlay(overlay *BinlogOverlay) error {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
// Except for recovery situations, the binlog is an append-only structure, therefore if there are a different number
|
||||
// of entries than when the overlay was created, then it has probably been written to. The likelihood of there being
|
||||
// an outstanding overlay while the binlog is being modified is exceedingly low.
|
||||
if len(binlog.rows) != overlay.parentLength {
|
||||
return fmt.Errorf("cannot merge overlay as binlog has been modified")
|
||||
}
|
||||
binlog.rows = append(binlog.rows, overlay.rows...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert adds an insert entry to the Binlog.
|
||||
func (binlog *Binlog) Insert(branch string, user string, host string, permissions uint64) {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
binlog.rows = append(binlog.rows, BinlogRow{
|
||||
IsInsert: true,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Permissions: permissions,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete adds a delete entry to the Binlog.
|
||||
func (binlog *Binlog) Delete(branch string, user string, host string, permissions uint64) {
|
||||
binlog.RWMutex.Lock()
|
||||
defer binlog.RWMutex.Unlock()
|
||||
|
||||
binlog.rows = append(binlog.rows, BinlogRow{
|
||||
IsInsert: false,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Permissions: permissions,
|
||||
})
|
||||
}
|
||||
|
||||
// Serialize writes the row to the given buffer. All encoded integers are big-endian.
|
||||
func (row *BinlogRow) Serialize(buffer *bytes.Buffer) {
|
||||
// Write whether this was an insertion or deletion
|
||||
if row.IsInsert {
|
||||
buffer.WriteByte(1)
|
||||
} else {
|
||||
buffer.WriteByte(0)
|
||||
}
|
||||
// Write the branch
|
||||
branchLen := uint16(len(row.Branch))
|
||||
writeUint16(buffer, branchLen)
|
||||
buffer.WriteString(row.Branch)
|
||||
// Write the user
|
||||
userLen := uint16(len(row.User))
|
||||
writeUint16(buffer, userLen)
|
||||
buffer.WriteString(row.User)
|
||||
// Write the host
|
||||
hostLen := uint16(len(row.Host))
|
||||
writeUint16(buffer, hostLen)
|
||||
buffer.WriteString(row.Host)
|
||||
// Write the permissions
|
||||
writeUint64(buffer, row.Permissions)
|
||||
}
|
||||
|
||||
// deserializeBinlogRow returns a BinlogRow from the data at the given position. Also returns the new position. Assumes
|
||||
// that the given data's encoded integers are big-endian.
|
||||
func deserializeBinlogRow(data []byte, position uint64) (BinlogRow, uint64) {
|
||||
binlogRow := BinlogRow{}
|
||||
// Read whether this was an insert or write
|
||||
if data[position] == 1 {
|
||||
binlogRow.IsInsert = true
|
||||
} else {
|
||||
binlogRow.IsInsert = false
|
||||
}
|
||||
position += 1
|
||||
// Read the branch
|
||||
branchLen := uint64(binary.BigEndian.Uint16(data[position:]))
|
||||
position += 2
|
||||
binlogRow.Branch = string(data[position : position+branchLen])
|
||||
position += branchLen
|
||||
// Read the user
|
||||
userLen := uint64(binary.BigEndian.Uint16(data[position:]))
|
||||
position += 2
|
||||
binlogRow.User = string(data[position : position+userLen])
|
||||
position += userLen
|
||||
// Read the host
|
||||
hostLen := uint64(binary.BigEndian.Uint16(data[position:]))
|
||||
position += 2
|
||||
binlogRow.Host = string(data[position : position+hostLen])
|
||||
position += hostLen
|
||||
// Read the permissions
|
||||
binlogRow.Permissions = binary.BigEndian.Uint64(data[position:])
|
||||
position += 8
|
||||
return binlogRow, position
|
||||
}
|
||||
|
||||
// Insert adds an insert entry to the BinlogOverlay.
|
||||
func (overlay *BinlogOverlay) Insert(branch string, user string, host string, permissions uint64) {
|
||||
overlay.rows = append(overlay.rows, BinlogRow{
|
||||
IsInsert: true,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Permissions: permissions,
|
||||
})
|
||||
}
|
||||
|
||||
// Delete adds a delete entry to the BinlogOverlay.
|
||||
func (overlay *BinlogOverlay) Delete(branch string, user string, host string, permissions uint64) {
|
||||
overlay.rows = append(overlay.rows, BinlogRow{
|
||||
IsInsert: false,
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Permissions: permissions,
|
||||
})
|
||||
}
|
||||
|
||||
// writeUint64 writes an uint64 into the buffer.
|
||||
func writeUint64(buffer *bytes.Buffer, val uint64) {
|
||||
buffer.WriteByte(byte(val >> 56))
|
||||
buffer.WriteByte(byte(val >> 48))
|
||||
buffer.WriteByte(byte(val >> 40))
|
||||
buffer.WriteByte(byte(val >> 32))
|
||||
buffer.WriteByte(byte(val >> 24))
|
||||
buffer.WriteByte(byte(val >> 16))
|
||||
buffer.WriteByte(byte(val >> 8))
|
||||
buffer.WriteByte(byte(val))
|
||||
}
|
||||
|
||||
// writeUint16 writes an uint16 into the buffer.
|
||||
func writeUint16(buffer *bytes.Buffer, val uint16) {
|
||||
buffer.WriteByte(byte(val >> 8))
|
||||
buffer.WriteByte(byte(val))
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
|
||||
"gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrIncorrectPermissions = errors.NewKind("`%s`@`%s` does not have the correct permissions on branch `%s`")
|
||||
ErrCannotCreateBranch = errors.NewKind("`%s`@`%s` cannot create a branch named `%s`")
|
||||
ErrCannotDeleteBranch = errors.NewKind("`%s`@`%s` cannot delete the branch `%s`")
|
||||
)
|
||||
|
||||
// Context represents the interface that must be inherited from the context.
|
||||
type Context interface {
|
||||
GetBranch() (string, error)
|
||||
GetUser() string
|
||||
GetHost() string
|
||||
GetController() *Controller
|
||||
}
|
||||
|
||||
// Controller is the central hub for branch control functions. This is passed within a context.
|
||||
type Controller struct {
|
||||
Access *Access
|
||||
Namespace *Namespace
|
||||
}
|
||||
|
||||
//TODO: delete me
|
||||
var StaticController = &Controller{}
|
||||
var enabled = false
|
||||
|
||||
func init() {
|
||||
if os.Getenv("DOLT_ENABLE_BRANCH_CONTROL") != "" {
|
||||
enabled = true
|
||||
}
|
||||
StaticController = CreateControllerWithSuperUser(context.Background(), "root", "localhost")
|
||||
}
|
||||
|
||||
// CreateController returns an empty *Controller.
|
||||
func CreateController(ctx context.Context) *Controller {
|
||||
accessTbl := newAccess("", "")
|
||||
return &Controller{
|
||||
Access: accessTbl,
|
||||
Namespace: newNamespace(accessTbl, "", ""),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateControllerWithSuperUser returns a controller with the given user and host set as an immutable super user.
|
||||
func CreateControllerWithSuperUser(ctx context.Context, superUser string, superHost string) *Controller {
|
||||
accessTbl := newAccess(superUser, superHost)
|
||||
return &Controller{
|
||||
Access: accessTbl,
|
||||
Namespace: newNamespace(accessTbl, superUser, superHost),
|
||||
}
|
||||
}
|
||||
|
||||
// CheckAccess returns whether the given context has the correct permissions on its selected branch. In general, SQL
|
||||
// statements will almost always return a *sql.Context, so any checks from the SQL path will correctly check for branch
|
||||
// permissions. However, not all CLI commands use *sql.Context, and therefore will not have any user associated with
|
||||
// the context. In these cases, CheckAccess will pass as we want to allow all local commands to ignore branch
|
||||
// permissions.
|
||||
func CheckAccess(ctx context.Context, flags Permissions) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we allow all operations
|
||||
if branchAwareSession == nil {
|
||||
return nil
|
||||
}
|
||||
StaticController.Access.RWMutex.RLock()
|
||||
defer StaticController.Access.RWMutex.RUnlock()
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
// Check if the user is the super user, which has access to all operations
|
||||
if user == StaticController.Access.SuperUser && host == StaticController.Access.SuperHost {
|
||||
return nil
|
||||
}
|
||||
branch, err := branchAwareSession.GetBranch()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Get the permissions for the branch, user, and host combination
|
||||
_, perms := StaticController.Access.Match(branch, user, host)
|
||||
// If either the flags match or the user is an admin for this branch, then we allow access
|
||||
if (perms&flags == flags) || (perms&Permissions_Admin == Permissions_Admin) {
|
||||
return nil
|
||||
}
|
||||
return ErrIncorrectPermissions.New(user, host, branch)
|
||||
}
|
||||
|
||||
// CanCreateBranch returns whether the given context can create a branch with the given name. In general, SQL statements
|
||||
// will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
|
||||
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
|
||||
// these cases, CanCreateBranch will pass as we want to allow all local commands to freely create branches.
|
||||
func CanCreateBranch(ctx context.Context, branchName string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
if branchAwareSession == nil {
|
||||
return nil
|
||||
}
|
||||
StaticController.Namespace.RWMutex.RLock()
|
||||
defer StaticController.Namespace.RWMutex.RUnlock()
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
if StaticController.Namespace.CanCreate(branchName, user, host) {
|
||||
return nil
|
||||
}
|
||||
return ErrCannotCreateBranch.New(user, host, branchName)
|
||||
}
|
||||
|
||||
// CanDeleteBranch returns whether the given context can delete a branch with the given name. In general, SQL statements
|
||||
// will almost always return a *sql.Context, so any checks from the SQL path will be able to validate a branch's name.
|
||||
// However, not all CLI commands use *sql.Context, and therefore will not have any user associated with the context. In
|
||||
// these cases, CanDeleteBranch will pass as we want to allow all local commands to freely delete branches.
|
||||
func CanDeleteBranch(ctx context.Context, branchName string) error {
|
||||
if !enabled {
|
||||
return nil
|
||||
}
|
||||
branchAwareSession := GetBranchAwareSession(ctx)
|
||||
// A nil session means we're not in the SQL context, so we allow the delete operation
|
||||
if branchAwareSession == nil {
|
||||
return nil
|
||||
}
|
||||
StaticController.Access.RWMutex.RLock()
|
||||
defer StaticController.Access.RWMutex.RUnlock()
|
||||
|
||||
user := branchAwareSession.GetUser()
|
||||
host := branchAwareSession.GetHost()
|
||||
// Check if the user is the super user, which is always able to delete branches
|
||||
if user == StaticController.Access.SuperUser && host == StaticController.Access.SuperHost {
|
||||
return nil
|
||||
}
|
||||
// Get the permissions for the branch, user, and host combination
|
||||
_, perms := StaticController.Access.Match(branchName, user, host)
|
||||
// If the user has the write or admin flags, then we allow access
|
||||
if (perms&Permissions_Write == Permissions_Write) || (perms&Permissions_Admin == Permissions_Admin) {
|
||||
return nil
|
||||
}
|
||||
return ErrCannotDeleteBranch.New(user, host, branchName)
|
||||
}
|
||||
|
||||
// GetBranchAwareSession returns the session contained within the context. If the context does NOT contain a session,
|
||||
// then nil is returned.
|
||||
func GetBranchAwareSession(ctx context.Context) Context {
|
||||
if sqlCtx, ok := ctx.(*sql.Context); ok {
|
||||
if bas, ok := sqlCtx.Session.(Context); ok {
|
||||
return bas
|
||||
}
|
||||
} else if bas, ok := ctx.(Context); ok {
|
||||
return bas
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
const (
|
||||
singleMatch = -1 // Equivalent to the single match character '_'
|
||||
anyMatch = -2 // Equivalent to the any-length match character '%'
|
||||
)
|
||||
|
||||
// invalidMatchExpression is a match expression that does not match anything
|
||||
var invalidMatchExpression = MatchExpression{math.MaxUint32, nil}
|
||||
|
||||
// matchExprPool is a pool for MatchExpression slices. Provides a significant performance benefit.
|
||||
var matchExprPool = &sync.Pool{
|
||||
New: func() any {
|
||||
return make([]MatchExpression, 0, 32)
|
||||
},
|
||||
}
|
||||
|
||||
// indexPool is a pool for index slices (such as those returned by Match). Provides a decent performance benefit.
|
||||
var indexPool = &sync.Pool{
|
||||
New: func() any {
|
||||
return make([]uint32, 0, 32)
|
||||
},
|
||||
}
|
||||
|
||||
// MatchExpression represents a parsed expression that may be matched against. It contains a list of sort orders, which
|
||||
// each represent a comparable value to determine whether any given character is a match. A character's sort order is
|
||||
// obtained from a collation. Also contains its index in the table. MatchExpression contents are not meant to be
|
||||
// comparable to one another, therefore please use the index to compare equivalence.
|
||||
type MatchExpression struct {
|
||||
CollectionIndex uint32 // CollectionIndex represents this expression's index in its parent slice.
|
||||
SortOrders []int32 // These are the sort orders that will be compared against when matching a given rune.
|
||||
}
|
||||
|
||||
// FoldExpression folds the given expression into its smallest form. Expressions have two wildcard operators:
|
||||
// '_' and '%'. '_' matches exactly one character, and it can be any character. '%' can match zero or more of any
|
||||
// character. Taking these two ops into account, the configurations "%_" and "_%" both resolve to matching one or more
|
||||
// of any character. However, the "_%" form is more economical, as you enforce the single match first before checking
|
||||
// for remaining matches. Similarly, "%%" is equivalent to a single '%'. Both of these rules are applied in this
|
||||
// function, guaranteeing that the returned expression is the smallest form that still exactly represents the original.
|
||||
//
|
||||
// This also assumes that '\' is the escape character.
|
||||
func FoldExpression(str string) string {
|
||||
// This loop only terminates when we complete a run where no substitutions were made. Substitutions are applied
|
||||
// linearly, therefore it's possible that one substitution may create an opportunity for another substitution.
|
||||
// To keep the code simple, we continue looping until we have nothing more to do.
|
||||
for true {
|
||||
newStrRunes := make([]rune, 0, len(str))
|
||||
// Skip next is set whenever we encounter the escape character, which is used to explicitly match against '_' and '%'
|
||||
skipNext := false
|
||||
// Consider next is set whenever we encounter an unescaped '%', indicating we may need to apply the substitutions
|
||||
considerNext := false
|
||||
for _, r := range str {
|
||||
if skipNext {
|
||||
skipNext = false
|
||||
newStrRunes = append(newStrRunes, r)
|
||||
continue
|
||||
} else if considerNext {
|
||||
considerNext = false
|
||||
switch r {
|
||||
case '\\':
|
||||
newStrRunes = append(newStrRunes, '%', r) // False alarm, reinsert % before this rune
|
||||
skipNext = true // We also need to ignore the next rune
|
||||
case '_':
|
||||
newStrRunes = append(newStrRunes, r, '%') // Replacing %_ with _%
|
||||
case '%':
|
||||
newStrRunes = append(newStrRunes, r) // Replacing %% with %
|
||||
default:
|
||||
newStrRunes = append(newStrRunes, '%', r) // False alarm, reinsert % before this rune
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
newStrRunes = append(newStrRunes, r)
|
||||
skipNext = true
|
||||
case '%':
|
||||
considerNext = true
|
||||
default:
|
||||
newStrRunes = append(newStrRunes, r)
|
||||
}
|
||||
}
|
||||
// If the very last rune is '%', then this will be true and we need to append it to the end
|
||||
if considerNext {
|
||||
newStrRunes = append(newStrRunes, '%')
|
||||
}
|
||||
newStr := string(newStrRunes)
|
||||
if str == newStr {
|
||||
break
|
||||
}
|
||||
str = newStr
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
// ParseExpression parses the given string expression into a slice of sort ints, which will be used in a MatchExpression.
|
||||
// Returns nil if the string is too long. Assumes that the given string expression has already been folded.
|
||||
func ParseExpression(str string, collation sql.CollationID) []int32 {
|
||||
if len(str) > math.MaxUint16 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sortFunc := collation.Sorter()
|
||||
var orders []int32
|
||||
escaped := false
|
||||
for _, r := range str {
|
||||
if escaped {
|
||||
escaped = false
|
||||
orders = append(orders, sortFunc(r))
|
||||
} else {
|
||||
switch r {
|
||||
case '\\':
|
||||
escaped = true
|
||||
case '%':
|
||||
orders = append(orders, anyMatch)
|
||||
case '_':
|
||||
orders = append(orders, singleMatch)
|
||||
default:
|
||||
orders = append(orders, sortFunc(r))
|
||||
}
|
||||
}
|
||||
}
|
||||
return orders
|
||||
}
|
||||
|
||||
// Match takes the match expression collection, and returns a slice of which collection indexes matched against the
|
||||
// given string. The given indices may be used to further reduce the match expression collection, which will also reduce
|
||||
// the total number of comparisons as they're narrowed down.
|
||||
//
|
||||
// It is vastly more performant to return a slice of collection indexes here, rather than a slice of match expressions.
|
||||
// This is true even when the match expressions are pooled. The reason is unknown, but as we only need the collection
|
||||
// indexes anyway, we discard the match expressions and return only their indexes.
|
||||
func Match(matchExprCollection []MatchExpression, str string, collation sql.CollationID) []uint32 {
|
||||
if len(str) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sortFunc := collation.Sorter()
|
||||
// Grab the first rune and also remove it from the string
|
||||
r, rSize := utf8.DecodeRuneInString(str)
|
||||
str = str[rSize:]
|
||||
// Grab a slice from the pool, which reduces the GC pressure.
|
||||
matchSubset := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
// We do a pass using the first rune over all expressions to get the subset that we'll be testing against
|
||||
for _, testExpr := range matchExprCollection {
|
||||
if matched, next, extra := testExpr.Matches(sortFunc(r)); matched {
|
||||
if extra.IsValid() {
|
||||
matchSubset = append(matchSubset, next, extra)
|
||||
} else {
|
||||
matchSubset = append(matchSubset, next)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Bail early if there are no matches here
|
||||
if len(matchSubset) == 0 {
|
||||
matchExprPool.Put(matchSubset)
|
||||
// We return a slice from the index pool as we later will return it to the pool. We don't want to stick a
|
||||
// nil/empty slice into the pool.
|
||||
return indexPool.Get().([]uint32)[:0]
|
||||
}
|
||||
|
||||
// This is the slice that we'll put matches into. This will also flip to become the match subset. This way we reuse
|
||||
// the underlying arrays. We also grab this from the pool.
|
||||
matches := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
// Now that we have our set of expressions to test, we loop over the remainder of the input string
|
||||
for _, r = range str {
|
||||
for _, testExpr := range matchSubset {
|
||||
if matched, next, extra := testExpr.Matches(sortFunc(r)); matched {
|
||||
if extra.IsValid() {
|
||||
matches = append(matches, next, extra)
|
||||
} else {
|
||||
matches = append(matches, next)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Swap the two, and put the slice of matches to be at the beginning of the previous subset array to reuse it
|
||||
matches, matchSubset = matchSubset[:0], matches
|
||||
}
|
||||
matchExprPool.Put(matches)
|
||||
|
||||
// Grab the indices of all valid matches
|
||||
validMatches := indexPool.Get().([]uint32)[:0]
|
||||
for _, match := range matchSubset {
|
||||
if match.IsAtEnd() && (len(validMatches) == 0 ||
|
||||
(len(validMatches) > 0 && match.CollectionIndex != validMatches[len(validMatches)-1])) {
|
||||
validMatches = append(validMatches, match.CollectionIndex)
|
||||
}
|
||||
}
|
||||
matchExprPool.Put(matchSubset)
|
||||
return validMatches
|
||||
}
|
||||
|
||||
// Matches returns true when the given sort order matches the expectation of the calling match expression. Returns a
|
||||
// reduced match expression as `next`, which should take the place of the calling match function. In the event of a
|
||||
// branch, returns the branching match expression as `extra`.
|
||||
//
|
||||
// Branches occur when the '%' operator sees that the given sort order matches the sort order after the '%'. As it
|
||||
// cannot be determined which path is the correct one (whether to consume the '%' or continue using it), a branch is
|
||||
// created. The `extra` should be checked for validity by calling IsValid.
|
||||
func (matchExpr MatchExpression) Matches(sortOrder int32) (matched bool, next MatchExpression, extra MatchExpression) {
|
||||
if len(matchExpr.SortOrders) == 0 {
|
||||
return false, invalidMatchExpression, invalidMatchExpression
|
||||
}
|
||||
switch matchExpr.SortOrders[0] {
|
||||
case singleMatch:
|
||||
return true, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[1:]}, invalidMatchExpression
|
||||
case anyMatch:
|
||||
if len(matchExpr.SortOrders) > 1 && matchExpr.SortOrders[1] == sortOrder {
|
||||
return true, matchExpr, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[2:]}
|
||||
}
|
||||
return true, matchExpr, invalidMatchExpression
|
||||
default:
|
||||
if sortOrder == matchExpr.SortOrders[0] {
|
||||
return true, MatchExpression{matchExpr.CollectionIndex, matchExpr.SortOrders[1:]}, invalidMatchExpression
|
||||
} else {
|
||||
return false, invalidMatchExpression, invalidMatchExpression
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns whether the match expression is valid. An invalid MatchExpression will have a collection index that
|
||||
// is at the maximum value for an uint32.
|
||||
func (matchExpr MatchExpression) IsValid() bool {
|
||||
return matchExpr.CollectionIndex < math.MaxUint32
|
||||
}
|
||||
|
||||
// IsAtEnd returns whether the match expression has matched every character. There is a special case where, if the last
|
||||
// character is '%', it is considered to be at the end.
|
||||
func (matchExpr MatchExpression) IsAtEnd() bool {
|
||||
return len(matchExpr.SortOrders) == 0 || (len(matchExpr.SortOrders) == 1 && matchExpr.SortOrders[0] == anyMatch)
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSingleMatch(t *testing.T) {
|
||||
tests := []struct {
|
||||
expression string
|
||||
testStr string
|
||||
matches bool
|
||||
collation sql.CollationID
|
||||
}{
|
||||
{"a__", "abc", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a__", "abcd", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a%b", "acb", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a%b", "acdkeflskjfdklb", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a%b", "ab", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a%b", "a", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{"a_b", "ab", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{"aa:%", "aa:bb:cc:dd:ee:ff", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{"aa:%", "AA:BB:CC:DD:EE:FF", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{"aa:%", "AA:BB:CC:DD:EE:FF", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{"a_%_b%_%c", "AaAbCc", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{"a_%_b%_%c", "AaAbBcCbCc", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{"a_%_b%_%c", "AbbbbC", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{"a_%_n%_%z", "aBcDeFgHiJkLmNoPqRsTuVwXyZ", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`a\%b`, "acb", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\%b`, "a%b", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\%b`, "A%B", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\%b`, "A%B", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`a`, "a", true, sql.Collation_utf8mb4_0900_bin},
|
||||
{`ab`, "a", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a\\\b`, "a", false, sql.Collation_utf8mb4_0900_bin},
|
||||
{`a`, "a", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`ab`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`a\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`a\\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`a\\\b`, "a", false, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`A%%%%`, "abc", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
{`A%%%%bc`, "abc", true, sql.Collation_utf8mb4_0900_ai_ci},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%q matches %q", test.testStr, test.expression), func(t *testing.T) {
|
||||
parsedExpression := ParseExpression(FoldExpression(test.expression), test.collation)
|
||||
matchCount := Match([]MatchExpression{{0, parsedExpression}}, test.testStr, test.collation)
|
||||
if test.matches {
|
||||
require.Len(t, matchCount, 1)
|
||||
} else {
|
||||
require.Len(t, matchCount, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleMatch(t *testing.T) {
|
||||
collation := sql.Collation_utf8mb4_0900_ai_ci
|
||||
matchExprs := []MatchExpression{
|
||||
{0, ParseExpression(FoldExpression("a__"), collation)},
|
||||
{1, ParseExpression(FoldExpression("a%b"), collation)},
|
||||
{2, ParseExpression(FoldExpression("a_b"), collation)},
|
||||
{3, ParseExpression(FoldExpression("aa:%"), collation)},
|
||||
{4, ParseExpression(FoldExpression("a_%_b%_%c"), collation)},
|
||||
{5, ParseExpression(FoldExpression(`a\%b`), collation)},
|
||||
{6, ParseExpression(FoldExpression(`a`), collation)},
|
||||
{7, ParseExpression(FoldExpression(`ab`), collation)},
|
||||
{8, ParseExpression(FoldExpression(`a\\b`), collation)},
|
||||
{9, ParseExpression(FoldExpression(`A%%%%`), collation)},
|
||||
{10, ParseExpression(FoldExpression(`A%%%%bc`), collation)},
|
||||
{11, ParseExpression(FoldExpression("a_%_b%_%c%"), collation)},
|
||||
{12, ParseExpression(FoldExpression("a_%_n%_%z"), collation)},
|
||||
}
|
||||
tests := []struct {
|
||||
testStr string
|
||||
indexes []uint32
|
||||
}{
|
||||
{"a", []uint32{6, 9}},
|
||||
{"ab", []uint32{1, 7, 9}},
|
||||
{"abc", []uint32{0, 9, 10}},
|
||||
{"acb", []uint32{0, 1, 2, 9}},
|
||||
{"abcd", []uint32{9}},
|
||||
{"acdkeflskjfdklb", []uint32{1, 9}},
|
||||
{"acdkeflskjfdklbc", []uint32{9, 10}},
|
||||
{"aa:bb:cc:dd:ee:ff", []uint32{3, 9, 11}},
|
||||
{"AA:BB:CC:DD:EE:FF", []uint32{3, 9, 11}},
|
||||
{"AaAbCc", []uint32{4, 9, 11}},
|
||||
{"AaAbBcCbCc", []uint32{4, 9, 11}},
|
||||
{"AbbbbC", []uint32{4, 9, 10, 11}},
|
||||
{"aBcDeFgHiJkLmNoPqRsTuVwXyZ", []uint32{9, 12}},
|
||||
{"a%b", []uint32{0, 1, 2, 5, 9}},
|
||||
{"A%B", []uint32{0, 1, 2, 5, 9}},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%q", test.testStr), func(t *testing.T) {
|
||||
actualMatches := Match(matchExprs, test.testStr, collation)
|
||||
require.ElementsMatch(t, test.indexes, actualMatches)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSimpleCase(b *testing.B) {
|
||||
collation := sql.Collation_utf8mb4_0900_ai_ci
|
||||
matchExprs := []MatchExpression{
|
||||
{0, ParseExpression(FoldExpression("a__"), collation)},
|
||||
{1, ParseExpression(FoldExpression("a%b"), collation)},
|
||||
{2, ParseExpression(FoldExpression("a_b"), collation)},
|
||||
{3, ParseExpression(FoldExpression("aa:%"), collation)},
|
||||
{4, ParseExpression(FoldExpression("a_%_b%_%c"), collation)},
|
||||
{5, ParseExpression(FoldExpression(`a\%b`), collation)},
|
||||
{6, ParseExpression(FoldExpression(`a`), collation)},
|
||||
{7, ParseExpression(FoldExpression(`ab`), collation)},
|
||||
{8, ParseExpression(FoldExpression(`a\\b`), collation)},
|
||||
{9, ParseExpression(FoldExpression(`A%%%%`), collation)},
|
||||
{10, ParseExpression(FoldExpression(`A%%%%bc`), collation)},
|
||||
{11, ParseExpression(FoldExpression("a_%_b%_%c%"), collation)},
|
||||
{12, ParseExpression(FoldExpression("a_%_n%_%z"), collation)},
|
||||
}
|
||||
tests := []struct {
|
||||
testStr string
|
||||
matches []uint32
|
||||
}{
|
||||
{"a", []uint32{6, 9}},
|
||||
{"ab", []uint32{1, 7, 9}},
|
||||
{"abc", []uint32{0, 9, 10}},
|
||||
{"acb", []uint32{0, 1, 2, 9}},
|
||||
{"abcd", []uint32{9}},
|
||||
{"acdkeflskjfdklb", []uint32{1, 9}},
|
||||
{"acdkeflskjfdklbc", []uint32{9, 10}},
|
||||
{"aa:bb:cc:dd:ee:ff", []uint32{3, 9, 11}},
|
||||
{"AA:BB:CC:DD:EE:FF", []uint32{3, 9, 11}},
|
||||
{"AaAbCc", []uint32{4, 9, 11}},
|
||||
{"AaAbBcCbCc", []uint32{4, 9, 11}},
|
||||
{"AbbbbC", []uint32{4, 9, 10, 11}},
|
||||
{"aBcDeFgHiJkLmNoPqRsTuVwXyZ", []uint32{9, 12}},
|
||||
{"a%b", []uint32{0, 1, 2, 5, 9}},
|
||||
{"A%B", []uint32{0, 1, 2, 5, 9}},
|
||||
}
|
||||
testLen := len(tests)
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
test := tests[i%testLen]
|
||||
indexes := Match(matchExprs, test.testStr, collation)
|
||||
indexPool.Put(indexes)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
}
|
||||
@@ -0,0 +1,154 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package branch_control
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
)
|
||||
|
||||
// Namespace contains all of the expressions that comprise the "dolt_branch_namespace_control" table, which controls
|
||||
// which users may use which branch names when creating branches. Modification of this table is handled by the Access
|
||||
// table.
|
||||
type Namespace struct {
|
||||
access *Access
|
||||
binlog *Binlog
|
||||
|
||||
Branches []MatchExpression
|
||||
Users []MatchExpression
|
||||
Hosts []MatchExpression
|
||||
Values []NamespaceValue
|
||||
SuperUser string
|
||||
SuperHost string
|
||||
RWMutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// NamespaceValue contains the user-facing values of a particular row.
|
||||
type NamespaceValue struct {
|
||||
Branch string
|
||||
User string
|
||||
Host string
|
||||
}
|
||||
|
||||
// newNamespace returns a new Namespace.
|
||||
func newNamespace(accessTbl *Access, superUser string, superHost string) *Namespace {
|
||||
return &Namespace{
|
||||
access: accessTbl,
|
||||
Branches: nil,
|
||||
Users: nil,
|
||||
Hosts: nil,
|
||||
Values: nil,
|
||||
SuperUser: superUser,
|
||||
SuperHost: superHost,
|
||||
RWMutex: &sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// CanCreate checks the given branch, and returns whether the given user and host combination is able to create that
|
||||
// branch. Handles the super user case.
|
||||
func (tbl *Namespace) CanCreate(branch string, user string, host string) bool {
|
||||
// Super user can always create branches
|
||||
if user == tbl.SuperUser && host == tbl.SuperHost {
|
||||
return true
|
||||
}
|
||||
matchedSet := Match(tbl.Branches, branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
// If there are no branch entries, then the Namespace is unrestricted
|
||||
if len(matchedSet) == 0 {
|
||||
indexPool.Put(matchedSet)
|
||||
return true
|
||||
}
|
||||
|
||||
// We take either the longest match, or the set of longest matches if multiple matches have the same length
|
||||
longest := -1
|
||||
filteredIndexes := indexPool.Get().([]uint32)[:0]
|
||||
for _, matched := range matchedSet {
|
||||
matchedValue := tbl.Values[matched]
|
||||
// If we've found a longer match, then we reset the slice. We append to it in the following if statement.
|
||||
if len(matchedValue.Branch) > longest {
|
||||
filteredIndexes = filteredIndexes[:0]
|
||||
}
|
||||
if len(matchedValue.Branch) >= longest {
|
||||
filteredIndexes = append(filteredIndexes, matched)
|
||||
}
|
||||
}
|
||||
indexPool.Put(matchedSet)
|
||||
|
||||
filteredUsers := tbl.filterUsers(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredUsers, user, sql.Collation_utf8mb4_0900_bin)
|
||||
matchExprPool.Put(filteredUsers)
|
||||
|
||||
filteredHosts := tbl.filterHosts(filteredIndexes)
|
||||
indexPool.Put(filteredIndexes)
|
||||
filteredIndexes = Match(filteredHosts, host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
matchExprPool.Put(filteredHosts)
|
||||
|
||||
result := len(filteredIndexes) > 0
|
||||
indexPool.Put(filteredIndexes)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetIndex returns the index of the given branch, user, and host expressions. If the expressions cannot be found,
|
||||
// returns -1. Assumes that the given expressions have already been folded.
|
||||
func (tbl *Namespace) GetIndex(branchExpr string, userExpr string, hostExpr string) int {
|
||||
for i, value := range tbl.Values {
|
||||
if value.Branch == branchExpr && value.User == userExpr && value.Host == hostExpr {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Access returns the Access table.
|
||||
func (tbl *Namespace) Access() *Access {
|
||||
return tbl.access
|
||||
}
|
||||
|
||||
// filterBranches returns all branches that match the given collection indexes.
|
||||
func (tbl *Namespace) filterBranches(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Branches[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterUsers returns all users that match the given collection indexes.
|
||||
func (tbl *Namespace) filterUsers(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Users[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
|
||||
// filterHosts returns all hosts that match the given collection indexes.
|
||||
func (tbl *Namespace) filterHosts(filters []uint32) []MatchExpression {
|
||||
if len(filters) == 0 {
|
||||
return nil
|
||||
}
|
||||
matchExprs := matchExprPool.Get().([]MatchExpression)[:0]
|
||||
for _, filter := range filters {
|
||||
matchExprs = append(matchExprs, tbl.Hosts[filter])
|
||||
}
|
||||
return matchExprs
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql/mysql_db"
|
||||
"gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions/commitwalk"
|
||||
@@ -361,6 +362,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
|
||||
if head == nil {
|
||||
var err error
|
||||
head, err = ds.GetHeadCommit(ctx, db.Name())
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
@@ -465,6 +467,10 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
|
||||
dt, found = dtables.NewMergeStatusTable(db.name), true
|
||||
case doltdb.TagsTableName:
|
||||
dt, found = dtables.NewTagsTable(ctx, db.ddb), true
|
||||
case dtables.AccessTableName:
|
||||
dt, found = dtables.NewBranchControlTable(branch_control.StaticController.Access), true
|
||||
case dtables.NamespaceTableName:
|
||||
dt, found = dtables.NewBranchNamespaceControlTable(branch_control.StaticController.Namespace), true
|
||||
}
|
||||
if found {
|
||||
return dt, found, nil
|
||||
@@ -728,6 +734,9 @@ func (db Database) GetHeadRoot(ctx *sql.Context) (*doltdb.RootValue, error) {
|
||||
// DropTable drops the table with the name given.
|
||||
// The planner returns the correct case sensitive name in tableName
|
||||
func (db Database) DropTable(ctx *sql.Context, tableName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if doltdb.IsReadOnlySystemTable(tableName) {
|
||||
return ErrSystemTableAlter.New(tableName)
|
||||
}
|
||||
@@ -827,6 +836,9 @@ func (db Database) removeTableFromAutoIncrementTracker(
|
||||
|
||||
// CreateTable creates a table with the name and schema given.
|
||||
func (db Database) CreateTable(ctx *sql.Context, tableName string, sch sql.PrimaryKeySchema, collation sql.CollationID) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.ToLower(tableName) == doltdb.DocTableName {
|
||||
// validate correct schema
|
||||
if !dtables.DoltDocsSqlSchema.Equals(sch.Schema) {
|
||||
@@ -938,6 +950,9 @@ func (db Database) CreateTemporaryTable(ctx *sql.Context, tableName string, pkSc
|
||||
|
||||
// RenameTable implements sql.TableRenamer
|
||||
func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := db.GetRoot(ctx)
|
||||
|
||||
if err != nil {
|
||||
@@ -1152,15 +1167,24 @@ func (db Database) GetStoredProcedures(ctx *sql.Context) ([]sql.StoredProcedureD
|
||||
|
||||
// SaveStoredProcedure implements sql.StoredProcedureDatabase.
|
||||
func (db Database) SaveStoredProcedure(ctx *sql.Context, spd sql.StoredProcedureDetails) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
return DoltProceduresAddProcedure(ctx, db, spd)
|
||||
}
|
||||
|
||||
// DropStoredProcedure implements sql.StoredProcedureDatabase.
|
||||
func (db Database) DropStoredProcedure(ctx *sql.Context, name string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
return DoltProceduresDropProcedure(ctx, db, name)
|
||||
}
|
||||
|
||||
func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, definition string, created time.Time, existingErr error) (err error) {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
tbl, err := GetOrCreateDoltSchemasTable(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1212,6 +1236,9 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
|
||||
}
|
||||
|
||||
func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
stbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/dolthub/go-mysql-server/sql/expression"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
@@ -120,6 +121,12 @@ func renameBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseRe
|
||||
if oldBranchName == "" || newBranchName == "" {
|
||||
return EmptyBranchNameErr
|
||||
}
|
||||
if err := branch_control.CanDeleteBranch(ctx, oldBranchName); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := branch_control.CanCreateBranch(ctx, newBranchName); err != nil {
|
||||
return err
|
||||
}
|
||||
force := apr.Contains(cli.ForceFlag)
|
||||
|
||||
if !force {
|
||||
@@ -168,6 +175,13 @@ func deleteBranches(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParse
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that we can delete all branches before continuing
|
||||
for _, branchName := range apr.Args {
|
||||
if err = branch_control.CanDeleteBranch(ctx, branchName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var updateFS = false
|
||||
for _, branchName := range apr.Args {
|
||||
if len(branchName) == 0 {
|
||||
@@ -278,6 +292,9 @@ func createNewBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgPars
|
||||
return EmptyBranchNameErr
|
||||
}
|
||||
|
||||
if err := branch_control.CanCreateBranch(ctx, branchName); err != nil {
|
||||
return err
|
||||
}
|
||||
return actions.CreateBranchWithStartPt(ctx, dbData, branchName, startPt, apr.Contains(cli.ForceFlag))
|
||||
}
|
||||
|
||||
@@ -301,6 +318,9 @@ func copyBranch(ctx *sql.Context, dbData env.DbData, apr *argparser.ArgParseResu
|
||||
}
|
||||
|
||||
func copyABranch(ctx *sql.Context, dbData env.DbData, srcBr string, destBr string, force bool) error {
|
||||
if err := branch_control.CanCreateBranch(ctx, destBr); err != nil {
|
||||
return err
|
||||
}
|
||||
err := actions.CopyBranchOnDB(ctx, dbData.Ddb, srcBr, destBr, force)
|
||||
if err != nil {
|
||||
if err == doltdb.ErrBranchNotFound {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
package dsess
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
@@ -26,6 +27,7 @@ import (
|
||||
goerrors "gopkg.in/src-d/go-errors.v1"
|
||||
|
||||
"github.com/dolthub/dolt/go/cmd/dolt/cli"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
|
||||
@@ -68,6 +70,7 @@ type DoltSession struct {
|
||||
|
||||
var _ sql.Session = (*DoltSession)(nil)
|
||||
var _ sql.PersistableSession = (*DoltSession)(nil)
|
||||
var _ branch_control.Context = (*DoltSession)(nil)
|
||||
|
||||
// DefaultSession creates a DoltSession with default values
|
||||
func DefaultSession(pro DoltDatabaseProvider) *DoltSession {
|
||||
@@ -1192,6 +1195,31 @@ func (d *DoltSession) SystemVariablesInConfig() ([]sql.SystemVariable, error) {
|
||||
return sysVars, nil
|
||||
}
|
||||
|
||||
// GetBranch implements the interface branch_control.Context.
|
||||
func (d *DoltSession) GetBranch() (string, error) {
|
||||
branchRef, err := d.CWBHeadRef(sql.NewContext(context.Background(), sql.WithSession(d)), d.Session.GetCurrentDatabase())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return branchRef.GetPath(), nil
|
||||
}
|
||||
|
||||
// GetUser implements the interface branch_control.Context.
|
||||
func (d *DoltSession) GetUser() string {
|
||||
return d.Session.Client().User
|
||||
}
|
||||
|
||||
// GetHost implements the interface branch_control.Context.
|
||||
func (d *DoltSession) GetHost() string {
|
||||
return d.Session.Client().Address
|
||||
}
|
||||
|
||||
// GetController implements the interface branch_control.Context.
|
||||
func (d *DoltSession) GetController() *branch_control.Controller {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// validatePersistedSysVar checks whether a system variable exists and is dynamic
|
||||
func validatePersistableSysVar(name string) (sql.SystemVariable, interface{}, error) {
|
||||
sysVar, val, ok := sql.SystemVariables.GetGlobal(name)
|
||||
|
||||
@@ -0,0 +1,344 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing branch_control.Permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dtables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
)
|
||||
|
||||
const (
|
||||
AccessTableName = "dolt_branch_control"
|
||||
)
|
||||
|
||||
// PermissionsStrings is a slice of strings representing the available branch_control.branch_control.Permissions. The order of the
|
||||
// strings should exactly match the order of the branch_control.Permissions according to their flag value.
|
||||
var PermissionsStrings = []string{"admin", "write"}
|
||||
|
||||
// accessSchema is the schema for the "dolt_branch_control" table.
|
||||
var accessSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "user",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: AccessTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "host",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "branch_control.Permissions",
|
||||
Type: sql.MustCreateSetType(PermissionsStrings, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: AccessTableName,
|
||||
PrimaryKey: false,
|
||||
},
|
||||
}
|
||||
|
||||
// BranchControlTable provides a layer over the branch_control.Access structure, exposing it as a system table.
|
||||
type BranchControlTable struct {
|
||||
*branch_control.Access
|
||||
}
|
||||
|
||||
var _ sql.Table = BranchControlTable{}
|
||||
var _ sql.InsertableTable = BranchControlTable{}
|
||||
var _ sql.ReplaceableTable = BranchControlTable{}
|
||||
var _ sql.UpdatableTable = BranchControlTable{}
|
||||
var _ sql.DeletableTable = BranchControlTable{}
|
||||
var _ sql.RowInserter = BranchControlTable{}
|
||||
var _ sql.RowReplacer = BranchControlTable{}
|
||||
var _ sql.RowUpdater = BranchControlTable{}
|
||||
var _ sql.RowDeleter = BranchControlTable{}
|
||||
|
||||
// NewBranchControlTable returns a new BranchControlTable.
|
||||
func NewBranchControlTable(access *branch_control.Access) BranchControlTable {
|
||||
return BranchControlTable{access}
|
||||
}
|
||||
|
||||
// Name implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) Name() string {
|
||||
return AccessTableName
|
||||
}
|
||||
|
||||
// String implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) String() string {
|
||||
return AccessTableName
|
||||
}
|
||||
|
||||
// Schema implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) Schema() sql.Schema {
|
||||
return accessSchema
|
||||
}
|
||||
|
||||
// Collation implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) Collation() sql.CollationID {
|
||||
return sql.Collation_Default
|
||||
}
|
||||
|
||||
// Partitions implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
|
||||
return index.SinglePartitionIterFromNomsMap(nil), nil
|
||||
}
|
||||
|
||||
// PartitionRows implements the interface sql.Table.
|
||||
func (tbl BranchControlTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
|
||||
tbl.RWMutex.RLock()
|
||||
defer tbl.RWMutex.RUnlock()
|
||||
|
||||
rows := []sql.Row{{"%", tbl.SuperUser, tbl.SuperHost, uint64(branch_control.Permissions_Admin)}}
|
||||
for _, value := range tbl.Values {
|
||||
rows = append(rows, sql.Row{
|
||||
value.Branch,
|
||||
value.User,
|
||||
value.Host,
|
||||
uint64(value.Permissions),
|
||||
})
|
||||
}
|
||||
return sql.RowsToRowIter(rows...), nil
|
||||
}
|
||||
|
||||
// Inserter implements the interface sql.InsertableTable.
|
||||
func (tbl BranchControlTable) Inserter(context *sql.Context) sql.RowInserter {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Replacer implements the interface sql.ReplaceableTable.
|
||||
func (tbl BranchControlTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Updater implements the interface sql.UpdatableTable.
|
||||
func (tbl BranchControlTable) Updater(ctx *sql.Context) sql.RowUpdater {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Deleter implements the interface sql.DeletableTable.
|
||||
func (tbl BranchControlTable) Deleter(context *sql.Context) sql.RowDeleter {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// StatementBegin implements the interface sql.TableEditor.
|
||||
func (tbl BranchControlTable) StatementBegin(ctx *sql.Context) {
|
||||
//TODO: will use the binlog to implement
|
||||
}
|
||||
|
||||
// DiscardChanges implements the interface sql.TableEditor.
|
||||
func (tbl BranchControlTable) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
|
||||
//TODO: will use the binlog to implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// StatementComplete implements the interface sql.TableEditor.
|
||||
func (tbl BranchControlTable) StatementComplete(ctx *sql.Context) error {
|
||||
//TODO: will use the binlog to implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert implements the interface sql.RowInserter.
|
||||
func (tbl BranchControlTable) Insert(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
perms := branch_control.Permissions(row[3].(uint64))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the insertion has permission to perform the insertion.
|
||||
_, modPerms := tbl.Match(branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(uint64(perms))
|
||||
return fmt.Errorf("`%s`@`%s` cannot add the row [%q, %q, %q, %q]",
|
||||
insertUser, insertHost, branch, user, host, permStr)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.insert(ctx, branch, user, host, perms)
|
||||
}
|
||||
|
||||
// Update implements the interface sql.RowUpdater.
|
||||
func (tbl BranchControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[1].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newUser := branch_control.FoldExpression(new[1].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
|
||||
newPerms := branch_control.Permissions(new[3].(uint64))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// If we're not updating the same row, then we pre-emptively check for a row violation
|
||||
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
|
||||
permBits := uint64(tbl.Values[tblIndex].Permissions)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, newBranch, newUser, newHost, permStr),
|
||||
true,
|
||||
sql.Row{newBranch, newUser, newHost, permBits})
|
||||
}
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we'd allow the update in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Match(oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q]",
|
||||
insertUser, insertHost, oldBranch, oldUser, oldHost)
|
||||
}
|
||||
// Now we check if the user has permission use the new branch name
|
||||
_, modPerms = tbl.Match(newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q",
|
||||
insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
|
||||
}
|
||||
}
|
||||
|
||||
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tbl.insert(ctx, newBranch, newUser, newHost, newPerms)
|
||||
}
|
||||
|
||||
// Delete implements the interface sql.RowDeleter.
|
||||
func (tbl BranchControlTable) Delete(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the deletion has permission to perform the deletion.
|
||||
_, modPerms := tbl.Match(branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot delete the row [%q, %q, %q]",
|
||||
insertUser, insertHost, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.delete(ctx, branch, user, host)
|
||||
}
|
||||
|
||||
// Close implements the interface sql.Closer.
|
||||
func (tbl BranchControlTable) Close(context *sql.Context) error {
|
||||
//TODO: write the binlog
|
||||
return nil
|
||||
}
|
||||
|
||||
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchControlTable) insert(ctx context.Context, branch string, user string, host string, perms branch_control.Permissions) error {
|
||||
// If we already have this in the table, then we return a duplicate PK error
|
||||
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
|
||||
permBits := uint64(tbl.Values[tblIndex].Permissions)
|
||||
permStr, _ := accessSchema[3].Type.(sql.SetType).BitsToString(permBits)
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q, %q]`, branch, user, host, permStr),
|
||||
true,
|
||||
sql.Row{branch, user, host, permBits})
|
||||
}
|
||||
|
||||
// Add the expressions to their respective slices
|
||||
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
|
||||
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
nextIdx := uint32(len(tbl.Values))
|
||||
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
|
||||
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
|
||||
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
|
||||
tbl.Values = append(tbl.Values, branch_control.AccessValue{
|
||||
Branch: branch,
|
||||
User: user,
|
||||
Host: host,
|
||||
Permissions: perms,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchControlTable) delete(ctx context.Context, branch string, user string, host string) error {
|
||||
// If we don't have this in the table, then we just return
|
||||
tblIndex := tbl.GetIndex(branch, user, host)
|
||||
if tblIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
endIndex := len(tbl.Values) - 1
|
||||
// Remove the matching row from all slices by first swapping with the last element
|
||||
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
|
||||
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
|
||||
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
|
||||
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
|
||||
// Then we remove the last element
|
||||
tbl.Branches = tbl.Branches[:endIndex]
|
||||
tbl.Users = tbl.Users[:endIndex]
|
||||
tbl.Hosts = tbl.Hosts[:endIndex]
|
||||
tbl.Values = tbl.Values[:endIndex]
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
// Copyright 2022 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package dtables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
"github.com/dolthub/vitess/go/sqltypes"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
|
||||
)
|
||||
|
||||
const (
|
||||
NamespaceTableName = "dolt_branch_namespace_control"
|
||||
)
|
||||
|
||||
// namespaceSchema is the schema for the "dolt_branch_namespace_control" table.
|
||||
var namespaceSchema = sql.Schema{
|
||||
&sql.Column{
|
||||
Name: "branch",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: NamespaceTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "user",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_bin),
|
||||
Source: NamespaceTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
&sql.Column{
|
||||
Name: "host",
|
||||
Type: sql.MustCreateString(sqltypes.VarChar, 16383, sql.Collation_utf8mb4_0900_ai_ci),
|
||||
Source: NamespaceTableName,
|
||||
PrimaryKey: true,
|
||||
},
|
||||
}
|
||||
|
||||
// BranchNamespaceControlTable provides a layer over the branch_control.Namespace structure, exposing it as a system
|
||||
// table.
|
||||
type BranchNamespaceControlTable struct {
|
||||
*branch_control.Namespace
|
||||
}
|
||||
|
||||
var _ sql.Table = BranchNamespaceControlTable{}
|
||||
var _ sql.InsertableTable = BranchNamespaceControlTable{}
|
||||
var _ sql.ReplaceableTable = BranchNamespaceControlTable{}
|
||||
var _ sql.UpdatableTable = BranchNamespaceControlTable{}
|
||||
var _ sql.DeletableTable = BranchNamespaceControlTable{}
|
||||
var _ sql.RowInserter = BranchNamespaceControlTable{}
|
||||
var _ sql.RowReplacer = BranchNamespaceControlTable{}
|
||||
var _ sql.RowUpdater = BranchNamespaceControlTable{}
|
||||
var _ sql.RowDeleter = BranchNamespaceControlTable{}
|
||||
|
||||
// NewBranchNamespaceControlTable returns a new BranchNamespaceControlTable.
|
||||
func NewBranchNamespaceControlTable(namespace *branch_control.Namespace) BranchNamespaceControlTable {
|
||||
return BranchNamespaceControlTable{namespace}
|
||||
}
|
||||
|
||||
// Name implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) Name() string {
|
||||
return NamespaceTableName
|
||||
}
|
||||
|
||||
// String implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) String() string {
|
||||
return NamespaceTableName
|
||||
}
|
||||
|
||||
// Schema implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) Schema() sql.Schema {
|
||||
return namespaceSchema
|
||||
}
|
||||
|
||||
// Collation implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) Collation() sql.CollationID {
|
||||
return sql.Collation_Default
|
||||
}
|
||||
|
||||
// Partitions implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) Partitions(context *sql.Context) (sql.PartitionIter, error) {
|
||||
return index.SinglePartitionIterFromNomsMap(nil), nil
|
||||
}
|
||||
|
||||
// PartitionRows implements the interface sql.Table.
|
||||
func (tbl BranchNamespaceControlTable) PartitionRows(context *sql.Context, partition sql.Partition) (sql.RowIter, error) {
|
||||
tbl.RWMutex.RLock()
|
||||
defer tbl.RWMutex.RUnlock()
|
||||
|
||||
var rows []sql.Row
|
||||
for _, value := range tbl.Values {
|
||||
rows = append(rows, sql.Row{
|
||||
value.Branch,
|
||||
value.User,
|
||||
value.Host,
|
||||
})
|
||||
}
|
||||
return sql.RowsToRowIter(rows...), nil
|
||||
}
|
||||
|
||||
// Inserter implements the interface sql.InsertableTable.
|
||||
func (tbl BranchNamespaceControlTable) Inserter(context *sql.Context) sql.RowInserter {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Replacer implements the interface sql.ReplaceableTable.
|
||||
func (tbl BranchNamespaceControlTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Updater implements the interface sql.UpdatableTable.
|
||||
func (tbl BranchNamespaceControlTable) Updater(ctx *sql.Context) sql.RowUpdater {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// Deleter implements the interface sql.DeletableTable.
|
||||
func (tbl BranchNamespaceControlTable) Deleter(context *sql.Context) sql.RowDeleter {
|
||||
return tbl
|
||||
}
|
||||
|
||||
// StatementBegin implements the interface sql.TableEditor.
|
||||
func (tbl BranchNamespaceControlTable) StatementBegin(ctx *sql.Context) {
|
||||
//TODO: will use the binlog to implement
|
||||
}
|
||||
|
||||
// DiscardChanges implements the interface sql.TableEditor.
|
||||
func (tbl BranchNamespaceControlTable) DiscardChanges(ctx *sql.Context, errorEncountered error) error {
|
||||
//TODO: will use the binlog to implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// StatementComplete implements the interface sql.TableEditor.
|
||||
func (tbl BranchNamespaceControlTable) StatementComplete(ctx *sql.Context) error {
|
||||
//TODO: will use the binlog to implement
|
||||
return nil
|
||||
}
|
||||
|
||||
// Insert implements the interface sql.RowInserter.
|
||||
func (tbl BranchNamespaceControlTable) Insert(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the insertion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
// Need to acquire a read lock on the Access table since we have to read from it
|
||||
tbl.Access().RWMutex.RLock()
|
||||
defer tbl.Access().RWMutex.RUnlock()
|
||||
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the insertion has permission to perform the insertion.
|
||||
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot add the row [%q, %q, %q]",
|
||||
insertUser, insertHost, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.insert(ctx, branch, user, host)
|
||||
}
|
||||
|
||||
// Update implements the interface sql.RowUpdater.
|
||||
func (tbl BranchNamespaceControlTable) Update(ctx *sql.Context, old sql.Row, new sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
oldBranch := strings.ToLower(branch_control.FoldExpression(old[0].(string)))
|
||||
oldUser := branch_control.FoldExpression(old[1].(string))
|
||||
oldHost := strings.ToLower(branch_control.FoldExpression(old[2].(string)))
|
||||
newBranch := strings.ToLower(branch_control.FoldExpression(new[0].(string)))
|
||||
newUser := branch_control.FoldExpression(new[1].(string))
|
||||
newHost := strings.ToLower(branch_control.FoldExpression(new[2].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(newBranch) > math.MaxUint16 || len(newUser) > math.MaxUint16 || len(newHost) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// If we're not updating the same row, then we pre-emptively check for a row violation
|
||||
if oldBranch != newBranch || oldUser != newUser || oldHost != newHost {
|
||||
if tblIndex := tbl.GetIndex(newBranch, newUser, newHost); tblIndex != -1 {
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q]`, newBranch, newUser, newHost),
|
||||
true,
|
||||
sql.Row{newBranch, newUser, newHost})
|
||||
}
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we'd allow the update in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
// Need to acquire a read lock on the Access table since we have to read from it
|
||||
tbl.Access().RWMutex.RLock()
|
||||
defer tbl.Access().RWMutex.RUnlock()
|
||||
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the update has permission to perform the update on the old branch name.
|
||||
_, modPerms := tbl.Access().Match(oldBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q]",
|
||||
insertUser, insertHost, oldBranch, oldUser, oldHost)
|
||||
}
|
||||
// Now we check if the user has permission use the new branch name
|
||||
_, modPerms = tbl.Access().Match(newBranch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot update the row [%q, %q, %q] to the new branch expression %q",
|
||||
insertUser, insertHost, oldBranch, oldUser, oldHost, newBranch)
|
||||
}
|
||||
}
|
||||
|
||||
if tblIndex := tbl.GetIndex(oldBranch, oldUser, oldHost); tblIndex != -1 {
|
||||
if err := tbl.delete(ctx, oldBranch, oldUser, oldHost); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return tbl.insert(ctx, newBranch, newUser, newHost)
|
||||
}
|
||||
|
||||
// Delete implements the interface sql.RowDeleter.
|
||||
func (tbl BranchNamespaceControlTable) Delete(ctx *sql.Context, row sql.Row) error {
|
||||
tbl.RWMutex.Lock()
|
||||
defer tbl.RWMutex.Unlock()
|
||||
|
||||
// Branch and Host are case-insensitive, while user is case-sensitive
|
||||
branch := strings.ToLower(branch_control.FoldExpression(row[0].(string)))
|
||||
user := branch_control.FoldExpression(row[1].(string))
|
||||
host := strings.ToLower(branch_control.FoldExpression(row[2].(string)))
|
||||
|
||||
// Verify that the lengths of each expression fit within an uint16
|
||||
if len(branch) > math.MaxUint16 || len(user) > math.MaxUint16 || len(host) > math.MaxUint16 {
|
||||
return fmt.Errorf("expressions are too long [%q, %q, %q]", branch, user, host)
|
||||
}
|
||||
|
||||
// A nil session means we're not in the SQL context, so we allow the deletion in such a case
|
||||
if branchAwareSession := branch_control.GetBranchAwareSession(ctx); branchAwareSession != nil {
|
||||
// Need to acquire a read lock on the Access table since we have to read from it
|
||||
tbl.Access().RWMutex.RLock()
|
||||
defer tbl.Access().RWMutex.RUnlock()
|
||||
|
||||
insertUser := branchAwareSession.GetUser()
|
||||
insertHost := branchAwareSession.GetHost()
|
||||
// As we've folded the branch expression, we can use it directly as though it were a normal branch name to
|
||||
// determine if the user attempting the deletion has permission to perform the deletion.
|
||||
_, modPerms := tbl.Access().Match(branch, insertUser, insertHost)
|
||||
if modPerms&branch_control.Permissions_Admin != branch_control.Permissions_Admin {
|
||||
return fmt.Errorf("`%s`@`%s` cannot delete the row [%q, %q, %q]",
|
||||
insertUser, insertHost, branch, user, host)
|
||||
}
|
||||
}
|
||||
|
||||
return tbl.delete(ctx, branch, user, host)
|
||||
}
|
||||
|
||||
// Close implements the interface sql.Closer.
|
||||
func (tbl BranchNamespaceControlTable) Close(context *sql.Context) error {
|
||||
//TODO: write the binlog
|
||||
return nil
|
||||
}
|
||||
|
||||
// insert adds the given branch, user, and host expression strings to the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchNamespaceControlTable) insert(ctx context.Context, branch string, user string, host string) error {
|
||||
// If we already have this in the table, then we return a duplicate PK error
|
||||
if tblIndex := tbl.GetIndex(branch, user, host); tblIndex != -1 {
|
||||
return sql.NewUniqueKeyErr(
|
||||
fmt.Sprintf(`[%q, %q, %q]`, branch, user, host),
|
||||
true,
|
||||
sql.Row{branch, user, host})
|
||||
}
|
||||
|
||||
// Add the expressions to their respective slices
|
||||
branchExpr := branch_control.ParseExpression(branch, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
userExpr := branch_control.ParseExpression(user, sql.Collation_utf8mb4_0900_bin)
|
||||
hostExpr := branch_control.ParseExpression(host, sql.Collation_utf8mb4_0900_ai_ci)
|
||||
nextIdx := uint32(len(tbl.Values))
|
||||
tbl.Branches = append(tbl.Branches, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: branchExpr})
|
||||
tbl.Users = append(tbl.Users, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: userExpr})
|
||||
tbl.Hosts = append(tbl.Hosts, branch_control.MatchExpression{CollectionIndex: nextIdx, SortOrders: hostExpr})
|
||||
return nil
|
||||
}
|
||||
|
||||
// delete removes the given branch, user, and host expression strings from the table. Assumes that the expressions have
|
||||
// already been folded.
|
||||
func (tbl BranchNamespaceControlTable) delete(ctx context.Context, branch string, user string, host string) error {
|
||||
// If we don't have this in the table, then we just return
|
||||
tblIndex := tbl.GetIndex(branch, user, host)
|
||||
if tblIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
endIndex := len(tbl.Values) - 1
|
||||
// Remove the matching row from all slices by first swapping with the last element
|
||||
tbl.Branches[tblIndex], tbl.Branches[endIndex] = tbl.Branches[endIndex], tbl.Branches[tblIndex]
|
||||
tbl.Users[tblIndex], tbl.Users[endIndex] = tbl.Users[endIndex], tbl.Users[tblIndex]
|
||||
tbl.Hosts[tblIndex], tbl.Hosts[endIndex] = tbl.Hosts[endIndex], tbl.Hosts[tblIndex]
|
||||
tbl.Values[tblIndex], tbl.Values[endIndex] = tbl.Values[endIndex], tbl.Values[tblIndex]
|
||||
// Then we remove the last element
|
||||
tbl.Branches = tbl.Branches[:endIndex]
|
||||
tbl.Users = tbl.Users[:endIndex]
|
||||
tbl.Hosts = tbl.Hosts[:endIndex]
|
||||
tbl.Values = tbl.Values[:endIndex]
|
||||
return nil
|
||||
}
|
||||
@@ -447,12 +447,11 @@ func TestJSONTableScripts(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUserPrivileges(t *testing.T) {
|
||||
t.Skip("Need to add more collations")
|
||||
enginetest.TestUserPrivileges(t, newDoltHarness(t))
|
||||
}
|
||||
|
||||
func TestUserAuthentication(t *testing.T) {
|
||||
t.Skip("Need to add more collations")
|
||||
t.Skip("Unexpected panic, need to fix")
|
||||
enginetest.TestUserAuthentication(t, newDoltHarness(t))
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
|
||||
"github.com/dolthub/go-mysql-server/sql"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
|
||||
@@ -609,6 +610,9 @@ func (t *WritableDoltTable) WithProjections(colNames []string) sql.Table {
|
||||
|
||||
// Inserter implements sql.InsertableTable
|
||||
func (t *WritableDoltTable) Inserter(ctx *sql.Context) sql.RowInserter {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
}
|
||||
te, err := t.getTableEditor(ctx)
|
||||
if err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
@@ -648,6 +652,9 @@ func (t *WritableDoltTable) getTableEditor(ctx *sql.Context) (ed writer.TableWri
|
||||
|
||||
// Deleter implements sql.DeletableTable
|
||||
func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
}
|
||||
te, err := t.getTableEditor(ctx)
|
||||
if err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
@@ -657,6 +664,9 @@ func (t *WritableDoltTable) Deleter(ctx *sql.Context) sql.RowDeleter {
|
||||
|
||||
// Replacer implements sql.ReplaceableTable
|
||||
func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
}
|
||||
te, err := t.getTableEditor(ctx)
|
||||
if err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
@@ -666,6 +676,9 @@ func (t *WritableDoltTable) Replacer(ctx *sql.Context) sql.RowReplacer {
|
||||
|
||||
// Truncate implements sql.TruncateableTable
|
||||
func (t *WritableDoltTable) Truncate(ctx *sql.Context) (int, error) {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
table, err := t.DoltTable.DoltTable(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@@ -753,6 +766,9 @@ func (t *WritableDoltTable) truncate(
|
||||
|
||||
// Updater implements sql.UpdatableTable
|
||||
func (t *WritableDoltTable) Updater(ctx *sql.Context) sql.RowUpdater {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
}
|
||||
te, err := t.getTableEditor(ctx)
|
||||
if err != nil {
|
||||
return sqlutil.NewStaticErrorEditor(err)
|
||||
@@ -1157,6 +1173,9 @@ func (t *AlterableDoltTable) WithProjections(colNames []string) sql.Table {
|
||||
|
||||
// AddColumn implements sql.AlterableTable
|
||||
func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := t.getRoot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1290,6 +1309,9 @@ func (t *AlterableDoltTable) RewriteInserter(
|
||||
oldColumn *sql.Column,
|
||||
newColumn *sql.Column,
|
||||
) (sql.RowInserter, error) {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err := validateSchemaChange(t.Name(), oldSchema, newSchema, oldColumn, newColumn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1555,6 +1577,9 @@ func (t *AlterableDoltTable) adjustForeignKeysForDroppedPk(ctx *sql.Context, roo
|
||||
|
||||
// DropColumn implements sql.AlterableTable
|
||||
func (t *AlterableDoltTable) DropColumn(ctx *sql.Context, columnName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if types.IsFormat_DOLT(t.nbf) {
|
||||
return nil
|
||||
}
|
||||
@@ -1671,6 +1696,9 @@ func (t *AlterableDoltTable) dropColumnData(ctx *sql.Context, updatedTable *dolt
|
||||
// ModifyColumn implements sql.AlterableTable. ModifyColumn operations are only used for operations that change only
|
||||
// the schema of a table, not the data. For those operations, |RewriteInserter| is used.
|
||||
func (t *AlterableDoltTable) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Column, order *sql.ColumnOrder) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
ws, err := t.db.GetWorkingSet(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1832,6 +1860,9 @@ func (t *AlterableDoltTable) CreateIndex(
|
||||
indexColumns []sql.IndexColumn,
|
||||
comment string,
|
||||
) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if constraint != sql.IndexConstraint_None && constraint != sql.IndexConstraint_Unique {
|
||||
return fmt.Errorf("only the following types of index constraints are supported: none, unique")
|
||||
}
|
||||
@@ -1901,6 +1932,9 @@ func (t *AlterableDoltTable) CreateIndex(
|
||||
|
||||
// DropIndex implements sql.IndexAlterableTable
|
||||
func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
// We disallow removing internal dolt_ tables from SQL directly
|
||||
if strings.HasPrefix(indexName, "dolt_") {
|
||||
return fmt.Errorf("dolt internal indexes may not be dropped")
|
||||
@@ -1927,6 +1961,9 @@ func (t *AlterableDoltTable) DropIndex(ctx *sql.Context, indexName string) error
|
||||
|
||||
// RenameIndex implements sql.IndexAlterableTable
|
||||
func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string, toIndexName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
// RenameIndex will error if there is a name collision or an index does not exist
|
||||
_, err := t.sch.Indexes().RenameIndex(fromIndexName, toIndexName)
|
||||
if err != nil {
|
||||
@@ -1965,6 +2002,9 @@ func (t *AlterableDoltTable) RenameIndex(ctx *sql.Context, fromIndexName string,
|
||||
|
||||
// AddForeignKey implements sql.ForeignKeyTable
|
||||
func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKeyConstraint) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if sqlFk.Name != "" && !doltdb.IsValidForeignKeyName(sqlFk.Name) {
|
||||
return fmt.Errorf("invalid foreign key name `%s` as it must match the regular expression %s", sqlFk.Name, doltdb.ForeignKeyNameRegexStr)
|
||||
}
|
||||
@@ -2152,6 +2192,9 @@ func (t *AlterableDoltTable) AddForeignKey(ctx *sql.Context, sqlFk sql.ForeignKe
|
||||
|
||||
// DropForeignKey implements sql.ForeignKeyTable
|
||||
func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := t.getRoot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2177,6 +2220,9 @@ func (t *AlterableDoltTable) DropForeignKey(ctx *sql.Context, fkName string) err
|
||||
|
||||
// UpdateForeignKey implements sql.ForeignKeyTable
|
||||
func (t *AlterableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, sqlFk sql.ForeignKeyConstraint) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := t.getRoot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2520,6 +2566,9 @@ func (t *AlterableDoltTable) updateFromRoot(ctx *sql.Context, root *doltdb.RootV
|
||||
}
|
||||
|
||||
func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := t.getRoot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2573,6 +2622,9 @@ func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefin
|
||||
}
|
||||
|
||||
func (t *AlterableDoltTable) DropCheck(ctx *sql.Context, chName string) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
root, err := t.getRoot(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2668,6 +2720,9 @@ func (t *AlterableDoltTable) constraintNameExists(ctx *sql.Context, name string)
|
||||
}
|
||||
|
||||
func (t *AlterableDoltTable) CreatePrimaryKey(ctx *sql.Context, columns []sql.IndexColumn) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if types.IsFormat_DOLT(t.nbf) {
|
||||
return nil
|
||||
}
|
||||
@@ -2702,6 +2757,9 @@ func (t *AlterableDoltTable) CreatePrimaryKey(ctx *sql.Context, columns []sql.In
|
||||
}
|
||||
|
||||
func (t *AlterableDoltTable) DropPrimaryKey(ctx *sql.Context) error {
|
||||
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
|
||||
return err
|
||||
}
|
||||
if types.IsFormat_DOLT(t.nbf) {
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user