From 2baa9d767c54815cdcccaee2b97a1e2f4f6be681 Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Sun, 4 Sep 2022 19:47:32 +0800 Subject: [PATCH 1/8] feat: support aliyun oss store --- go/go.mod | 1 + go/go.sum | 3 + go/libraries/doltcore/dbfactory/factory.go | 3 + .../doltcore/dbfactory/factory_test.go | 9 ++ go/libraries/doltcore/dbfactory/oss.go | 54 ++++++++ go/store/blobstore/oss.go | 116 ++++++++++++++++++ go/store/blobstore/oss_test.go | 94 ++++++++++++++ 7 files changed, 280 insertions(+) create mode 100644 go/libraries/doltcore/dbfactory/oss.go create mode 100644 go/store/blobstore/oss.go create mode 100644 go/store/blobstore/oss_test.go diff --git a/go/go.mod b/go/go.mod index 81cdc8cdad..6b270d55ca 100644 --- a/go/go.mod +++ b/go/go.mod @@ -84,6 +84,7 @@ require ( github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect + github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible // indirect github.com/apache/thrift v0.13.1-0.20201008052519-daf620915714 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect diff --git a/go/go.sum b/go/go.sum index 97cfab4151..b153b73749 100644 --- a/go/go.sum +++ b/go/go.sum @@ -79,6 +79,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d h1:UQZhZ2O0vMHr2cI+DC1Mbh0TJxzA3RcLoMsFw+aXw7E= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible h1:QoRMR0TCctLDqBCMyOu1eXdZyMw3F7uGA9qPn2J4+R8= +github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible/go.mod h1:T/Aws4fEfogEE9v+HPhhw+CntffsBHJ8nXQCwKr0/g8= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= @@ -996,6 +998,7 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180221164845-07fd8470d635/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/go/libraries/doltcore/dbfactory/factory.go b/go/libraries/doltcore/dbfactory/factory.go index f861ad9037..ecf6e2f8dd 100644 --- a/go/libraries/doltcore/dbfactory/factory.go +++ b/go/libraries/doltcore/dbfactory/factory.go @@ -48,6 +48,8 @@ const ( // InMemBlobstore Scheme LocalBSScheme = "localbs" + OSSScheme = "oss" + defaultScheme = HTTPSScheme defaultMemTableSize = 256 * 1024 * 1024 ) @@ -65,6 +67,7 @@ type DBFactory interface { // from external packages. var DBFactories = map[string]DBFactory{ AWSScheme: AWSFactory{}, + OSSScheme: OSSFactory{}, GSScheme: GSFactory{}, FileScheme: FileFactory{}, MemScheme: MemFactory{}, diff --git a/go/libraries/doltcore/dbfactory/factory_test.go b/go/libraries/doltcore/dbfactory/factory_test.go index 604a06b035..61c3c58581 100644 --- a/go/libraries/doltcore/dbfactory/factory_test.go +++ b/go/libraries/doltcore/dbfactory/factory_test.go @@ -61,3 +61,12 @@ func TestCreateMemDB(t *testing.T) { assert.NotNil(t, vrw) assert.NotNil(t, ns) } + +func TestCreateDB(t *testing.T) { + ctx := context.Background() + db, vrw, ns, err := CreateDB(ctx, types.Format_Default, "oss://aaa/bbb", nil) + assert.NoError(t, err) + assert.NotNil(t, db) + assert.NotNil(t, vrw) + assert.NotNil(t, ns) +} diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go new file mode 100644 index 0000000000..456fdd5ff0 --- /dev/null +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -0,0 +1,54 @@ +package dbfactory + +import ( + "context" + "errors" + "fmt" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/dolthub/dolt/go/store/blobstore" + "github.com/dolthub/dolt/go/store/chunks" + "github.com/dolthub/dolt/go/store/datas" + "github.com/dolthub/dolt/go/store/nbs" + "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/dolt/go/store/types" + "net/url" +) + +// OSSFactory is a DBFactory implementation for creating GCS backed databases +type OSSFactory struct { +} + +// CreateDB creates an GCS backed database +func (fact OSSFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (datas.Database, types.ValueReadWriter, tree.NodeStore, error) { + ossStore, err := fact.newChunkStore(ctx, nbf, urlObj, params) + if err != nil { + return nil, nil, nil, err + } + + vrw := types.NewValueStore(ossStore) + ns := tree.NewNodeStore(ossStore) + db := datas.NewTypesDatabase(vrw, ns) + + return db, vrw, ns, nil +} + +func (fact OSSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (chunks.ChunkStore, error) { + // oss://[bucket]/[key] + bucket := urlObj.Hostname() + prefix := urlObj.Path + // todo get endpoint accesskeyid and secret from env or params + ossClient, err := oss.New( + "endpoint", + "accesskey", + "secret", + ) + if err != nil { + return nil, fmt.Errorf("failed to initialize oss err: %s", err) + } + bs, err := blobstore.NewOSSBlobstore(ossClient, bucket, prefix) + if err != nil { + return nil, errors.New("failed to initialize oss blob store") + } + q := nbs.NewUnlimitedMemQuotaProvider() + return nbs.NewBSStore(ctx, nbf.VersionString(), bs, defaultMemTableSize, q) +} diff --git a/go/store/blobstore/oss.go b/go/store/blobstore/oss.go new file mode 100644 index 0000000000..5474f562ff --- /dev/null +++ b/go/store/blobstore/oss.go @@ -0,0 +1,116 @@ +package blobstore + +import ( + "context" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "io" + "net/http" + "path" + "strconv" +) + +const ( + enabled = "Enabled" +) + +// OSSBlobstore provides an Aliyun OSS implementation of the Blobstore interface +type OSSBlobstore struct { + bucket *oss.Bucket + bucketName string + enableVersion bool + prefix string +} + +// NewOSSBlobstore creates a new instance of a OSSBlobstore +func NewOSSBlobstore(ossClient *oss.Client, bucketName, prefix string) (*OSSBlobstore, error) { + prefix = normalizePrefix(prefix) + bucket, err := ossClient.Bucket(bucketName) + if err != nil { + return nil, err + } + // check if bucket enable versioning + versionStatus, err := ossClient.GetBucketVersioning(bucketName) + if err != nil { + return nil, err + } + return &OSSBlobstore{ + bucket: bucket, + bucketName: bucketName, + prefix: prefix, + enableVersion: versionStatus.Status == enabled, + }, nil +} + +func (ob *OSSBlobstore) Exists(_ context.Context, key string) (bool, error) { + return ob.bucket.IsObjectExist(ob.absKey(key)) +} + +func (ob *OSSBlobstore) Get(ctx context.Context, key string, br BlobRange) (io.ReadCloser, string, error) { + absKey := ob.absKey(key) + meta, err := ob.bucket.GetObjectMeta(absKey) + + if isNotFoundErr(err) { + return nil, "", NotFound{"oss://" + path.Join(ob.bucketName, absKey)} + } + + if br.isAllRange() { + reader, err := ob.bucket.GetObject(absKey) + if err != nil { + return nil, "", err + } + return reader, "", nil + } + size, err := strconv.ParseInt(meta.Get(oss.HTTPHeaderContentLength), 10, 64) + if err != nil { + return nil, "", err + } + posBr := br.positiveRange(size) + reader, err := ob.bucket.GetObject(absKey, oss.Range(posBr.offset, posBr.offset+posBr.length-1)) + if err != nil { + return nil, "", err + } + return reader, oss.GetVersionId(meta), nil +} + +func (ob *OSSBlobstore) Put(ctx context.Context, key string, reader io.Reader) (string, error) { + var meta http.Header + if err := ob.bucket.PutObject(ob.absKey(key), reader, oss.GetResponseHeader(&meta)); err != nil { + return "", err + } + return oss.GetVersionId(meta), nil +} + +func (ob *OSSBlobstore) CheckAndPut(ctx context.Context, expectedVersion, key string, reader io.Reader) (string, error) { + var options []oss.Option + if expectedVersion != "" { + options = append(options, oss.VersionId(expectedVersion)) + } else { + options = append(options, oss.ForbidOverWrite(true)) + } + var meta http.Header + if err := ob.bucket.PutObject(ob.absKey(key), reader, oss.GetResponseHeader(&meta)); err != nil { + return "", err + } + return oss.GetVersionId(meta), nil +} + +func (ob *OSSBlobstore) absKey(key string) string { + return path.Join(ob.prefix, key) +} + +func normalizePrefix(prefix string) string { + for len(prefix) > 0 && prefix[0] == '/' { + prefix = prefix[1:] + } + return prefix +} + +func isNotFoundErr(err error) bool { + switch err.(type) { + case oss.ServiceError: + if err.(oss.ServiceError).StatusCode == 404 { + return true + } + } + return false +} diff --git a/go/store/blobstore/oss_test.go b/go/store/blobstore/oss_test.go new file mode 100644 index 0000000000..8fc65ba668 --- /dev/null +++ b/go/store/blobstore/oss_test.go @@ -0,0 +1,94 @@ +package blobstore + +import ( + "context" + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func TestOSSBlobstore_Put(t *testing.T) { + c, _ := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") + b, _ := c.Bucket("seanxu-version") + cfg, err := c.GetBucketVersioning("seanxu-version") + assert.Nil(t, err) + assert.Equal(t, "Enabled", cfg.Status) + f, err := os.Open("/Users/sean/code/github.com/dolthub/dolt/go/go.sum") + assert.Nil(t, err) + defer f.Close() + err = b.PutObject("testversion/go.mod", f) + assert.Nil(t, err) +} + +func TestOSSBlobstore_Get(t *testing.T) { + c, _ := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") + b, _ := c.Bucket("seanxu-version") + meta, _ := b.GetObjectMeta("testversion/go.mod") + versionID := oss.GetVersionId(meta) + // "CAEQDhiBgMDd9_CYmBgiIGY0YjE2YjY0ZTJiMzQ0NDk4YzNhZWYzNTUwMzFjYTgy" + // "CAEQDhiBgID27JiZmBgiIDJjYTMwN2U5MDkyODRjYjg5ZWUzN2FkYTk0ZWQ3MjY5" + assert.Equal(t, "test", versionID) +} + +func TestOSSBlobstore_Put1(t *testing.T) { + c, err := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") + assert.Nil(t, err) + bs, _ := NewOSSBlobstore(c, "seanxu-version", "") + f, err := os.Open("/Users/sean/code/github.com/dolthub/dolt/go/go.sum") + assert.Nil(t, err) + defer f.Close() + version, err := bs.Put(context.Background(), "dolt/TestOSSBlobstore_Put1", f) + assert.Nil(t, err) + assert.Equal(t, "aaa", version) +} + +func TestNewOSSBlobstore(t *testing.T) { + c, err := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") + assert.Nil(t, err) + bs, err := NewOSSBlobstore(c, "seanxu-version", "dolt") + assert.Nil(t, err) + assert.True(t, bs.enableVersion) + + bs, err = NewOSSBlobstore(c, "seanxu", "dolt") + assert.Nil(t, err) + assert.False(t, bs.enableVersion) +} + +func Test_normalizePrefix(t *testing.T) { + type args struct { + prefix string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "no_leading_slash", + args: args{ + prefix: "root", + }, + want: "root", + }, + { + name: "with_leading_slash", + args: args{ + prefix: "/root", + }, + want: "root", + }, + { + name: "with_multi_leading_slash", + args: args{ + prefix: "//root", + }, + want: "root", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, normalizePrefix(tt.args.prefix), "normalizePrefix(%v)", tt.args.prefix) + }) + } +} From c57752de177d9ca8f2b66a2beaa4bc3bdac5bde4 Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Mon, 5 Sep 2022 22:34:02 +0800 Subject: [PATCH 2/8] feat: fetch oss parameter from env --- go/libraries/doltcore/dbfactory/oss.go | 32 +++++++++++++++++---- go/libraries/doltcore/dbfactory/oss_test.go | 23 +++++++++++++++ go/store/blobstore/oss.go | 25 +++++++++++----- 3 files changed, 67 insertions(+), 13 deletions(-) create mode 100644 go/libraries/doltcore/dbfactory/oss_test.go diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go index 456fdd5ff0..27ceedea6d 100644 --- a/go/libraries/doltcore/dbfactory/oss.go +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -12,6 +12,13 @@ import ( "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" "net/url" + "os" +) + +const ( + ossEndpointEnvKey = "OSS_ENDPOINT" + ossAccessKeyIDEnvKey = "OSS_ACCESS_KEY_ID" + ossAccessKeySecretEnvKey = "OSS_ACCESS_KEY_SECRET" ) // OSSFactory is a DBFactory implementation for creating GCS backed databases @@ -36,12 +43,7 @@ func (fact OSSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm // oss://[bucket]/[key] bucket := urlObj.Hostname() prefix := urlObj.Path - // todo get endpoint accesskeyid and secret from env or params - ossClient, err := oss.New( - "endpoint", - "accesskey", - "secret", - ) + ossClient, err := getOSSClient() if err != nil { return nil, fmt.Errorf("failed to initialize oss err: %s", err) } @@ -52,3 +54,21 @@ func (fact OSSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm q := nbs.NewUnlimitedMemQuotaProvider() return nbs.NewBSStore(ctx, nbf.VersionString(), bs, defaultMemTableSize, q) } + +func getOSSClient() (*oss.Client, error) { + var endpoint, accessKeyID, accessKeySecret string + if endpoint = os.Getenv(ossEndpointEnvKey); endpoint == "" { + return nil, fmt.Errorf("failed to find endpoint from env %s", ossEndpointEnvKey) + } + if accessKeyID = os.Getenv(ossAccessKeyIDEnvKey); accessKeyID == "" { + return nil, fmt.Errorf("failed to find accessKeyID from env %s", ossAccessKeyIDEnvKey) + } + if accessKeySecret = os.Getenv(ossAccessKeySecretEnvKey); accessKeySecret == "" { + return nil, fmt.Errorf("failed to find accessKeySecret from env %s", ossAccessKeySecretEnvKey) + } + return oss.New( + endpoint, + accessKeyID, + accessKeySecret, + ) +} diff --git a/go/libraries/doltcore/dbfactory/oss_test.go b/go/libraries/doltcore/dbfactory/oss_test.go new file mode 100644 index 0000000000..a7574e9e05 --- /dev/null +++ b/go/libraries/doltcore/dbfactory/oss_test.go @@ -0,0 +1,23 @@ +package dbfactory + +import ( + "github.com/stretchr/testify/assert" + "os" + "testing" +) + +func Test_getOSSClient(t *testing.T) { + _, err := getOSSClient() + assert.Error(t, err) + os.Setenv(ossEndpointEnvKey, "testendpoint") + _, err = getOSSClient() + assert.Error(t, err) + + os.Setenv(ossAccessKeyIDEnvKey, "testAccesskey") + _, err = getOSSClient() + assert.Error(t, err) + + os.Setenv(ossAccessKeySecretEnvKey, "testAccessSecret") + _, err = getOSSClient() + assert.Nil(t, err) +} diff --git a/go/store/blobstore/oss.go b/go/store/blobstore/oss.go index 5474f562ff..03be44b351 100644 --- a/go/store/blobstore/oss.go +++ b/go/store/blobstore/oss.go @@ -2,6 +2,7 @@ package blobstore import ( "context" + "fmt" "github.com/aliyun/aliyun-oss-go-sdk/oss" "io" "net/http" @@ -58,7 +59,7 @@ func (ob *OSSBlobstore) Get(ctx context.Context, key string, br BlobRange) (io.R if err != nil { return nil, "", err } - return reader, "", nil + return reader, ob.getVersion(meta), nil } size, err := strconv.ParseInt(meta.Get(oss.HTTPHeaderContentLength), 10, 64) if err != nil { @@ -69,7 +70,7 @@ func (ob *OSSBlobstore) Get(ctx context.Context, key string, br BlobRange) (io.R if err != nil { return nil, "", err } - return reader, oss.GetVersionId(meta), nil + return reader, ob.getVersion(meta), nil } func (ob *OSSBlobstore) Put(ctx context.Context, key string, reader io.Reader) (string, error) { @@ -77,27 +78,37 @@ func (ob *OSSBlobstore) Put(ctx context.Context, key string, reader io.Reader) ( if err := ob.bucket.PutObject(ob.absKey(key), reader, oss.GetResponseHeader(&meta)); err != nil { return "", err } - return oss.GetVersionId(meta), nil + return ob.getVersion(meta), nil } func (ob *OSSBlobstore) CheckAndPut(ctx context.Context, expectedVersion, key string, reader io.Reader) (string, error) { var options []oss.Option if expectedVersion != "" { options = append(options, oss.VersionId(expectedVersion)) - } else { - options = append(options, oss.ForbidOverWrite(true)) } var meta http.Header - if err := ob.bucket.PutObject(ob.absKey(key), reader, oss.GetResponseHeader(&meta)); err != nil { + options = append(options, oss.GetResponseHeader(&meta)) + if err := ob.bucket.PutObject(ob.absKey(key), reader, options...); err != nil { + ossErr, ok := err.(oss.ServiceError) + if ok { + return "", CheckAndPutError{key, expectedVersion, fmt.Sprintf("unknown (OSS error code %d)", ossErr.StatusCode)} + } return "", err } - return oss.GetVersionId(meta), nil + return ob.getVersion(meta), nil } func (ob *OSSBlobstore) absKey(key string) string { return path.Join(ob.prefix, key) } +func (ob *OSSBlobstore) getVersion(meta http.Header) string { + if ob.enableVersion { + return oss.GetVersionId(meta) + } + return "" +} + func normalizePrefix(prefix string) string { for len(prefix) > 0 && prefix[0] == '/' { prefix = prefix[1:] From fb9d32e09ff64c0c64fdeda0f949672847ac3a2a Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Mon, 5 Sep 2022 22:40:32 +0800 Subject: [PATCH 3/8] remove failed test --- go/store/blobstore/oss_test.go | 50 ---------------------------------- 1 file changed, 50 deletions(-) diff --git a/go/store/blobstore/oss_test.go b/go/store/blobstore/oss_test.go index 8fc65ba668..a7387d5227 100644 --- a/go/store/blobstore/oss_test.go +++ b/go/store/blobstore/oss_test.go @@ -1,60 +1,10 @@ package blobstore import ( - "context" - "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/stretchr/testify/assert" - "os" "testing" ) -func TestOSSBlobstore_Put(t *testing.T) { - c, _ := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") - b, _ := c.Bucket("seanxu-version") - cfg, err := c.GetBucketVersioning("seanxu-version") - assert.Nil(t, err) - assert.Equal(t, "Enabled", cfg.Status) - f, err := os.Open("/Users/sean/code/github.com/dolthub/dolt/go/go.sum") - assert.Nil(t, err) - defer f.Close() - err = b.PutObject("testversion/go.mod", f) - assert.Nil(t, err) -} - -func TestOSSBlobstore_Get(t *testing.T) { - c, _ := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") - b, _ := c.Bucket("seanxu-version") - meta, _ := b.GetObjectMeta("testversion/go.mod") - versionID := oss.GetVersionId(meta) - // "CAEQDhiBgMDd9_CYmBgiIGY0YjE2YjY0ZTJiMzQ0NDk4YzNhZWYzNTUwMzFjYTgy" - // "CAEQDhiBgID27JiZmBgiIDJjYTMwN2U5MDkyODRjYjg5ZWUzN2FkYTk0ZWQ3MjY5" - assert.Equal(t, "test", versionID) -} - -func TestOSSBlobstore_Put1(t *testing.T) { - c, err := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") - assert.Nil(t, err) - bs, _ := NewOSSBlobstore(c, "seanxu-version", "") - f, err := os.Open("/Users/sean/code/github.com/dolthub/dolt/go/go.sum") - assert.Nil(t, err) - defer f.Close() - version, err := bs.Put(context.Background(), "dolt/TestOSSBlobstore_Put1", f) - assert.Nil(t, err) - assert.Equal(t, "aaa", version) -} - -func TestNewOSSBlobstore(t *testing.T) { - c, err := oss.New("oss-cn-hangzhou.aliyuncs.com", "", "") - assert.Nil(t, err) - bs, err := NewOSSBlobstore(c, "seanxu-version", "dolt") - assert.Nil(t, err) - assert.True(t, bs.enableVersion) - - bs, err = NewOSSBlobstore(c, "seanxu", "dolt") - assert.Nil(t, err) - assert.False(t, bs.enableVersion) -} - func Test_normalizePrefix(t *testing.T) { type args struct { prefix string From e28eebd3c834051eafb412c213bb97bc314154c1 Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Sun, 18 Sep 2022 17:33:31 +0800 Subject: [PATCH 4/8] feat(ossstore): support oss-cred-file and oss-cred-profile the default credentials file will locate at $HOME/.oss/dolt_oss_credentials, environment variable will still be supported for container environment. dolt remote add --oss-creds-profile prod-profile --oss-creds-config origin oss://[bucket]/key --- go/cmd/dolt/cli/arg_parser_helpers.go | 31 ++- go/cmd/dolt/commands/remote.go | 10 +- go/libraries/doltcore/dbfactory/oss.go | 144 +++++++++++++- go/libraries/doltcore/dbfactory/oss_test.go | 178 ++++++++++++++++-- .../testdata/osscred/dolt_oss_credentials | 17 ++ .../dbfactory/testdata/osscred/empty_oss_cred | 0 .../testdata/osscred/single_oss_cred | 7 + .../doltcore/sqle/dprocedures/dolt_remote.go | 8 +- 8 files changed, 364 insertions(+), 31 deletions(-) create mode 100644 go/libraries/doltcore/dbfactory/testdata/osscred/dolt_oss_credentials create mode 100644 go/libraries/doltcore/dbfactory/testdata/osscred/empty_oss_cred create mode 100644 go/libraries/doltcore/dbfactory/testdata/osscred/single_oss_cred diff --git a/go/cmd/dolt/cli/arg_parser_helpers.go b/go/cmd/dolt/cli/arg_parser_helpers.go index f75d25c39b..afca385e95 100644 --- a/go/cmd/dolt/cli/arg_parser_helpers.go +++ b/go/cmd/dolt/cli/arg_parser_helpers.go @@ -170,6 +170,8 @@ func CreateCloneArgParser() *argparser.ArgParser { ap.SupportsValidatedString(dbfactory.AWSCredsTypeParam, "", "creds-type", "", argparser.ValidatorFromStrList(dbfactory.AWSCredsTypeParam, dbfactory.AWSCredTypes)) ap.SupportsString(dbfactory.AWSCredsFileParam, "", "file", "AWS credentials file.") ap.SupportsString(dbfactory.AWSCredsProfile, "", "profile", "AWS profile to use.") + ap.SupportsString(dbfactory.OSSCredsFileParam, "", "file", "OSS credentials file.") + ap.SupportsString(dbfactory.OSSCredsProfile, "", "profile", "OSS profile to use.") return ap } @@ -278,17 +280,20 @@ func CreateVerifyConstraintsArgParser() *argparser.ArgParser { } var awsParams = []string{dbfactory.AWSRegionParam, dbfactory.AWSCredsTypeParam, dbfactory.AWSCredsFileParam, dbfactory.AWSCredsProfile} +var ossParams = []string{dbfactory.OSSCredsFileParam, dbfactory.OSSCredsProfile} func ProcessBackupArgs(apr *argparser.ArgParseResults, scheme, backupUrl string) (map[string]string, error) { params := map[string]string{} var err error - if scheme == dbfactory.AWSScheme { + switch scheme { + case dbfactory.AWSScheme: err = AddAWSParams(backupUrl, apr, params) - } else { + case dbfactory.OSSScheme: + err = AddOSSParams(backupUrl, apr, params) + default: err = VerifyNoAwsParams(apr) } - return params, err } @@ -312,6 +317,26 @@ func AddAWSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[s return nil } +func AddOSSParams(remoteUrl string, apr *argparser.ArgParseResults, params map[string]string) error { + isOSS := strings.HasPrefix(remoteUrl, "oss") + + if !isOSS { + for _, p := range ossParams { + if _, ok := apr.GetValue(p); ok { + return fmt.Errorf("%s param is only valid for oss cloud remotes in the format oss://oss-bucket/database", p) + } + } + } + + for _, p := range ossParams { + if val, ok := apr.GetValue(p); ok { + params[p] = val + } + } + + return nil +} + func VerifyNoAwsParams(apr *argparser.ArgParseResults) error { if awsParams := apr.GetValues(awsParams...); len(awsParams) > 0 { awsParamKeys := make([]string, 0, len(awsParams)) diff --git a/go/cmd/dolt/commands/remote.go b/go/cmd/dolt/commands/remote.go index af56d33e95..14032bd08c 100644 --- a/go/cmd/dolt/commands/remote.go +++ b/go/cmd/dolt/commands/remote.go @@ -94,6 +94,8 @@ func (cmd RemoteCmd) ArgParser() *argparser.ArgParser { ap.SupportsValidatedString(dbfactory.AWSCredsTypeParam, "", "creds-type", "", argparser.ValidatorFromStrList(dbfactory.AWSCredsTypeParam, dbfactory.AWSCredTypes)) ap.SupportsString(dbfactory.AWSCredsFileParam, "", "file", "AWS credentials file") ap.SupportsString(dbfactory.AWSCredsProfile, "", "profile", "AWS profile to use") + ap.SupportsString(dbfactory.OSSCredsFileParam, "", "file", "OSS credentials file") + ap.SupportsString(dbfactory.OSSCredsProfile, "", "profile", "OSS profile to use") return ap } @@ -191,12 +193,14 @@ func parseRemoteArgs(apr *argparser.ArgParseResults, scheme, remoteUrl string) ( params := map[string]string{} var err error - if scheme == dbfactory.AWSScheme { + switch scheme { + case dbfactory.AWSScheme: err = cli.AddAWSParams(remoteUrl, apr, params) - } else { + case dbfactory.OSSScheme: + err = cli.AddOSSParams(remoteUrl, apr, params) + default: err = cli.VerifyNoAwsParams(apr) } - if err != nil { return nil, errhand.VerboseErrorFromError(err) } diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go index 27ceedea6d..e8b7ad1471 100644 --- a/go/libraries/doltcore/dbfactory/oss.go +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -2,6 +2,7 @@ package dbfactory import ( "context" + "encoding/json" "errors" "fmt" "github.com/aliyun/aliyun-oss-go-sdk/oss" @@ -11,16 +12,37 @@ import ( "github.com/dolthub/dolt/go/store/nbs" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" + "io/ioutil" "net/url" "os" + "path/filepath" ) const ( ossEndpointEnvKey = "OSS_ENDPOINT" ossAccessKeyIDEnvKey = "OSS_ACCESS_KEY_ID" ossAccessKeySecretEnvKey = "OSS_ACCESS_KEY_SECRET" + + // OSSCredsFileParam is a creation parameter that can be used to specify a credential file to use. + OSSCredsFileParam = "oss-creds-file" + + // OSSCredsProfile is a creation parameter that can be used to specify which AWS profile to use. + OSSCredsProfile = "oss-creds-profile" ) +var ( + emptyOSSCredential = ossCredential{} +) + +type ossParams map[string]interface{} +type ossCredentials map[string]ossCredential + +type ossCredential struct { + Endpoint string `json:"endpoint,omitempty"` + AccessKeyID string `json:"accessKeyID,omitempty"` + AccessKeySecret string `json:"accessKeySecret,omitempty"` +} + // OSSFactory is a DBFactory implementation for creating GCS backed databases type OSSFactory struct { } @@ -43,7 +65,10 @@ func (fact OSSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm // oss://[bucket]/[key] bucket := urlObj.Hostname() prefix := urlObj.Path - ossClient, err := getOSSClient() + + opts := ossConfigFromParams(params) + + ossClient, err := getOSSClient(opts) if err != nil { return nil, fmt.Errorf("failed to initialize oss err: %s", err) } @@ -55,16 +80,41 @@ func (fact OSSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm return nbs.NewBSStore(ctx, nbf.VersionString(), bs, defaultMemTableSize, q) } -func getOSSClient() (*oss.Client, error) { - var endpoint, accessKeyID, accessKeySecret string - if endpoint = os.Getenv(ossEndpointEnvKey); endpoint == "" { - return nil, fmt.Errorf("failed to find endpoint from env %s", ossEndpointEnvKey) +func ossConfigFromParams(params map[string]interface{}) ossCredential { + // then we look for config from oss-creds-file + p := ossParams(params) + credFile, err := p.getCredFile() + if err != nil { + return emptyOSSCredential } - if accessKeyID = os.Getenv(ossAccessKeyIDEnvKey); accessKeyID == "" { - return nil, fmt.Errorf("failed to find accessKeyID from env %s", ossAccessKeyIDEnvKey) + creds, err := readOSSCredentialsFromFile(credFile) + if err != nil { + return emptyOSSCredential } - if accessKeySecret = os.Getenv(ossAccessKeySecretEnvKey); accessKeySecret == "" { - return nil, fmt.Errorf("failed to find accessKeySecret from env %s", ossAccessKeySecretEnvKey) + // if there is only 1 cred in the file, just use this cred regardless the profile is + if len(creds) == 1 { + return creds.First() + } + // otherwise, we try to get cred by profile from cred file + if res, ok := creds[p.getCredProfile()]; ok { + return res + } + return emptyOSSCredential +} + +func getOSSClient(opts ossCredential) (*oss.Client, error) { + var ( + endpoint, accessKeyID, accessKeySecret string + err error + ) + if endpoint, err = opts.getEndPoint(); err != nil { + return nil, err + } + if accessKeyID, err = opts.getAccessKeyID(); err != nil { + return nil, err + } + if accessKeySecret, err = opts.getAccessKeySecret(); err != nil { + return nil, err } return oss.New( endpoint, @@ -72,3 +122,79 @@ func getOSSClient() (*oss.Client, error) { accessKeySecret, ) } + +func (opt ossCredential) getEndPoint() (string, error) { + if opt.Endpoint != "" { + return opt.Endpoint, nil + } + if v := os.Getenv(ossEndpointEnvKey); v != "" { + return v, nil + } + return "", fmt.Errorf("failed to find endpoint from cred file or env %s", ossEndpointEnvKey) +} + +func (opt ossCredential) getAccessKeyID() (string, error) { + if opt.AccessKeyID != "" { + return opt.AccessKeyID, nil + } + if v := os.Getenv(ossAccessKeyIDEnvKey); v != "" { + return v, nil + } + return "", fmt.Errorf("failed to find accessKeyID from cred file or env %s", ossAccessKeyIDEnvKey) +} + +func (opt ossCredential) getAccessKeySecret() (string, error) { + if opt.AccessKeySecret != "" { + return opt.AccessKeySecret, nil + } + if v := os.Getenv(ossAccessKeySecretEnvKey); v != "" { + return v, nil + } + return "", fmt.Errorf("failed to find accessKeySecret from cred file or env %s", ossAccessKeySecretEnvKey) +} + +func readOSSCredentialsFromFile(credFile string) (ossCredentials, error) { + data, err := ioutil.ReadFile(credFile) + if err != nil { + return nil, fmt.Errorf("failed to read oss cred file %s, err: %s", credFile, err) + } + var res map[string]ossCredential + if err = json.Unmarshal(data, &res); err != nil { + return nil, fmt.Errorf("invalid oss credential file %s, err: %s", credFile, err) + } + if len(res) == 0 { + return nil, errors.New("empty credential file is not allowed") + } + return res, nil +} + +func (oc ossCredentials) First() ossCredential { + var res ossCredential + for _, c := range oc { + res = c + break + } + return res +} + +func (p ossParams) getCredFile() (string, error) { + // then we look for config from oss-creds-file + credFile, ok := p[OSSCredsFileParam] + if !ok { + // if oss-creds-files is + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to find oss cred file from home dir, err: %s", err) + } + credFile = filepath.Join(homeDir, ".oss", "dolt_oss_credentials") + } + return credFile.(string), nil +} + +func (p ossParams) getCredProfile() string { + credProfile, ok := p[OSSCredsProfile] + if !ok { + credProfile = "default" + } + return credProfile.(string) +} diff --git a/go/libraries/doltcore/dbfactory/oss_test.go b/go/libraries/doltcore/dbfactory/oss_test.go index a7574e9e05..02b880530f 100644 --- a/go/libraries/doltcore/dbfactory/oss_test.go +++ b/go/libraries/doltcore/dbfactory/oss_test.go @@ -1,23 +1,175 @@ package dbfactory import ( + "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/stretchr/testify/assert" "os" "testing" ) -func Test_getOSSClient(t *testing.T) { - _, err := getOSSClient() - assert.Error(t, err) - os.Setenv(ossEndpointEnvKey, "testendpoint") - _, err = getOSSClient() - assert.Error(t, err) - - os.Setenv(ossAccessKeyIDEnvKey, "testAccesskey") - _, err = getOSSClient() - assert.Error(t, err) - - os.Setenv(ossAccessKeySecretEnvKey, "testAccessSecret") - _, err = getOSSClient() +func Test_readOssCredentialsFromFile(t *testing.T) { + creds, err := readOSSCredentialsFromFile("testdata/osscred/dolt_oss_credentials") assert.Nil(t, err) + assert.Equal(t, 3, len(creds)) +} + +func Test_ossConfigFromParams(t *testing.T) { + type args struct { + params map[string]interface{} + } + tests := []struct { + name string + args args + want ossCredential + }{ + { + name: "not exist", + args: args{ + params: nil, + }, + want: emptyOSSCredential, + }, + { + name: "get default profile", + args: args{ + params: map[string]interface{}{ + OSSCredsFileParam: "testdata/osscred/dolt_oss_credentials", + }, + }, + want: ossCredential{ + Endpoint: "oss-cn-hangzhou.aliyuncs.com", + AccessKeyID: "defaulttestid", + AccessKeySecret: "test secret", + }, + }, + { + name: "get default profile single cred", + args: args{ + params: map[string]interface{}{ + OSSCredsFileParam: "testdata/osscred/single_oss_cred", + }, + }, + want: ossCredential{ + Endpoint: "oss-cn-hangzhou.aliyuncs.com", + AccessKeyID: "testid", + AccessKeySecret: "test secret", + }, + }, + { + name: "get cred by profile", + args: args{ + params: map[string]interface{}{ + OSSCredsFileParam: "testdata/osscred/dolt_oss_credentials", + OSSCredsProfile: "prod", + }, + }, + want: ossCredential{ + Endpoint: "oss-cn-hangzhou.aliyuncs.com", + AccessKeyID: "prodid", + AccessKeySecret: "test secret", + }, + }, + { + name: "profile not exists", + args: args{ + params: map[string]interface{}{ + OSSCredsFileParam: "testdata/osscred/dolt_oss_credentials", + OSSCredsProfile: "notexists", + }, + }, + want: emptyOSSCredential, + }, + { + name: "empty cred file", + args: args{ + params: map[string]interface{}{ + OSSCredsFileParam: "testdata/osscred/empty_oss_cred", + }, + }, + want: emptyOSSCredential, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, ossConfigFromParams(tt.args.params), "ossConfigFromParams(%v)", tt.args.params) + }) + } +} + +func Test_getOSSClient(t *testing.T) { + type args struct { + opts ossCredential + } + tests := []struct { + name string + args args + before func() + after func() + want func(got *oss.Client) bool + wantErr bool + }{ + { + name: "get valid oss client", + args: args{ + opts: ossCredential{ + Endpoint: "testendpoint", + AccessKeyID: "testid", + AccessKeySecret: "testkey", + }, + }, + wantErr: false, + want: func(got *oss.Client) bool { + return got != nil + }, + }, + { + name: "get invalid oss client", + args: args{ + opts: ossCredential{ + Endpoint: "", + AccessKeyID: "testid", + AccessKeySecret: "testkey", + }, + }, + wantErr: true, + want: func(got *oss.Client) bool { + return got == nil + }, + }, + { + name: "get valid oss client from env", + before: func() { + os.Setenv(ossEndpointEnvKey, "testendpoint") + }, + after: func() { + os.Unsetenv(ossEndpointEnvKey) + }, + args: args{ + opts: ossCredential{ + Endpoint: "", + AccessKeyID: "testid", + AccessKeySecret: "testkey", + }, + }, + wantErr: false, + want: func(got *oss.Client) bool { + return got != nil + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.before != nil { + tt.before() + } + if tt.after != nil { + defer tt.after() + } + got, err := getOSSClient(tt.args.opts) + if tt.wantErr { + assert.Error(t, err) + } + assert.True(t, tt.want(got)) + }) + } } diff --git a/go/libraries/doltcore/dbfactory/testdata/osscred/dolt_oss_credentials b/go/libraries/doltcore/dbfactory/testdata/osscred/dolt_oss_credentials new file mode 100644 index 0000000000..33380e1b65 --- /dev/null +++ b/go/libraries/doltcore/dbfactory/testdata/osscred/dolt_oss_credentials @@ -0,0 +1,17 @@ +{ + "prod": { + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + "accessKeyID": "prodid", + "accessKeySecret": "test secret" + }, + "dev": { + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + "accessKeyID": "devid", + "accessKeySecret": "dev secret" + }, + "default": { + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + "accessKeyID": "defaulttestid", + "accessKeySecret": "test secret" + } +} diff --git a/go/libraries/doltcore/dbfactory/testdata/osscred/empty_oss_cred b/go/libraries/doltcore/dbfactory/testdata/osscred/empty_oss_cred new file mode 100644 index 0000000000..e69de29bb2 diff --git a/go/libraries/doltcore/dbfactory/testdata/osscred/single_oss_cred b/go/libraries/doltcore/dbfactory/testdata/osscred/single_oss_cred new file mode 100644 index 0000000000..a91deec9bd --- /dev/null +++ b/go/libraries/doltcore/dbfactory/testdata/osscred/single_oss_cred @@ -0,0 +1,7 @@ +{ + "prod": { + "endpoint": "oss-cn-hangzhou.aliyuncs.com", + "accessKeyID": "testid", + "accessKeySecret": "test secret" + } +} diff --git a/go/libraries/doltcore/sqle/dprocedures/dolt_remote.go b/go/libraries/doltcore/sqle/dprocedures/dolt_remote.go index dbffbe3af8..580d8371a8 100644 --- a/go/libraries/doltcore/sqle/dprocedures/dolt_remote.go +++ b/go/libraries/doltcore/sqle/dprocedures/dolt_remote.go @@ -144,10 +144,12 @@ func remoteParams(apr *argparser.ArgParseResults, scheme, remoteUrl string) (map params := map[string]string{} var err error - if scheme == dbfactory.AWSScheme { - // TODO: get AWS params from session + switch scheme { + case dbfactory.AWSScheme: err = cli.AddAWSParams(remoteUrl, apr, params) - } else { + case dbfactory.OSSScheme: + err = cli.AddOSSParams(remoteUrl, apr, params) + default: err = cli.VerifyNoAwsParams(apr) } From a3987fa400311cd40a92c3d2e3be9a91dfc539ca Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Sun, 18 Sep 2022 17:39:06 +0800 Subject: [PATCH 5/8] fix: review suggestion --- go/libraries/doltcore/dbfactory/oss.go | 6 +++--- go/store/blobstore/oss.go | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go index e8b7ad1471..0b3a027af4 100644 --- a/go/libraries/doltcore/dbfactory/oss.go +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -26,7 +26,7 @@ const ( // OSSCredsFileParam is a creation parameter that can be used to specify a credential file to use. OSSCredsFileParam = "oss-creds-file" - // OSSCredsProfile is a creation parameter that can be used to specify which AWS profile to use. + // OSSCredsProfile is a creation parameter that can be used to specify which OSS profile to use. OSSCredsProfile = "oss-creds-profile" ) @@ -43,11 +43,11 @@ type ossCredential struct { AccessKeySecret string `json:"accessKeySecret,omitempty"` } -// OSSFactory is a DBFactory implementation for creating GCS backed databases +// OSSFactory is a DBFactory implementation for creating OSS backed databases type OSSFactory struct { } -// CreateDB creates an GCS backed database +// CreateDB creates an OSS backed database func (fact OSSFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (datas.Database, types.ValueReadWriter, tree.NodeStore, error) { ossStore, err := fact.newChunkStore(ctx, nbf, urlObj, params) if err != nil { diff --git a/go/store/blobstore/oss.go b/go/store/blobstore/oss.go index 03be44b351..4448228f34 100644 --- a/go/store/blobstore/oss.go +++ b/go/store/blobstore/oss.go @@ -91,7 +91,10 @@ func (ob *OSSBlobstore) CheckAndPut(ctx context.Context, expectedVersion, key st if err := ob.bucket.PutObject(ob.absKey(key), reader, options...); err != nil { ossErr, ok := err.(oss.ServiceError) if ok { - return "", CheckAndPutError{key, expectedVersion, fmt.Sprintf("unknown (OSS error code %d)", ossErr.StatusCode)} + return "", CheckAndPutError{ + Key: key, + ExpectedVersion: expectedVersion, + ActualVersion: fmt.Sprintf("unknown (OSS error code %d)", ossErr.StatusCode)} } return "", err } From 03fe33aa0b3efb08d7495707286d008cb3d31352 Mon Sep 17 00:00:00 2001 From: Gaius Date: Mon, 19 Sep 2022 17:45:48 +0800 Subject: [PATCH 6/8] feat: add PrepareDB to OSSFactory Signed-off-by: Gaius --- go/go.mod | 3 ++- go/libraries/doltcore/dbfactory/oss.go | 15 +++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/go/go.mod b/go/go.mod index 6b270d55ca..0a23775edd 100644 --- a/go/go.mod +++ b/go/go.mod @@ -56,6 +56,7 @@ require ( ) require ( + github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible github.com/dolthub/go-mysql-server v0.12.1-0.20220917042748-7b03d007ce46 github.com/google/flatbuffers v2.0.6+incompatible github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6 @@ -84,7 +85,6 @@ require ( github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b // indirect github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible // indirect github.com/apache/thrift v0.13.1-0.20201008052519-daf620915714 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect @@ -128,6 +128,7 @@ require ( golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3 // indirect golang.org/x/oauth2 v0.0.0-20200902213428-5d25da1a8d43 // indirect golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 // indirect + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect golang.org/x/tools v0.1.10 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go index 0b3a027af4..f8c95535e0 100644 --- a/go/libraries/doltcore/dbfactory/oss.go +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -5,6 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" + "net/url" + "os" + "path/filepath" + "github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/dolthub/dolt/go/store/blobstore" "github.com/dolthub/dolt/go/store/chunks" @@ -12,10 +17,6 @@ import ( "github.com/dolthub/dolt/go/store/nbs" "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" - "io/ioutil" - "net/url" - "os" - "path/filepath" ) const ( @@ -47,6 +48,12 @@ type ossCredential struct { type OSSFactory struct { } +// PrepareDB prepares an OSS backed database +func (fact OSSFactory) PrepareDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) error { + // nothing to prepare + return nil +} + // CreateDB creates an OSS backed database func (fact OSSFactory) CreateDB(ctx context.Context, nbf *types.NomsBinFormat, urlObj *url.URL, params map[string]interface{}) (datas.Database, types.ValueReadWriter, tree.NodeStore, error) { ossStore, err := fact.newChunkStore(ctx, nbf, urlObj, params) From ff04472567a61b34d3b37453e2fc5a9e2634c032 Mon Sep 17 00:00:00 2001 From: xujihui1985 Date: Thu, 22 Sep 2022 09:27:45 +0800 Subject: [PATCH 7/8] fix: copyright and failed unittest --- .../doltcore/dbfactory/factory_test.go | 9 --------- go/libraries/doltcore/dbfactory/oss.go | 17 +++++++++++++++-- go/libraries/doltcore/dbfactory/oss_test.go | 19 +++++++++++++++++-- go/store/blobstore/oss.go | 17 ++++++++++++++++- go/store/blobstore/oss_test.go | 17 ++++++++++++++++- 5 files changed, 64 insertions(+), 15 deletions(-) diff --git a/go/libraries/doltcore/dbfactory/factory_test.go b/go/libraries/doltcore/dbfactory/factory_test.go index 61c3c58581..604a06b035 100644 --- a/go/libraries/doltcore/dbfactory/factory_test.go +++ b/go/libraries/doltcore/dbfactory/factory_test.go @@ -61,12 +61,3 @@ func TestCreateMemDB(t *testing.T) { assert.NotNil(t, vrw) assert.NotNil(t, ns) } - -func TestCreateDB(t *testing.T) { - ctx := context.Background() - db, vrw, ns, err := CreateDB(ctx, types.Format_Default, "oss://aaa/bbb", nil) - assert.NoError(t, err) - assert.NotNil(t, db) - assert.NotNil(t, vrw) - assert.NotNil(t, ns) -} diff --git a/go/libraries/doltcore/dbfactory/oss.go b/go/libraries/doltcore/dbfactory/oss.go index f8c95535e0..584ebd753f 100644 --- a/go/libraries/doltcore/dbfactory/oss.go +++ b/go/libraries/doltcore/dbfactory/oss.go @@ -1,3 +1,17 @@ +// Copyright 2019 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package dbfactory import ( @@ -5,7 +19,6 @@ import ( "encoding/json" "errors" "fmt" - "io/ioutil" "net/url" "os" "path/filepath" @@ -161,7 +174,7 @@ func (opt ossCredential) getAccessKeySecret() (string, error) { } func readOSSCredentialsFromFile(credFile string) (ossCredentials, error) { - data, err := ioutil.ReadFile(credFile) + data, err := os.ReadFile(credFile) if err != nil { return nil, fmt.Errorf("failed to read oss cred file %s, err: %s", credFile, err) } diff --git a/go/libraries/doltcore/dbfactory/oss_test.go b/go/libraries/doltcore/dbfactory/oss_test.go index 02b880530f..f7393c7275 100644 --- a/go/libraries/doltcore/dbfactory/oss_test.go +++ b/go/libraries/doltcore/dbfactory/oss_test.go @@ -1,10 +1,25 @@ +// Copyright 2019 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package dbfactory import ( - "github.com/aliyun/aliyun-oss-go-sdk/oss" - "github.com/stretchr/testify/assert" "os" "testing" + + "github.com/aliyun/aliyun-oss-go-sdk/oss" + "github.com/stretchr/testify/assert" ) func Test_readOssCredentialsFromFile(t *testing.T) { diff --git a/go/store/blobstore/oss.go b/go/store/blobstore/oss.go index 4448228f34..9f7d4fc1de 100644 --- a/go/store/blobstore/oss.go +++ b/go/store/blobstore/oss.go @@ -1,13 +1,28 @@ +// Copyright 2019 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package blobstore import ( "context" "fmt" - "github.com/aliyun/aliyun-oss-go-sdk/oss" "io" "net/http" "path" "strconv" + + "github.com/aliyun/aliyun-oss-go-sdk/oss" ) const ( diff --git a/go/store/blobstore/oss_test.go b/go/store/blobstore/oss_test.go index a7387d5227..1d70c8b7f7 100644 --- a/go/store/blobstore/oss_test.go +++ b/go/store/blobstore/oss_test.go @@ -1,8 +1,23 @@ +// Copyright 2019 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package blobstore import ( - "github.com/stretchr/testify/assert" "testing" + + "github.com/stretchr/testify/assert" ) func Test_normalizePrefix(t *testing.T) { From 691997b87a40125477710424860f45abfb418274 Mon Sep 17 00:00:00 2001 From: jennifersp <44716627+jennifersp@users.noreply.github.com> Date: Mon, 26 Sep 2022 09:48:58 -0700 Subject: [PATCH 8/8] support dolt_diff_summary_table_function (#4363) --- go/cmd/dolt/commands/diff.go | 3 +- go/cmd/dolt/commands/diff_output.go | 39 +- go/libraries/doltcore/diff/diff_summary.go | 40 +- go/libraries/doltcore/diff/table_deltas.go | 18 +- go/libraries/doltcore/row/tagged_values.go | 7 +- go/libraries/doltcore/schema/schema.go | 5 + .../doltcore/sqle/database_provider.go | 6 +- .../sqle/dolt_diff_summary_table_function.go | 509 ++++++++++++++++ .../doltcore/sqle/dolt_diff_table_function.go | 6 +- .../sqle/enginetest/dolt_engine_test.go | 22 + .../doltcore/sqle/enginetest/dolt_queries.go | 553 ++++++++++++++++++ integration-tests/bats/diff.bats | 32 +- 12 files changed, 1186 insertions(+), 54 deletions(-) create mode 100644 go/libraries/doltcore/sqle/dolt_diff_summary_table_function.go diff --git a/go/cmd/dolt/commands/diff.go b/go/cmd/dolt/commands/diff.go index 6c9864acd2..8327826558 100644 --- a/go/cmd/dolt/commands/diff.go +++ b/go/cmd/dolt/commands/diff.go @@ -391,8 +391,7 @@ func diffUserTable( } if dArgs.diffParts&Summary != 0 { - numCols := fromSch.GetAllCols().Size() - return printDiffSummary(ctx, td, numCols) + return printDiffSummary(ctx, td, fromSch.GetAllCols().Size(), toSch.GetAllCols().Size()) } if dArgs.diffParts&SchemaOnlyDiff != 0 { diff --git a/go/cmd/dolt/commands/diff_output.go b/go/cmd/dolt/commands/diff_output.go index ff6cf8b7ca..6f6ddaeca5 100644 --- a/go/cmd/dolt/commands/diff_output.go +++ b/go/cmd/dolt/commands/diff_output.go @@ -65,7 +65,7 @@ func newDiffWriter(diffOutput diffOutput) (diffWriter, error) { } } -func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errhand.VerboseError { +func printDiffSummary(ctx context.Context, td diff.TableDelta, oldColLen, newColLen int) errhand.VerboseError { // todo: use errgroup.Group ae := atomicerr.New() ch := make(chan diff.DiffSummaryProgress) @@ -89,11 +89,13 @@ func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errha acc.Removes += p.Removes acc.Changes += p.Changes acc.CellChanges += p.CellChanges - acc.NewSize += p.NewSize - acc.OldSize += p.OldSize + acc.NewRowSize += p.NewRowSize + acc.OldRowSize += p.OldRowSize + acc.NewCellSize += p.NewCellSize + acc.OldCellSize += p.OldCellSize if count%10000 == 0 { - eP.Printf("prev size: %d, new size: %d, adds: %d, deletes: %d, modifications: %d\n", acc.OldSize, acc.NewSize, acc.Adds, acc.Removes, acc.Changes) + eP.Printf("prev size: %d, new size: %d, adds: %d, deletes: %d, modifications: %d\n", acc.OldRowSize, acc.NewRowSize, acc.Adds, acc.Removes, acc.Changes) eP.Display() } @@ -108,10 +110,10 @@ func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errha keyless, err := td.IsKeyless(ctx) if err != nil { - return nil + return errhand.BuildDError("").AddCause(err).Build() } - if (acc.Adds + acc.Removes + acc.Changes) == 0 { + if (acc.Adds+acc.Removes+acc.Changes) == 0 && (acc.OldCellSize-acc.NewCellSize) == 0 { cli.Println("No data changes. See schema changes by using -s or --schema.") return nil } @@ -119,24 +121,27 @@ func printDiffSummary(ctx context.Context, td diff.TableDelta, colLen int) errha if keyless { printKeylessSummary(acc) } else { - printSummary(acc, colLen) + printSummary(acc, oldColLen, newColLen) } return nil } -func printSummary(acc diff.DiffSummaryProgress, colLen int) { - rowsUnmodified := uint64(acc.OldSize - acc.Changes - acc.Removes) +func printSummary(acc diff.DiffSummaryProgress, oldColLen, newColLen int) { + numCellInserts, numCellDeletes := sqle.GetCellsAddedAndDeleted(acc, newColLen) + rowsUnmodified := uint64(acc.OldRowSize - acc.Changes - acc.Removes) unmodified := pluralize("Row Unmodified", "Rows Unmodified", rowsUnmodified) insertions := pluralize("Row Added", "Rows Added", acc.Adds) deletions := pluralize("Row Deleted", "Rows Deleted", acc.Removes) changes := pluralize("Row Modified", "Rows Modified", acc.Changes) + cellInsertions := pluralize("Cell Added", "Cells Added", numCellInserts) + cellDeletions := pluralize("Cell Deleted", "Cells Deleted", numCellDeletes) cellChanges := pluralize("Cell Modified", "Cells Modified", acc.CellChanges) - oldValues := pluralize("Entry", "Entries", acc.OldSize) - newValues := pluralize("Entry", "Entries", acc.NewSize) + oldValues := pluralize("Row Entry", "Row Entries", acc.OldRowSize) + newValues := pluralize("Row Entry", "Row Entries", acc.NewRowSize) - percentCellsChanged := float64(100*acc.CellChanges) / (float64(acc.OldSize) * float64(colLen)) + percentCellsChanged := float64(100*acc.CellChanges) / (float64(acc.OldRowSize) * float64(oldColLen)) safePercent := func(num, dom uint64) float64 { // returns +Inf for x/0 where x > 0 @@ -146,10 +151,12 @@ func printSummary(acc diff.DiffSummaryProgress, colLen int) { return float64(100*num) / (float64(dom)) } - cli.Printf("%s (%.2f%%)\n", unmodified, safePercent(rowsUnmodified, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", insertions, safePercent(acc.Adds, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", deletions, safePercent(acc.Removes, acc.OldSize)) - cli.Printf("%s (%.2f%%)\n", changes, safePercent(acc.Changes, acc.OldSize)) + cli.Printf("%s (%.2f%%)\n", unmodified, safePercent(rowsUnmodified, acc.OldRowSize)) + cli.Printf("%s (%.2f%%)\n", insertions, safePercent(acc.Adds, acc.OldRowSize)) + cli.Printf("%s (%.2f%%)\n", deletions, safePercent(acc.Removes, acc.OldRowSize)) + cli.Printf("%s (%.2f%%)\n", changes, safePercent(acc.Changes, acc.OldRowSize)) + cli.Printf("%s (%.2f%%)\n", cellInsertions, safePercent(numCellInserts, acc.OldCellSize)) + cli.Printf("%s (%.2f%%)\n", cellDeletions, safePercent(numCellDeletes, acc.OldCellSize)) cli.Printf("%s (%.2f%%)\n", cellChanges, percentCellsChanged) cli.Printf("(%s vs %s)\n\n", oldValues, newValues) } diff --git a/go/libraries/doltcore/diff/diff_summary.go b/go/libraries/doltcore/diff/diff_summary.go index 96d2acb437..ff8b55172a 100644 --- a/go/libraries/doltcore/diff/diff_summary.go +++ b/go/libraries/doltcore/diff/diff_summary.go @@ -33,11 +33,11 @@ import ( ) type DiffSummaryProgress struct { - Adds, Removes, Changes, CellChanges, NewSize, OldSize uint64 + Adds, Removes, Changes, CellChanges, NewRowSize, OldRowSize, NewCellSize, OldCellSize uint64 } type prollyReporter func(ctx context.Context, vMapping val.OrdinalMapping, fromD, toD val.TupleDesc, change tree.Diff, ch chan<- DiffSummaryProgress) error -type nomsReporter func(ctx context.Context, change *diff.Difference, ch chan<- DiffSummaryProgress) error +type nomsReporter func(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffSummaryProgress) error // Summary reports a summary of diff changes between two values // todo: make package private once dolthub is migrated @@ -50,7 +50,7 @@ func Summary(ctx context.Context, ch chan DiffSummaryProgress, from, to durable. if err != nil { return err } - ch <- DiffSummaryProgress{OldSize: fc, NewSize: tc} + ch <- DiffSummaryProgress{OldRowSize: fc, NewRowSize: tc} fk, tk := schema.IsKeyless(fromSch), schema.IsKeyless(toSch) var keyless bool @@ -64,7 +64,7 @@ func Summary(ctx context.Context, ch chan DiffSummaryProgress, from, to durable. return diffProllyTrees(ctx, ch, keyless, from, to, fromSch, toSch) } - return diffNomsMaps(ctx, ch, keyless, from, to) + return diffNomsMaps(ctx, ch, keyless, from, to, fromSch, toSch) } // SummaryForTableDelta pushes diff summary progress messages for the table delta given to the channel given @@ -91,7 +91,7 @@ func SummaryForTableDelta(ctx context.Context, ch chan DiffSummaryProgress, td T if types.IsFormat_DOLT(td.Format()) { return diffProllyTrees(ctx, ch, keyless, fromRows, toRows, fromSch, toSch) } else { - return diffNomsMaps(ctx, ch, keyless, fromRows, toRows) + return diffNomsMaps(ctx, ch, keyless, fromRows, toRows, fromSch, toSch) } } @@ -114,14 +114,18 @@ func diffProllyTrees(ctx context.Context, ch chan DiffSummaryProgress, keyless b if err != nil { return err } + cfc := uint64(len(fromSch.GetAllCols().GetColumns())) * fc tc, err := to.Count() if err != nil { return err } + ctc := uint64(len(toSch.GetAllCols().GetColumns())) * tc rpr = reportPkChanges ch <- DiffSummaryProgress{ - OldSize: fc, - NewSize: tc, + OldRowSize: fc, + NewRowSize: tc, + OldCellSize: cfc, + NewCellSize: ctc, } } @@ -134,7 +138,7 @@ func diffProllyTrees(ctx context.Context, ch chan DiffSummaryProgress, keyless b return nil } -func diffNomsMaps(ctx context.Context, ch chan DiffSummaryProgress, keyless bool, fromRows durable.Index, toRows durable.Index) error { +func diffNomsMaps(ctx context.Context, ch chan DiffSummaryProgress, keyless bool, fromRows durable.Index, toRows durable.Index, fromSch, toSch schema.Schema) error { var rpr nomsReporter if keyless { rpr = reportNomsKeylessChanges @@ -143,21 +147,25 @@ func diffNomsMaps(ctx context.Context, ch chan DiffSummaryProgress, keyless bool if err != nil { return err } + cfc := uint64(len(fromSch.GetAllCols().GetColumns())) * fc tc, err := toRows.Count() if err != nil { return err } + ctc := uint64(len(toSch.GetAllCols().GetColumns())) * tc rpr = reportNomsPkChanges ch <- DiffSummaryProgress{ - OldSize: fc, - NewSize: tc, + OldRowSize: fc, + NewRowSize: tc, + OldCellSize: cfc, + NewCellSize: ctc, } } - return summaryWithReporter(ctx, ch, durable.NomsMapFromIndex(fromRows), durable.NomsMapFromIndex(toRows), rpr) + return summaryWithReporter(ctx, ch, durable.NomsMapFromIndex(fromRows), durable.NomsMapFromIndex(toRows), rpr, fromSch, toSch) } -func summaryWithReporter(ctx context.Context, ch chan DiffSummaryProgress, from, to types.Map, rpr nomsReporter) (err error) { +func summaryWithReporter(ctx context.Context, ch chan DiffSummaryProgress, from, to types.Map, rpr nomsReporter, fromSch, toSch schema.Schema) (err error) { ad := NewAsyncDiffer(1024) ad.Start(ctx, from, to) defer func() { @@ -175,7 +183,7 @@ func summaryWithReporter(ctx context.Context, ch chan DiffSummaryProgress, from, } for _, df := range diffs { - err = rpr(ctx, df, ch) + err = rpr(ctx, df, fromSch, toSch, ch) if err != nil { return err } @@ -270,7 +278,7 @@ func prollyCountCellDiff(mapping val.OrdinalMapping, fromD, toD val.TupleDesc, f return changed } -func reportNomsPkChanges(ctx context.Context, change *diff.Difference, ch chan<- DiffSummaryProgress) error { +func reportNomsPkChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffSummaryProgress) error { var summary DiffSummaryProgress switch change.ChangeType { case types.DiffChangeAdded: @@ -280,7 +288,7 @@ func reportNomsPkChanges(ctx context.Context, change *diff.Difference, ch chan<- case types.DiffChangeModified: oldTuple := change.OldValue.(types.Tuple) newTuple := change.NewValue.(types.Tuple) - cellChanges, err := row.CountCellDiffs(oldTuple, newTuple) + cellChanges, err := row.CountCellDiffs(oldTuple, newTuple, fromSch, toSch) if err != nil { return err } @@ -296,7 +304,7 @@ func reportNomsPkChanges(ctx context.Context, change *diff.Difference, ch chan<- } } -func reportNomsKeylessChanges(ctx context.Context, change *diff.Difference, ch chan<- DiffSummaryProgress) error { +func reportNomsKeylessChanges(ctx context.Context, change *diff.Difference, fromSch, toSch schema.Schema, ch chan<- DiffSummaryProgress) error { var oldCard uint64 if change.OldValue != nil { v, err := change.OldValue.(types.Tuple).Get(row.KeylessCardinalityValIdx) diff --git a/go/libraries/doltcore/diff/table_deltas.go b/go/libraries/doltcore/diff/table_deltas.go index 7d13f99e58..fa94cd4960 100644 --- a/go/libraries/doltcore/diff/table_deltas.go +++ b/go/libraries/doltcore/diff/table_deltas.go @@ -370,14 +370,20 @@ func (td TableDelta) IsKeyless(ctx context.Context) (bool, error) { return false, err } + // nil table is neither keyless nor keyed from, to := schema.IsKeyless(f), schema.IsKeyless(t) - - if from && to { - return true, nil - } else if !from && !to { - return false, nil + if td.FromTable == nil { + return to, nil + } else if td.ToTable == nil { + return from, nil } else { - return false, fmt.Errorf("mismatched keyless and keyed schemas for table %s", td.CurName()) + if from && to { + return true, nil + } else if !from && !to { + return false, nil + } else { + return false, fmt.Errorf("mismatched keyless and keyed schemas for table %s", td.CurName()) + } } } diff --git a/go/libraries/doltcore/row/tagged_values.go b/go/libraries/doltcore/row/tagged_values.go index 1b0357fc3f..bb16911360 100644 --- a/go/libraries/doltcore/row/tagged_values.go +++ b/go/libraries/doltcore/row/tagged_values.go @@ -263,7 +263,9 @@ func (tt TaggedValues) String() string { // CountCellDiffs returns the number of fields that are different between two // tuples and does not panic if tuples are different lengths. -func CountCellDiffs(from, to types.Tuple) (uint64, error) { +func CountCellDiffs(from, to types.Tuple, fromSch, toSch schema.Schema) (uint64, error) { + fromColLen := len(fromSch.GetAllCols().GetColumns()) + toColLen := len(toSch.GetAllCols().GetColumns()) changed := 0 f, err := ParseTaggedValues(from) if err != nil { @@ -277,7 +279,8 @@ func CountCellDiffs(from, to types.Tuple) (uint64, error) { for i, v := range f { ov, ok := t[i] - if !ok || !v.Equals(ov) { + // !ok means t[i] has NULL value, and it is not cell modify if it was from drop column or add column + if (!ok && fromColLen == toColLen) || (ok && !v.Equals(ov)) { changed++ } } diff --git a/go/libraries/doltcore/schema/schema.go b/go/libraries/doltcore/schema/schema.go index 03cfabfdff..1bd7c8a708 100644 --- a/go/libraries/doltcore/schema/schema.go +++ b/go/libraries/doltcore/schema/schema.go @@ -233,6 +233,11 @@ func MapSchemaBasedOnTagAndName(inSch, outSch Schema) ([]int, []int, error) { keyMapping := make([]int, inSch.GetPKCols().Size()) valMapping := make([]int, inSch.GetNonPKCols().Size()) + // if inSch or outSch is empty schema. This can be from added or dropped table. + if len(inSch.GetAllCols().cols) == 0 || len(outSch.GetAllCols().cols) == 0 { + return keyMapping, valMapping, nil + } + err := inSch.GetPKCols().Iter(func(tag uint64, col Column) (stop bool, err error) { i := inSch.GetPKCols().TagToIdx[tag] if col, ok := outSch.GetPKCols().GetByTag(tag); ok { diff --git a/go/libraries/doltcore/sqle/database_provider.go b/go/libraries/doltcore/sqle/database_provider.go index ae9cbe3040..55ee99aa46 100644 --- a/go/libraries/doltcore/sqle/database_provider.go +++ b/go/libraries/doltcore/sqle/database_provider.go @@ -777,9 +777,13 @@ func (p DoltDatabaseProvider) ExternalStoredProcedures(_ *sql.Context, name stri func (p DoltDatabaseProvider) TableFunction(_ *sql.Context, name string) (sql.TableFunction, error) { // currently, only one table function is supported, if we extend this, we should clean this up // and store table functions in a map, similar to regular functions. - if strings.ToLower(name) == "dolt_diff" { + switch strings.ToLower(name) { + case "dolt_diff": dtf := &DiffTableFunction{} return dtf, nil + case "dolt_diff_summary": + dtf := &DiffSummaryTableFunction{} + return dtf, nil } return nil, sql.ErrTableFunctionNotFound.New(name) diff --git a/go/libraries/doltcore/sqle/dolt_diff_summary_table_function.go b/go/libraries/doltcore/sqle/dolt_diff_summary_table_function.go new file mode 100644 index 0000000000..c0cdf58ea8 --- /dev/null +++ b/go/libraries/doltcore/sqle/dolt_diff_summary_table_function.go @@ -0,0 +1,509 @@ +// Copyright 2022 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqle + +import ( + "fmt" + "io" + "math" + + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/dolt/go/libraries/doltcore/diff" + "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" + "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" + "github.com/dolthub/dolt/go/store/atomicerr" +) + +var _ sql.TableFunction = (*DiffSummaryTableFunction)(nil) + +type DiffSummaryTableFunction struct { + ctx *sql.Context + + fromCommitExpr sql.Expression + toCommitExpr sql.Expression + tableNameExpr sql.Expression + database sql.Database +} + +var diffSummaryTableSchema = sql.Schema{ + &sql.Column{Name: "table_name", Type: sql.LongText, Nullable: false}, + &sql.Column{Name: "rows_unmodified", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "rows_added", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "rows_deleted", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "rows_modified", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "cells_added", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "cells_deleted", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "cells_modified", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "old_row_count", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "new_row_count", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "old_cell_count", Type: sql.Int64, Nullable: true}, + &sql.Column{Name: "new_cell_count", Type: sql.Int64, Nullable: true}, +} + +// NewInstance creates a new instance of TableFunction interface +func (ds *DiffSummaryTableFunction) NewInstance(ctx *sql.Context, db sql.Database, expressions []sql.Expression) (sql.Node, error) { + newInstance := &DiffSummaryTableFunction{ + ctx: ctx, + database: db, + } + + node, err := newInstance.WithExpressions(expressions...) + if err != nil { + return nil, err + } + + return node, nil +} + +// Database implements the sql.Databaser interface +func (ds *DiffSummaryTableFunction) Database() sql.Database { + return ds.database +} + +// WithDatabase implements the sql.Databaser interface +func (ds *DiffSummaryTableFunction) WithDatabase(database sql.Database) (sql.Node, error) { + ds.database = database + return ds, nil +} + +// FunctionName implements the sql.TableFunction interface +func (ds *DiffSummaryTableFunction) FunctionName() string { + return "dolt_diff_summary" +} + +// Resolved implements the sql.Resolvable interface +func (ds *DiffSummaryTableFunction) Resolved() bool { + if ds.tableNameExpr != nil { + return ds.fromCommitExpr.Resolved() && ds.toCommitExpr.Resolved() && ds.tableNameExpr.Resolved() + } + return ds.fromCommitExpr.Resolved() && ds.toCommitExpr.Resolved() +} + +// String implements the Stringer interface +func (ds *DiffSummaryTableFunction) String() string { + if ds.tableNameExpr != nil { + return fmt.Sprintf("DOLT_DIFF_SUMMARY(%s, %s, %s)", ds.fromCommitExpr.String(), ds.toCommitExpr.String(), ds.tableNameExpr.String()) + } + return fmt.Sprintf("DOLT_DIFF_SUMMARY(%s, %s)", ds.fromCommitExpr.String(), ds.toCommitExpr.String()) +} + +// Schema implements the sql.Node interface. +func (ds *DiffSummaryTableFunction) Schema() sql.Schema { + return diffSummaryTableSchema +} + +// Children implements the sql.Node interface. +func (ds *DiffSummaryTableFunction) Children() []sql.Node { + return nil +} + +// WithChildren implements the sql.Node interface. +func (ds *DiffSummaryTableFunction) WithChildren(children ...sql.Node) (sql.Node, error) { + if len(children) != 0 { + return nil, fmt.Errorf("unexpected children") + } + return ds, nil +} + +// CheckPrivileges implements the interface sql.Node. +func (ds *DiffSummaryTableFunction) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { + if ds.tableNameExpr != nil { + if !sql.IsText(ds.tableNameExpr.Type()) { + return false + } + + tableNameVal, err := ds.tableNameExpr.Eval(ds.ctx, nil) + if err != nil { + return false + } + tableName, ok := tableNameVal.(string) + if !ok { + return false + } + + // TODO: Add tests for privilege checking + return opChecker.UserHasPrivileges(ctx, + sql.NewPrivilegedOperation(ds.database.Name(), tableName, "", sql.PrivilegeType_Select)) + } + + tblNames, err := ds.database.GetTableNames(ctx) + if err != nil { + return false + } + + var operations []sql.PrivilegedOperation + for _, tblName := range tblNames { + operations = append(operations, sql.NewPrivilegedOperation(ds.database.Name(), tblName, "", sql.PrivilegeType_Select)) + } + + return opChecker.UserHasPrivileges(ctx, operations...) +} + +// Expressions implements the sql.Expressioner interface. +func (ds *DiffSummaryTableFunction) Expressions() []sql.Expression { + exprs := []sql.Expression{ds.fromCommitExpr, ds.toCommitExpr} + if ds.tableNameExpr != nil { + exprs = append(exprs, ds.tableNameExpr) + } + return exprs +} + +// WithExpressions implements the sql.Expressioner interface. +func (ds *DiffSummaryTableFunction) WithExpressions(expression ...sql.Expression) (sql.Node, error) { + if len(expression) < 2 || len(expression) > 3 { + return nil, sql.ErrInvalidArgumentNumber.New(ds.FunctionName(), "2 or 3", len(expression)) + } + + for _, expr := range expression { + if !expr.Resolved() { + return nil, ErrInvalidNonLiteralArgument.New(ds.FunctionName(), expr.String()) + } + } + + ds.fromCommitExpr = expression[0] + ds.toCommitExpr = expression[1] + if len(expression) == 3 { + ds.tableNameExpr = expression[2] + } + + // validate the expressions + if !sql.IsText(ds.fromCommitExpr.Type()) { + return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.fromCommitExpr.String()) + } + + if !sql.IsText(ds.toCommitExpr.Type()) { + return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.toCommitExpr.String()) + } + + if ds.tableNameExpr != nil { + if !sql.IsText(ds.tableNameExpr.Type()) { + return nil, sql.ErrInvalidArgumentDetails.New(ds.FunctionName(), ds.tableNameExpr.String()) + } + } + + return ds, nil +} + +// RowIter implements the sql.Node interface +func (ds *DiffSummaryTableFunction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { + fromCommitVal, toCommitVal, tableName, err := ds.evaluateArguments() + if err != nil { + return nil, err + } + + sqledb, ok := ds.database.(Database) + if !ok { + return nil, fmt.Errorf("unexpected database type: %T", ds.database) + } + + sess := dsess.DSessFromSess(ctx.Session) + fromRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), fromCommitVal) + if err != nil { + return nil, err + } + + toRoot, _, err := sess.ResolveRootForRef(ctx, sqledb.Name(), toCommitVal) + if err != nil { + return nil, err + } + + deltas, err := diff.GetTableDeltas(ctx, fromRoot, toRoot) + if err != nil { + return nil, err + } + + // If tableNameExpr defined, return a single table diff summary result + if ds.tableNameExpr != nil { + delta := findMatchingDelta(deltas, tableName) + diffSum, hasDiff, err := getDiffSummaryNodeFromDelta(ctx, delta, fromRoot, toRoot, tableName) + if err != nil { + return nil, err + } + if !hasDiff { + return NewDiffSummaryTableFunctionRowIter([]diffSummaryNode{}), nil + } + return NewDiffSummaryTableFunctionRowIter([]diffSummaryNode{diffSum}), nil + } + + var diffSummaries []diffSummaryNode + for _, delta := range deltas { + tblName := delta.ToName + if tblName == "" { + tblName = delta.FromName + } + diffSum, hasDiff, err := getDiffSummaryNodeFromDelta(ctx, delta, fromRoot, toRoot, tblName) + if err != nil { + return nil, err + } + if hasDiff { + diffSummaries = append(diffSummaries, diffSum) + } + } + + return NewDiffSummaryTableFunctionRowIter(diffSummaries), nil +} + +// evaluateArguments returns fromCommitValStr, toCommitValStr and tableName. +// It evaluates the argument expressions to turn them into values this DiffTableFunction +// can use. Note that this method only evals the expressions, and doesn't validate the values. +func (ds *DiffSummaryTableFunction) evaluateArguments() (string, string, string, error) { + var tableName string + if ds.tableNameExpr != nil { + tableNameVal, err := ds.tableNameExpr.Eval(ds.ctx, nil) + if err != nil { + return "", "", "", err + } + tn, ok := tableNameVal.(string) + if !ok { + return "", "", "", ErrInvalidTableName.New(ds.tableNameExpr.String()) + } + tableName = tn + } + + fromCommitVal, err := ds.fromCommitExpr.Eval(ds.ctx, nil) + if err != nil { + return "", "", "", err + } + fromCommitValStr, ok := fromCommitVal.(string) + if !ok { + return "", "", "", fmt.Errorf("received '%v' when expecting commit hash string", fromCommitVal) + } + + toCommitVal, err := ds.toCommitExpr.Eval(ds.ctx, nil) + if err != nil { + return "", "", "", err + } + toCommitValStr, ok := toCommitVal.(string) + if !ok { + return "", "", "", fmt.Errorf("received '%v' when expecting commit hash string", toCommitVal) + } + + return fromCommitValStr, toCommitValStr, tableName, nil +} + +// getDiffSummaryNodeFromDelta returns diffSummaryNode object and whether there is data diff or not. It gets tables +// from roots and diff summary if there is a valid table exists in both fromRoot and toRoot. +func getDiffSummaryNodeFromDelta(ctx *sql.Context, delta diff.TableDelta, fromRoot, toRoot *doltdb.RootValue, tableName string) (diffSummaryNode, bool, error) { + var oldColLen int + var newColLen int + fromTable, _, fromTableExists, err := fromRoot.GetTableInsensitive(ctx, tableName) + if err != nil { + return diffSummaryNode{}, false, err + } + + if fromTableExists { + fromSch, err := fromTable.GetSchema(ctx) + if err != nil { + return diffSummaryNode{}, false, err + } + oldColLen = len(fromSch.GetAllCols().GetColumns()) + } + + toTable, _, toTableExists, err := toRoot.GetTableInsensitive(ctx, tableName) + if err != nil { + return diffSummaryNode{}, false, err + } + + if toTableExists { + toSch, err := toTable.GetSchema(ctx) + if err != nil { + return diffSummaryNode{}, false, err + } + newColLen = len(toSch.GetAllCols().GetColumns()) + } + + if !fromTableExists && !toTableExists { + return diffSummaryNode{}, false, sql.ErrTableNotFound.New(tableName) + } + + // no diff from tableDelta + if delta.FromTable == nil && delta.ToTable == nil { + return diffSummaryNode{}, false, nil + } + + diffSum, hasDiff, keyless, err := getDiffSummary(ctx, delta) + if err != nil { + return diffSummaryNode{}, false, err + } + + return diffSummaryNode{tableName, diffSum, oldColLen, newColLen, keyless}, hasDiff, nil +} + +// getDiffSummary returns diff.DiffSummaryProgress object and whether there is a data diff or not. +func getDiffSummary(ctx *sql.Context, td diff.TableDelta) (diff.DiffSummaryProgress, bool, bool, error) { + // got this method from diff_output.go + // todo: use errgroup.Group + ae := atomicerr.New() + ch := make(chan diff.DiffSummaryProgress) + go func() { + defer close(ch) + err := diff.SummaryForTableDelta(ctx, ch, td) + + ae.SetIfError(err) + }() + + acc := diff.DiffSummaryProgress{} + var count int64 + for p := range ch { + if ae.IsSet() { + break + } + + acc.Adds += p.Adds + acc.Removes += p.Removes + acc.Changes += p.Changes + acc.CellChanges += p.CellChanges + acc.NewRowSize += p.NewRowSize + acc.OldRowSize += p.OldRowSize + acc.NewCellSize += p.NewCellSize + acc.OldCellSize += p.OldCellSize + + count++ + } + + if err := ae.Get(); err != nil { + return diff.DiffSummaryProgress{}, false, false, err + } + + keyless, err := td.IsKeyless(ctx) + if err != nil { + return diff.DiffSummaryProgress{}, false, keyless, err + } + + if (acc.Adds+acc.Removes+acc.Changes) == 0 && (acc.OldCellSize-acc.NewCellSize) == 0 { + return diff.DiffSummaryProgress{}, false, keyless, nil + } + + return acc, true, keyless, nil +} + +//------------------------------------ +// diffSummaryTableFunctionRowIter +//------------------------------------ + +var _ sql.RowIter = &diffSummaryTableFunctionRowIter{} + +type diffSummaryTableFunctionRowIter struct { + diffSums []diffSummaryNode + diffIdx int +} + +func (d *diffSummaryTableFunctionRowIter) incrementIndexes() { + d.diffIdx++ + if d.diffIdx >= len(d.diffSums) { + d.diffIdx = 0 + d.diffSums = nil + } +} + +type diffSummaryNode struct { + tblName string + diffSummary diff.DiffSummaryProgress + oldColLen int + newColLen int + keyless bool +} + +func NewDiffSummaryTableFunctionRowIter(ds []diffSummaryNode) sql.RowIter { + return &diffSummaryTableFunctionRowIter{ + diffSums: ds, + } +} + +func (d *diffSummaryTableFunctionRowIter) Next(ctx *sql.Context) (sql.Row, error) { + defer d.incrementIndexes() + if d.diffIdx >= len(d.diffSums) { + return nil, io.EOF + } + + if d.diffSums == nil { + return nil, io.EOF + } + + ds := d.diffSums[d.diffIdx] + return getRowFromDiffSummary(ds.tblName, ds.diffSummary, ds.newColLen, ds.oldColLen, ds.keyless), nil +} + +func (d *diffSummaryTableFunctionRowIter) Close(context *sql.Context) error { + return nil +} + +// getRowFromDiffSummary takes diff.DiffSummaryProgress and calculates the row_modified, cell_added, cell_deleted. +// If the number of cell change from old to new cell count does not equal to cell_added and/or cell_deleted, there +// must be schema changes that affects cell_added and cell_deleted value addition to the row count * col length number. +func getRowFromDiffSummary(tblName string, dsp diff.DiffSummaryProgress, newColLen, oldColLen int, keyless bool) sql.Row { + // if table is keyless table, match current CLI command result + if keyless { + return sql.Row{ + tblName, // table_name + nil, // rows_unmodified + int64(dsp.Adds), // rows_added + int64(dsp.Removes), // rows_deleted + nil, // rows_modified + nil, // cells_added + nil, // cells_deleted + nil, // cells_modified + nil, // old_row_count + nil, // new_row_count + nil, // old_cell_count + nil, // new_cell_count + } + } + + numCellInserts, numCellDeletes := GetCellsAddedAndDeleted(dsp, newColLen) + rowsUnmodified := dsp.OldRowSize - dsp.Changes - dsp.Removes + + return sql.Row{ + tblName, // table_name + int64(rowsUnmodified), // rows_unmodified + int64(dsp.Adds), // rows_added + int64(dsp.Removes), // rows_deleted + int64(dsp.Changes), // rows_modified + int64(numCellInserts), // cells_added + int64(numCellDeletes), // cells_deleted + int64(dsp.CellChanges), // cells_modified + int64(dsp.OldRowSize), // old_row_count + int64(dsp.NewRowSize), // new_row_count + int64(dsp.OldCellSize), // old_cell_count + int64(dsp.NewCellSize), // new_cell_count + } +} + +// GetCellsAddedAndDeleted calculates cells added and deleted given diff.DiffSummaryProgress and toCommit table +// column length. We use rows added and deleted to calculate cells added and deleted, but it does not include +// cells added and deleted from schema changes. Here we fill those in using total number of cells in each commit table. +func GetCellsAddedAndDeleted(acc diff.DiffSummaryProgress, newColLen int) (uint64, uint64) { + var numCellInserts, numCellDeletes float64 + rowToCellInserts := float64(acc.Adds) * float64(newColLen) + rowToCellDeletes := float64(acc.Removes) * float64(newColLen) + cellDiff := float64(acc.NewCellSize) - float64(acc.OldCellSize) + if cellDiff > 0 { + numCellInserts = cellDiff + rowToCellDeletes + numCellDeletes = rowToCellDeletes + } else if cellDiff < 0 { + numCellInserts = rowToCellInserts + numCellDeletes = math.Abs(cellDiff) + rowToCellInserts + } else { + if rowToCellInserts != rowToCellDeletes { + numCellDeletes = math.Max(rowToCellDeletes, rowToCellInserts) + numCellInserts = math.Max(rowToCellDeletes, rowToCellInserts) + } else { + numCellDeletes = rowToCellDeletes + numCellInserts = rowToCellInserts + } + } + return uint64(numCellInserts), uint64(numCellDeletes) +} diff --git a/go/libraries/doltcore/sqle/dolt_diff_table_function.go b/go/libraries/doltcore/sqle/dolt_diff_table_function.go index cdbd6a0d5b..7da02c1a7b 100644 --- a/go/libraries/doltcore/sqle/dolt_diff_table_function.go +++ b/go/libraries/doltcore/sqle/dolt_diff_table_function.go @@ -50,7 +50,7 @@ type DiffTableFunction struct { toDate *types.Timestamp } -// NewInstance implements the TableFunction interface +// NewInstance creates a new instance of TableFunction interface func (dtf *DiffTableFunction) NewInstance(ctx *sql.Context, database sql.Database, expressions []sql.Expression) (sql.Node, error) { newInstance := &DiffTableFunction{ ctx: ctx, @@ -192,7 +192,7 @@ func loadDetailsForRef( // WithChildren implements the sql.Node interface func (dtf *DiffTableFunction) WithChildren(node ...sql.Node) (sql.Node, error) { if len(node) != 0 { - panic("unexpected children") + return nil, fmt.Errorf("unexpected children") } return dtf, nil } @@ -257,7 +257,7 @@ func (dtf *DiffTableFunction) generateSchema(ctx *sql.Context, tableName string, sqledb, ok := dtf.database.(Database) if !ok { - panic(fmt.Sprintf("unexpected database type: %T", dtf.database)) + return fmt.Errorf("unexpected database type: %T", dtf.database) } delta, err := dtf.cacheTableDelta(ctx, tableName, fromCommitVal, toCommitVal, sqledb) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go index 9dc204f2cb..5acf95ab0a 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_engine_test.go @@ -1134,6 +1134,28 @@ func TestDiffTableFunctionPrepared(t *testing.T) { } } +func TestDiffSummaryTableFunction(t *testing.T) { + harness := newDoltHarness(t) + harness.Setup(setup.MydbData) + for _, test := range DiffSummaryTableFunctionScriptTests { + harness.engine = nil + t.Run(test.Name, func(t *testing.T) { + enginetest.TestScript(t, harness, test) + }) + } +} + +func TestDiffSummaryTableFunctionPrepared(t *testing.T) { + harness := newDoltHarness(t) + harness.Setup(setup.MydbData) + for _, test := range DiffSummaryTableFunctionScriptTests { + harness.engine = nil + t.Run(test.Name, func(t *testing.T) { + enginetest.TestScriptPrepared(t, harness, test) + }) + } +} + func TestCommitDiffSystemTable(t *testing.T) { harness := newDoltHarness(t) harness.Setup(setup.MydbData) diff --git a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go index 90c45f0a7c..17c38368d3 100644 --- a/go/libraries/doltcore/sqle/enginetest/dolt_queries.go +++ b/go/libraries/doltcore/sqle/enginetest/dolt_queries.go @@ -734,6 +734,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{ Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, }, + { + // Without access to the database, dolt_diff_summary should fail with a database access error + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff_summary('main~', 'main', 'test');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, { // Grant single-table access to the underlying user table User: "root", @@ -755,6 +762,20 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{ Query: "SELECT * FROM dolt_diff('test2', 'main~', 'main');", ExpectedErr: sql.ErrPrivilegeCheckFailed, }, + { + // With access to the db, but not the table, dolt_diff_summary should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff_summary('main~', 'main', 'test2');", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, + { + // With access to the db, dolt_diff_summary should fail for all tables if no access any of tables + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff_summary('main~', 'main');", + ExpectedErr: sql.ErrPrivilegeCheckFailed, + }, { // Revoke select on mydb.test User: "root", @@ -783,6 +804,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{ Query: "SELECT COUNT(*) FROM dolt_diff('test', 'main~', 'main');", Expected: []sql.Row{{1}}, }, + { + // After granting access to the entire db, dolt_diff_summary should work + User: "tester", + Host: "localhost", + Query: "SELECT COUNT(*) FROM dolt_diff_summary('main~', 'main');", + Expected: []sql.Row{{1}}, + }, { // Revoke multi-table access User: "root", @@ -797,6 +825,13 @@ var DoltUserPrivTests = []queries.UserPrivilegeTest{ Query: "SELECT * FROM dolt_diff('test', 'main~', 'main');", ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, }, + { + // After revoking access, dolt_diff_summary should fail + User: "tester", + Host: "localhost", + Query: "SELECT * FROM dolt_diff_summary('main~', 'main', 'test');", + ExpectedErr: sql.ErrDatabaseAccessDeniedForUser, + }, { // Grant global access to *.* User: "root", @@ -4851,6 +4886,524 @@ var DiffTableFunctionScriptTests = []queries.ScriptTest{ }, } +var DiffSummaryTableFunctionScriptTests = []queries.ScriptTest{ + { + Name: "invalid arguments", + SetUpScript: []string{ + "create table t (pk int primary key, c1 varchar(20), c2 varchar(20));", + "call dolt_add('.')", + "set @Commit1 = dolt_commit('-am', 'creating table t');", + + "insert into t values(1, 'one', 'two'), (2, 'two', 'three');", + "set @Commit2 = dolt_commit('-am', 'inserting into t');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary('t');", + ExpectedErr: sql.ErrInvalidArgumentNumber, + }, + { + Query: "SELECT * from dolt_diff_summary('t', @Commit1, @Commit2, 'extra');", + ExpectedErr: sql.ErrInvalidArgumentNumber, + }, + { + Query: "SELECT * from dolt_diff_summary(null, null, null);", + ExpectedErr: sql.ErrInvalidArgumentDetails, + }, + { + Query: "SELECT * from dolt_diff_summary(123, @Commit1, @Commit2);", + ExpectedErr: sql.ErrInvalidArgumentDetails, + }, + { + Query: "SELECT * from dolt_diff_summary('t', 123, @Commit2);", + ExpectedErr: sql.ErrInvalidArgumentDetails, + }, + { + Query: "SELECT * from dolt_diff_summary('t', @Commit1, 123);", + ExpectedErr: sql.ErrInvalidArgumentDetails, + }, + { + Query: "SELECT * from dolt_diff_summary('fake-branch', @Commit2, 't');", + ExpectedErrStr: "branch not found: fake-branch", + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, 'fake-branch', 't');", + ExpectedErrStr: "branch not found: fake-branch", + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 'doesnotexist');", + ExpectedErr: sql.ErrTableNotFound, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, concat('fake', '-', 'branch'), 't');", + ExpectedErr: sqle.ErrInvalidNonLiteralArgument, + }, + { + Query: "SELECT * from dolt_diff_summary(hashof('main'), @Commit2, 't');", + ExpectedErr: sqle.ErrInvalidNonLiteralArgument, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, LOWER('T'));", + ExpectedErr: sqle.ErrInvalidNonLiteralArgument, + }, + }, + }, + { + Name: "basic case with single table", + SetUpScript: []string{ + "set @Commit0 = HashOf('HEAD');", + "set @Commit1 = dolt_commit('--allow-empty', '-m', 'creating table t');", + + // create table t only + "create table t (pk int primary key, c1 varchar(20), c2 varchar(20));", + "call dolt_add('.')", + "set @Commit2 = dolt_commit('-am', 'creating table t');", + + // insert 1 row into t + "insert into t values(1, 'one', 'two');", + "set @Commit3 = dolt_commit('-am', 'inserting 1 into table t');", + + // insert 2 rows into t and update two cells + "insert into t values(2, 'two', 'three'), (3, 'three', 'four');", + "update t set c1='uno', c2='dos' where pk=1;", + "set @Commit4 = dolt_commit('-am', 'inserting 2 into table t');", + + // drop table t only + "drop table t;", + "set @Commit5 = dolt_commit('-am', 'drop table t');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // table is added, no data diff, result is empty + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 't');", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit2, @Commit3, 't');", + Expected: []sql.Row{{"t", 0, 1, 0, 0, 3, 0, 0, 0, 1, 0, 3}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit3, @Commit4, 't');", + Expected: []sql.Row{{"t", 0, 2, 0, 1, 6, 0, 2, 1, 3, 3, 9}}, + }, + { + // change from and to commits + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit3, 't');", + Expected: []sql.Row{{"t", 0, 0, 2, 1, 0, 6, 2, 3, 1, 9, 3}}, + }, + { + // table is dropped + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit5, 't');", + Expected: []sql.Row{{"t", 0, 0, 3, 0, 0, 9, 0, 3, 0, 9, 0}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit4, 't');", + Expected: []sql.Row{{"t", 0, 3, 0, 0, 9, 0, 0, 0, 3, 0, 9}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit5, 't');", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "basic case with single keyless table", + SetUpScript: []string{ + "set @Commit0 = HashOf('HEAD');", + "set @Commit1 = dolt_commit('--allow-empty', '-m', 'creating table t');", + + // create table t only + "create table t (id int, c1 varchar(20), c2 varchar(20));", + "call dolt_add('.')", + "set @Commit2 = dolt_commit('-am', 'creating table t');", + + // insert 1 row into t + "insert into t values(1, 'one', 'two');", + "set @Commit3 = dolt_commit('-am', 'inserting 1 into table t');", + + // insert 2 rows into t and update two cells + "insert into t values(2, 'two', 'three'), (3, 'three', 'four');", + "update t set c1='uno', c2='dos' where id=1;", + "set @Commit4 = dolt_commit('-am', 'inserting 2 into table t');", + + // drop table t only + "drop table t;", + "set @Commit5 = dolt_commit('-am', 'drop table t');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + // table is added, no data diff, result is empty + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 't');", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit2, @Commit3, 't');", + Expected: []sql.Row{{"t", nil, 1, 0, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + // TODO : (correct result is commented out) + // update row for keyless table deletes the row and insert the new row + // this causes row added = 3 and row deleted = 1 + Query: "SELECT * from dolt_diff_summary(@Commit3, @Commit4, 't');", + //Expected: []sql.Row{{"t", nil, 2, 0, nil, nil, nil, nil, nil, nil, nil, nil}}, + Expected: []sql.Row{{"t", nil, 3, 1, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit3, 't');", + //Expected: []sql.Row{{"t", nil, 0, 2, nil, nil, nil, nil, nil, nil, nil, nil}}, + Expected: []sql.Row{{"t", nil, 1, 3, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + // table is dropped + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit5, 't');", + Expected: []sql.Row{{"t", nil, 0, 3, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit4, 't');", + Expected: []sql.Row{{"t", nil, 3, 0, nil, nil, nil, nil, nil, nil, nil, nil}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit5, 't');", + ExpectedErr: sql.ErrTableNotFound, + }, + }, + }, + { + Name: "basic case with multiple tables", + SetUpScript: []string{ + "set @Commit0 = HashOf('HEAD');", + + // add table t with 1 row + "create table t (pk int primary key, c1 varchar(20), c2 varchar(20));", + "insert into t values(1, 'one', 'two');", + "call dolt_add('.')", + "set @Commit1 = dolt_commit('-am', 'inserting into table t');", + + // add table t2 with 1 row + "create table t2 (pk int primary key, c1 varchar(20), c2 varchar(20));", + "insert into t2 values(100, 'hundred', 'hundert');", + "call dolt_add('.')", + "set @Commit2 = dolt_commit('-am', 'inserting into table t2');", + + // changes on both tables + "insert into t values(2, 'two', 'three'), (3, 'three', 'four'), (4, 'four', 'five');", + "update t set c1='uno', c2='dos' where pk=1;", + "insert into t2 values(101, 'hundred one', 'one');", + "set @Commit3 = dolt_commit('-am', 'inserting into table t');", + + // changes on both tables + "delete from t where c2 = 'four';", + "update t2 set c2='zero' where pk=100;", + "set @Commit4 = dolt_commit('-am', 'inserting into table t');", + + // create keyless table + "create table keyless (id int);", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary(@Commit0, @Commit1);", + Expected: []sql.Row{{"t", 0, 1, 0, 0, 3, 0, 0, 0, 1, 0, 3}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2);", + Expected: []sql.Row{{"t2", 0, 1, 0, 0, 3, 0, 0, 0, 1, 0, 3}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit2, @Commit3);", + Expected: []sql.Row{{"t", 0, 3, 0, 1, 9, 0, 2, 1, 4, 3, 12}, {"t2", 1, 1, 0, 0, 3, 0, 0, 1, 2, 3, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit3, @Commit4);", + Expected: []sql.Row{{"t", 3, 0, 1, 0, 0, 3, 0, 4, 3, 12, 9}, {"t2", 1, 0, 0, 1, 0, 0, 1, 2, 2, 6, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit2);", + Expected: []sql.Row{{"t", 0, 0, 2, 1, 0, 6, 2, 3, 1, 9, 3}, {"t2", 0, 0, 1, 1, 0, 3, 1, 2, 1, 6, 3}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit3, 'WORKING');", + Expected: []sql.Row{{"t", 3, 0, 1, 0, 0, 3, 0, 4, 3, 12, 9}, {"t2", 1, 0, 0, 1, 0, 0, 1, 2, 2, 6, 6}}, + }, + }, + }, + { + Name: "WORKING and STAGED", + SetUpScript: []string{ + "set @Commit0 = HashOf('HEAD');", + + "create table t (pk int primary key, c1 text, c2 text);", + "call dolt_add('.')", + "insert into t values (1, 'one', 'two'), (2, 'three', 'four');", + "set @Commit1 = dolt_commit('-am', 'inserting two rows into table t');", + + "insert into t values (3, 'five', 'six');", + "delete from t where pk = 2", + "update t set c2 = '100' where pk = 1", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary(@Commit1, 'WORKING', 't')", + Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary('STAGED', 'WORKING', 't')", + Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary('WORKING', 'STAGED', 't')", + Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary('WORKING', 'WORKING', 't')", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * from dolt_diff_summary('STAGED', 'STAGED', 't')", + Expected: []sql.Row{}, + }, + { + Query: "call dolt_add('.')", + SkipResultsCheck: true, + }, + { + Query: "SELECT * from dolt_diff_summary('WORKING', 'STAGED', 't')", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * from dolt_diff_summary('HEAD', 'STAGED', 't')", + Expected: []sql.Row{{"t", 0, 1, 1, 1, 3, 3, 1, 2, 2, 6, 6}}, + }, + }, + }, + { + Name: "diff with branch refs", + SetUpScript: []string{ + "create table t (pk int primary key, c1 varchar(20), c2 varchar(20));", + "call dolt_add('.')", + "set @Commit1 = dolt_commit('-am', 'creating table t');", + + "insert into t values(1, 'one', 'two');", + "set @Commit2 = dolt_commit('-am', 'inserting row 1 into t in main');", + + "select dolt_checkout('-b', 'branch1');", + "alter table t drop column c2;", + "set @Commit3 = dolt_commit('-am', 'dropping column c2 in branch1');", + + "delete from t where pk=1;", + "set @Commit4 = dolt_commit('-am', 'deleting row 1 in branch1');", + + "insert into t values (2, 'two');", + "set @Commit5 = dolt_commit('-am', 'inserting row 2 in branch1');", + + "select dolt_checkout('main');", + "insert into t values (2, 'two', 'three');", + "set @Commit6 = dolt_commit('-am', 'inserting row 2 in main');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary('main', 'branch1', 't');", + Expected: []sql.Row{{"t", 0, 0, 1, 1, 0, 4, 0, 2, 1, 6, 2}}, + }, + { + Query: "SELECT * from dolt_diff_summary('branch1', 'main', 't');", + Expected: []sql.Row{{"t", 0, 1, 0, 1, 4, 0, 1, 1, 2, 2, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary('main~', 'branch1', 't');", + Expected: []sql.Row{{"t", 0, 1, 1, 0, 2, 3, 0, 1, 1, 3, 2}}, + }, + }, + }, + { + Name: "schema modification: drop and add column", + SetUpScript: []string{ + "create table t (pk int primary key, c1 varchar(20), c2 varchar(20));", + "call dolt_add('.');", + "insert into t values (1, 'one', 'two'), (2, 'two', 'three');", + "set @Commit1 = dolt_commit('-am', 'inserting row 1, 2 into t');", + + // drop 1 column and add 1 row + "alter table t drop column c2;", + "set @Commit2 = dolt_commit('-am', 'dropping column c2');", + + // drop 1 column and add 1 row + "insert into t values (3, 'three');", + "set @Commit3 = dolt_commit('-am', 'inserting row 3');", + + // add 1 column and 1 row and update + "alter table t add column c2 varchar(20);", + "insert into t values (4, 'four', 'five');", + "update t set c2='foo' where pk=1;", + "set @Commit4 = dolt_commit('-am', 'adding column c2, inserting, and updating data');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 't');", + Expected: []sql.Row{{"t", 0, 0, 0, 2, 0, 2, 0, 2, 2, 6, 4}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit2, @Commit3, 't');", + Expected: []sql.Row{{"t", 2, 1, 0, 0, 2, 0, 0, 2, 3, 4, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit3, 't');", + Expected: []sql.Row{{"t", 0, 1, 0, 2, 2, 2, 0, 2, 3, 6, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit3, @Commit4, 't');", + Expected: []sql.Row{{"t", 2, 1, 0, 1, 6, 0, 1, 3, 4, 6, 12}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit4, 't');", + Expected: []sql.Row{{"t", 0, 2, 0, 2, 6, 0, 2, 2, 4, 6, 12}}, + }, + }, + }, + { + Name: "schema modification: rename columns", + SetUpScript: []string{ + "create table t (pk int primary key, c1 varchar(20), c2 int);", + "call dolt_add('.')", + "set @Commit1 = dolt_commit('-am', 'creating table t');", + + "insert into t values(1, 'one', -1), (2, 'two', -2);", + "set @Commit2 = dolt_commit('-am', 'inserting into t');", + + "alter table t rename column c2 to c3;", + "set @Commit3 = dolt_commit('-am', 'renaming column c2 to c3');", + + "insert into t values (3, 'three', -3);", + "update t set c3=1 where pk=1;", + "set @Commit4 = dolt_commit('-am', 'inserting and updating data');", + + "alter table t rename column c3 to c2;", + "insert into t values (4, 'four', -4);", + "set @Commit5 = dolt_commit('-am', 'renaming column c3 to c2, and inserting data');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit2, 't');", + Expected: []sql.Row{{"t", 0, 2, 0, 0, 6, 0, 0, 0, 2, 0, 6}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit2, @Commit3, 't');", + Expected: []sql.Row{}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit3, @Commit4, 't');", + Expected: []sql.Row{{"t", 1, 1, 0, 1, 3, 0, 1, 2, 3, 6, 9}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit4, @Commit5, 't');", + Expected: []sql.Row{{"t", 3, 1, 0, 0, 3, 0, 0, 3, 4, 9, 12}}, + }, + { + Query: "SELECT * from dolt_diff_summary(@Commit1, @Commit5, 't');", + Expected: []sql.Row{{"t", 0, 4, 0, 0, 12, 0, 0, 0, 4, 0, 12}}, + }, + }, + }, + { + Name: "new table", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_diff_summary('HEAD', 'WORKING')", + Expected: []sql.Row{}, + }, + { + Query: "select * from dolt_diff_summary('WORKING', 'HEAD')", + Expected: []sql.Row{}, + }, + { + Query: "insert into t1 values (1,2)", + SkipResultsCheck: true, + }, + { + Query: "select * from dolt_diff_summary('HEAD', 'WORKING', 't1')", + Expected: []sql.Row{{"t1", 0, 1, 0, 0, 2, 0, 0, 0, 1, 0, 2}}, + }, + { + Query: "select * from dolt_diff_summary('WORKING', 'HEAD', 't1')", + Expected: []sql.Row{{"t1", 0, 0, 1, 0, 0, 2, 0, 1, 0, 2, 0}}, + }, + }, + }, + { + Name: "dropped table", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "call dolt_add('.')", + "insert into t1 values (1,2)", + "call dolt_commit('-am', 'new table')", + "drop table t1", + "call dolt_commit('-am', 'dropped table')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_diff_summary('HEAD~', 'HEAD', 't1')", + Expected: []sql.Row{{"t1", 0, 0, 1, 0, 0, 2, 0, 1, 0, 2, 0}}, + }, + { + Query: "select * from dolt_diff_summary('HEAD', 'HEAD~', 't1')", + Expected: []sql.Row{{"t1", 0, 1, 0, 0, 2, 0, 0, 0, 1, 0, 2}}, + }, + }, + }, + { + Name: "renamed table", + SetUpScript: []string{ + "create table t1 (a int primary key, b int)", + "call dolt_add('.')", + "insert into t1 values (1,2)", + "call dolt_commit('-am', 'new table')", + "alter table t1 rename to t2", + "call dolt_add('.')", + "insert into t2 values (3,4)", + "call dolt_commit('-am', 'renamed table')", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "select * from dolt_diff_summary('HEAD~', 'HEAD', 't2')", + Expected: []sql.Row{{"t2", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}}, + }, + { + // Old table name can be matched as well + Query: "select * from dolt_diff_summary('HEAD~', 'HEAD', 't1')", + Expected: []sql.Row{{"t1", 1, 1, 0, 0, 2, 0, 0, 1, 2, 2, 4}}, + }, + }, + }, + { + Name: "add multiple columns, then set and unset a value. Should not show a diff", + SetUpScript: []string{ + "CREATE table t (pk int primary key);", + "Insert into t values (1);", + "CALL DOLT_ADD('.');", + "CALL DOLT_COMMIT('-am', 'setup');", + "alter table t add column col1 int;", + "alter table t add column col2 int;", + "CALL DOLT_ADD('.');", + "CALL DOLT_COMMIT('-am', 'add columns');", + "UPDATE t set col1 = 1 where pk = 1;", + "UPDATE t set col1 = null where pk = 1;", + "CALL DOLT_COMMIT('--allow-empty', '-am', 'fix short tuple');", + }, + Assertions: []queries.ScriptTestAssertion{ + { + Query: "SELECT * from dolt_diff_summary('HEAD~2', 'HEAD');", + Expected: []sql.Row{{"t", 1, 0, 0, 0, 2, 0, 0, 1, 1, 1, 3}}, + }, + { + Query: "SELECT * from dolt_diff_summary('HEAD~', 'HEAD');", + Expected: []sql.Row{}, + }, + }, + }, +} + var LargeJsonObjectScriptTests = []queries.ScriptTest{ { Name: "JSON under max length limit", diff --git a/integration-tests/bats/diff.bats b/integration-tests/bats/diff.bats index c3701622a5..416d23273c 100644 --- a/integration-tests/bats/diff.bats +++ b/integration-tests/bats/diff.bats @@ -326,8 +326,10 @@ SQL [[ "$output" =~ "2 Rows Added (100.00%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Rows Modified (0.00%)" ]] || false + [[ "$output" =~ "12 Cells Added (100.00%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Cells Modified (0.00%)" ]] || false - [[ "$output" =~ "(2 Entries vs 4 Entries)" ]] || false + [[ "$output" =~ "(2 Row Entries vs 4 Row Entries)" ]] || false dolt add test dolt commit -m "added two rows" @@ -338,8 +340,10 @@ SQL [[ "$output" =~ "0 Rows Added (0.00%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Row Modified (25.00%)" ]] || false + [[ "$output" =~ "0 Cells Added (0.00%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "2 Cells Modified (8.33%)" ]] || false - [[ "$output" =~ "(4 Entries vs 4 Entries)" ]] || false + [[ "$output" =~ "(4 Row Entries vs 4 Row Entries)" ]] || false dolt add test dolt commit -m "modified first row" @@ -350,8 +354,10 @@ SQL [[ "$output" =~ "0 Rows Added (0.00%)" ]] || false [[ "$output" =~ "1 Row Deleted (25.00%)" ]] || false [[ "$output" =~ "0 Rows Modified (0.00%)" ]] || false + [[ "$output" =~ "0 Cells Added (0.00%)" ]] || false + [[ "$output" =~ "6 Cells Deleted (25.00%)" ]] || false [[ "$output" =~ "0 Cells Modified (0.00%)" ]] || false - [[ "$output" =~ "(4 Entries vs 3 Entries)" ]] || false + [[ "$output" =~ "(4 Row Entries vs 3 Row Entries)" ]] || false } @test "diff: summary comparing row with a deleted cell and an added cell" { @@ -367,8 +373,10 @@ SQL [[ "$output" =~ "0 Rows Added (0.00%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Row Modified (100.00%)" ]] || false + [[ "$output" =~ "0 Cells Added (0.00%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Cell Modified (16.67%)" ]] || false - [[ "$output" =~ "(1 Entry vs 1 Entry)" ]] || false + [[ "$output" =~ "(1 Row Entry vs 1 Row Entry)" ]] || false dolt add test dolt commit -m "row modified" dolt sql -q "replace into test values (0, 1, 2, 3, 4, 5)" @@ -378,8 +386,10 @@ SQL [[ "$output" =~ "0 Rows Added (0.00%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Row Modified (100.00%)" ]] || false + [[ "$output" =~ "0 Cells Added (0.00%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Cell Modified (16.67%)" ]] || false - [[ "$output" =~ "(1 Entry vs 1 Entry)" ]] || false + [[ "$output" =~ "(1 Row Entry vs 1 Row Entry)" ]] || false } @test "diff: summary comparing two branches" { @@ -397,8 +407,10 @@ SQL [[ "$output" =~ "1 Row Added (100.00%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Rows Modified (0.00%)" ]] || false + [[ "$output" =~ "6 Cells Added (100.00%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Cells Modified (0.00%)" ]] || false - [[ "$output" =~ "(1 Entry vs 2 Entries)" ]] || false + [[ "$output" =~ "(1 Row Entry vs 2 Row Entries)" ]] || false } @test "diff: summary shows correct changes after schema change" { @@ -423,8 +435,10 @@ DELIM [[ "$output" =~ "1 Row Added (33.33%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Rows Modified (0.00%)" ]] || false + [[ "$output" =~ "10 Cells Added (55.56%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "0 Cells Modified (0.00%)" ]] || false - [[ "$output" =~ "(3 Entries vs 4 Entries)" ]] || false + [[ "$output" =~ "(3 Row Entries vs 4 Row Entries)" ]] || false dolt sql -q "replace into employees values (0, 'tim', 'sehn', 'ceo', '2 years ago', '', 'Santa Monica')" @@ -435,8 +449,10 @@ DELIM [[ "$output" =~ "1 Row Added (33.33%)" ]] || false [[ "$output" =~ "0 Rows Deleted (0.00%)" ]] || false [[ "$output" =~ "1 Row Modified (33.33%)" ]] || false + [[ "$output" =~ "10 Cells Added (55.56%)" ]] || false + [[ "$output" =~ "0 Cells Deleted (0.00%)" ]] || false [[ "$output" =~ "2 Cells Modified (11.11%)" ]] || false - [[ "$output" =~ "(3 Entries vs 4 Entries)" ]] || false + [[ "$output" =~ "(3 Row Entries vs 4 Row Entries)" ]] || false } @test "diff: summary gets summaries for all tables with changes" {