Check constraint, working except for DROP CONSTRAINT. Also needs bats tests

This commit is contained in:
Zach Musgrave
2021-04-08 15:10:31 -07:00
parent eefa7abe3b
commit 241a7fb342
7 changed files with 244 additions and 1 deletions

View File

@@ -0,0 +1,94 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package schema
import (
"fmt"
)
type Check interface {
Name() string
Expression() string
Enforced() bool
}
// CheckCollection is the set of `check` constraints on a table's schema
type CheckCollection interface {
// AddCheck adds a check to this collection and returns it
AddCheck(name, expression string, enforce bool) (Check, error)
DropCheck(name string) error
AllChecks() []Check
Count() int
}
type check struct {
name string
expression string
enforced bool
}
func (c check) Name() string {
return c.name
}
func (c check) Expression() string {
return c.expression
}
func (c check) Enforced() bool {
return c.enforced
}
type checkCollection struct {
checks map[string]check
}
func (c checkCollection) AddCheck(name, expression string, enforce bool) (Check, error) {
if _, ok := c.checks[name]; ok {
return nil, fmt.Errorf("name %s in use", name)
}
c.checks[name] = check{
name: name,
expression: expression,
enforced: enforce,
}
return c.checks[name], nil
}
func (c checkCollection) DropCheck(name string) error {
delete(c.checks, name)
return nil
}
func (c checkCollection) AllChecks() []Check {
checks := make([]Check, len(c.checks))
i := 0
for _, check := range c.checks {
checks[i] = check
i++
}
return checks
}
func (c checkCollection) Count() int {
return len(c.checks)
}
func NewCheckCollection() CheckCollection {
return checkCollection{
checks: make(map[string]check),
}
}

View File

@@ -146,9 +146,16 @@ type encodedIndex struct {
IsSystemDefined bool `noms:"hidden,omitempty" json:"hidden,omitempty"` // Was previously named Hidden, do not change noms name
}
type encodedCheck struct {
Name string `noms:"name" json:"name"`
Expression string `noms:"expression" json:"expression"`
Enforced bool `noms:"enforced" json:"enforced"`
}
type schemaData struct {
Columns []encodedColumn `noms:"columns" json:"columns"`
IndexCollection []encodedIndex `noms:"idxColl,omitempty" json:"idxColl,omitempty"`
CheckConstraints []encodedCheck `noms:"checks,omitempty" json:"checks,omitempty"`
}
func toSchemaData(sch schema.Schema) (schemaData, error) {
@@ -178,7 +185,21 @@ func toSchemaData(sch schema.Schema) (schemaData, error) {
}
}
return schemaData{encCols, encodedIndexes}, nil
encodedChecks := make([]encodedCheck, sch.Checks().Count())
checks := sch.Checks()
for i, check := range checks.AllChecks() {
encodedChecks[i] = encodedCheck{
Name: check.Name(),
Expression: check.Expression(),
Enforced: check.Enforced(),
}
}
return schemaData{
Columns: encCols,
IndexCollection: encodedIndexes,
CheckConstraints: encodedChecks,
}, nil
}
func (sd schemaData) decodeSchema() (schema.Schema, error) {
@@ -215,6 +236,17 @@ func (sd schemaData) decodeSchema() (schema.Schema, error) {
}
}
for _, encodedCheck := range sd.CheckConstraints {
_, err = sch.Checks().AddCheck(
encodedCheck.Name,
encodedCheck.Expression,
encodedCheck.Enforced,
)
if err != nil {
return nil, err
}
}
return sch, nil
}

View File

@@ -171,6 +171,7 @@ func (ix *indexImpl) Schema() Schema {
nonPKCols: nonPkCols,
allCols: allCols,
indexCollection: NewIndexCollection(nil),
checkCollection: NewCheckCollection(),
}
}

View File

@@ -27,6 +27,9 @@ type Schema interface {
// Indexes returns a collection of all indexes on the table that this schema belongs to.
Indexes() IndexCollection
// Checks returns a collection of all check constraints on the table that this schema belongs to.
Checks() CheckCollection
}
// ColFromTag returns a schema.Column from a schema and a tag

View File

@@ -32,6 +32,7 @@ var EmptySchema = &schemaImpl{
type schemaImpl struct {
pkCols, nonPKCols, allCols *ColCollection
indexCollection IndexCollection
checkCollection CheckCollection
}
// SchemaFromCols creates a Schema from a collection of columns
@@ -59,6 +60,7 @@ func SchemaFromCols(allCols *ColCollection) (Schema, error) {
nonPKCols: nonPKColColl,
allCols: allCols,
indexCollection: NewIndexCollection(allCols),
checkCollection: NewCheckCollection(),
}, nil
}
@@ -123,6 +125,7 @@ func UnkeyedSchemaFromCols(allCols *ColCollection) Schema {
nonPKCols: nonPKColColl,
allCols: nonPKColColl,
indexCollection: NewIndexCollection(nil),
checkCollection: NewCheckCollection(),
}
}
@@ -156,6 +159,7 @@ func SchemaFromPKAndNonPKCols(pkCols, nonPKCols *ColCollection) (Schema, error)
nonPKCols: nonPKCols,
allCols: allColColl,
indexCollection: NewIndexCollection(allColColl),
checkCollection: NewCheckCollection(),
}, nil
}
@@ -207,3 +211,7 @@ func (si *schemaImpl) String() string {
func (si *schemaImpl) Indexes() IndexCollection {
return si.indexCollection
}
func (si *schemaImpl) Checks() CheckCollection {
return si.checkCollection
}

View File

@@ -220,6 +220,23 @@ func TestDropForeignKeys(t *testing.T) {
enginetest.TestDropForeignKeys(t, newDoltHarness(t))
}
func TestCreateCheckConstraints(t *testing.T) {
enginetest.TestCreateCheckConstraints(t, newDoltHarness(t))
}
func TestChecksOnInsert(t *testing.T) {
enginetest.TestChecksOnInsert(t, newDoltHarness(t))
}
func TestTestDisallowedCheckConstraints(t *testing.T) {
enginetest.TestDisallowedCheckConstraints(t, newDoltHarness(t))
}
func TestDropCheckConstraints(t *testing.T) {
enginetest.TestDropCheckConstraints(t, newDoltHarness(t))
}
func TestExplode(t *testing.T) {
t.Skipf("Unsupported types")
enginetest.TestExplode(t, newDoltHarness(t))

View File

@@ -388,6 +388,7 @@ var _ sql.InsertableTable = (*WritableDoltTable)(nil)
var _ sql.ReplaceableTable = (*WritableDoltTable)(nil)
var _ sql.AutoIncrementTable = (*WritableDoltTable)(nil)
var _ sql.TruncateableTable = (*WritableDoltTable)(nil)
var _ sql.CheckTable = (*WritableDoltTable)(nil)
func (t *WritableDoltTable) WithIndexLookup(lookup sql.IndexLookup) sql.Table {
dil, ok := lookup.(*doltIndexLookup)
@@ -526,6 +527,24 @@ func (t *WritableDoltTable) GetAutoIncrementValue(ctx *sql.Context) (interface{}
return t.DoltTable.GetAutoIncrementValue(ctx)
}
func (t *WritableDoltTable) GetChecks(ctx *sql.Context) ([]sql.CheckDefinition, error) {
sch, err := t.table.GetSchema(ctx)
if err != nil {
return nil, err
}
checks := make([]sql.CheckDefinition, sch.Checks().Count())
for i, check := range sch.Checks().AllChecks() {
checks[i] = sql.CheckDefinition{
Name: check.Name(),
CheckExpression: check.Expression(),
Enforced: check.Enforced(),
}
}
return checks, nil
}
// GetForeignKeys implements sql.ForeignKeyTable
func (t *DoltTable) GetForeignKeys(ctx *sql.Context) ([]sql.ForeignKeyConstraint, error) {
root, err := t.db.GetRoot(ctx)
@@ -677,6 +696,7 @@ var _ sql.AlterableTable = (*AlterableDoltTable)(nil)
var _ sql.IndexAlterableTable = (*AlterableDoltTable)(nil)
var _ sql.ForeignKeyAlterableTable = (*AlterableDoltTable)(nil)
var _ sql.ForeignKeyTable = (*AlterableDoltTable)(nil)
var _ sql.CheckAlterableTable = (*AlterableDoltTable)(nil)
// AddColumn implements sql.AlterableTable
func (t *AlterableDoltTable) AddColumn(ctx *sql.Context, column *sql.Column, order *sql.ColumnOrder) error {
@@ -1422,3 +1442,71 @@ func (t *AlterableDoltTable) updateFromRoot(ctx *sql.Context, root *doltdb.RootV
t.WritableDoltTable.DoltTable = updatedTable.WritableDoltTable.DoltTable
return nil
}
func (t *AlterableDoltTable) CreateCheck(ctx *sql.Context, check *sql.CheckDefinition) error {
sch, err := t.table.GetSchema(ctx)
if err != nil {
return err
}
_, err = sch.Checks().AddCheck(check.Name, check.CheckExpression, check.Enforced)
if err != nil {
return err
}
newTable, err := t.table.UpdateSchema(ctx, sch)
if err != nil {
return err
}
root, err := t.db.GetRoot(ctx)
if err != nil {
return err
}
newRoot, err := root.PutTable(ctx, t.name, newTable)
if err != nil {
return err
}
err = t.db.SetRoot(ctx, newRoot)
if err != nil {
return err
}
return t.updateFromRoot(ctx, newRoot)
}
func (t *AlterableDoltTable) DropCheck(ctx *sql.Context, chName string) error {
sch, err := t.table.GetSchema(ctx)
if err != nil {
return err
}
err = sch.Checks().DropCheck(chName)
if err != nil {
return err
}
newTable, err := t.table.UpdateSchema(ctx, sch)
if err != nil {
return err
}
root, err := t.db.GetRoot(ctx)
if err != nil {
return err
}
newRoot, err := root.PutTable(ctx, t.name, newTable)
if err != nil {
return err
}
err = t.db.SetRoot(ctx, newRoot)
if err != nil {
return err
}
return t.updateFromRoot(ctx, newRoot)
}