Changed check storage to a slice to preserve ordering. Added name generation

This commit is contained in:
Zach Musgrave
2021-04-11 12:57:20 -07:00
parent 788a714f87
commit 6c81b48324
3 changed files with 114 additions and 41 deletions

View File

@@ -190,6 +190,8 @@ func (fkc *ForeignKeyCollection) AddKeys(fks ...ForeignKey) error {
if key.Name == "" {
// assign a name based on the hash
// 8 char = 5 base32 bytes, should be collision resistant
// TODO: constraint names should be unique, and this isn't guaranteed to be.
// This logic needs to live at the table / DB level.
key.Name = key.HashOf().String()[:8]
}

View File

@@ -16,6 +16,7 @@ package schema
import (
"fmt"
"strings"
)
type Check interface {
@@ -28,8 +29,11 @@ type Check interface {
type CheckCollection interface {
// AddCheck adds a check to this collection and returns it
AddCheck(name, expression string, enforce bool) (Check, error)
// DropCheck removes the check with the name given
DropCheck(name string) error
// AllChecks returns all the checks in the collection
AllChecks() []Check
// Count returns the size of the collection
Count() int
}
@@ -52,43 +56,50 @@ func (c check) Enforced() bool {
}
type checkCollection struct {
checks map[string]check
checks []check
}
func (c checkCollection) AddCheck(name, expression string, enforce bool) (Check, error) {
if _, ok := c.checks[name]; ok {
return nil, fmt.Errorf("name %s in use", name)
}
c.checks[name] = check{
name: name,
expression: expression,
enforced: enforce,
}
func (c *checkCollection) AddCheck(name, expression string, enforce bool) (Check, error) {
for _, chk := range c.checks {
if strings.ToLower(name) == strings.ToLower(chk.name) {
return nil, fmt.Errorf("name %s in use", name)
}
}
return c.checks[name], nil
newCheck := check{
name: name,
expression: expression,
enforced: enforce,
}
c.checks = append(c.checks, newCheck)
return newCheck, nil
}
func (c checkCollection) DropCheck(name string) error {
delete(c.checks, name)
return nil
func (c *checkCollection) DropCheck(name string) error {
for i, chk := range c.checks {
if strings.ToLower(name) == strings.ToLower(chk.name) {
c.checks = append(c.checks[:i], c.checks[i+1:]...)
return nil
}
}
return nil
}
func (c checkCollection) AllChecks() []Check {
func (c *checkCollection) AllChecks() []Check {
checks := make([]Check, len(c.checks))
i := 0
for _, check := range c.checks {
for i, check := range c.checks {
checks[i] = check
i++
}
return checks
}
func (c checkCollection) Count() int {
func (c *checkCollection) Count() int {
return len(c.checks)
}
func NewCheckCollection() CheckCollection {
return checkCollection{
checks: make(map[string]check),
return &checkCollection{
checks: make([]check, 0),
}
}

View File

@@ -15,28 +15,30 @@
package sqle
import (
"context"
"errors"
"fmt"
"io"
"os"
"runtime"
"strconv"
"strings"
"sync"
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
"runtime"
"strconv"
"strings"
"sync"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/sqltypes"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/encoding"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/types"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/alterschema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/encoding"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/set"
"github.com/dolthub/dolt/go/store/types"
)
const (
@@ -1449,6 +1451,14 @@ func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefin
return err
}
check = &(*check)
if check.Name == "" {
check.Name, err = t.generateCheckName(ctx, check)
if err != nil {
return err
}
}
_, err = sch.Checks().AddCheck(check.Name, check.CheckExpression, check.Enforced)
if err != nil {
return err
@@ -1510,3 +1520,53 @@ func (t *AlterableDoltTable) DropCheck(ctx *sql.Context, chName string) error {
return t.updateFromRoot(ctx, newRoot)
}
func (t *AlterableDoltTable) generateCheckName(ctx *sql.Context, check *sql.CheckDefinition) (string, error) {
var bb bytes.Buffer
bb.Write([]byte(check.CheckExpression))
hash := hash.Of(bb.Bytes())
hashedName := fmt.Sprintf("chk_%s", hash.String()[:8])
name := hashedName
var i int
for {
exists, err := t.constraintNameExists(ctx, name)
if err != nil {
return "", err
}
if !exists {
break
}
name = fmt.Sprintf("%s_%d", hashedName, i)
}
return name, nil
}
func (t *AlterableDoltTable) constraintNameExists(ctx *sql.Context, name string) (bool, error) {
keys, err := t.GetForeignKeys(ctx)
if err != nil {
return false, err
}
for _, key := range keys {
if strings.ToLower(key.Name) == strings.ToLower(name) {
return true, nil
}
}
checks, err := t.GetChecks(ctx)
if err != nil {
return false, err
}
for _, check := range checks {
if strings.ToLower(check.Name) == strings.ToLower(name) {
return true, nil
}
}
return false, nil
}