diff --git a/go/go.mod b/go/go.mod index 23764d5218..4aa35bc80c 100644 --- a/go/go.mod +++ b/go/go.mod @@ -57,7 +57,7 @@ require ( github.com/cespare/xxhash v1.1.0 github.com/creasty/defaults v1.6.0 github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 - github.com/dolthub/go-mysql-server v0.18.2-0.20240604235838-5d11cec1718f + github.com/dolthub/go-mysql-server v0.18.2-0.20240606230452-b64d0222abc5 github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 github.com/dolthub/swiss v0.1.0 github.com/goccy/go-json v0.10.2 diff --git a/go/go.sum b/go/go.sum index e908e3902f..0ffa0aa349 100644 --- a/go/go.sum +++ b/go/go.sum @@ -183,8 +183,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.2-0.20240604235838-5d11cec1718f h1:e2Xyty29+ht/mL8ffvPyeKiVjaFcTo3N1OYuj6EnlmA= -github.com/dolthub/go-mysql-server v0.18.2-0.20240604235838-5d11cec1718f/go.mod h1:GT7JcQavIf7bAO17/odujkgHM/N0t4b1HfAPBJ2jzXo= +github.com/dolthub/go-mysql-server v0.18.2-0.20240606230452-b64d0222abc5 h1:UX5N4VwYOrrPNF15IJIlWp7ka27K6/nJSuMqzU7d04Y= +github.com/dolthub/go-mysql-server v0.18.2-0.20240606230452-b64d0222abc5/go.mod h1:GT7JcQavIf7bAO17/odujkgHM/N0t4b1HfAPBJ2jzXo= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= diff --git a/go/libraries/doltcore/remotesrv/grpc.go b/go/libraries/doltcore/remotesrv/grpc.go index 751fc31a76..e9aa6d1d6e 100644 --- a/go/libraries/doltcore/remotesrv/grpc.go +++ b/go/libraries/doltcore/remotesrv/grpc.go @@ -337,10 +337,20 @@ func (rs *RemoteChunkStore) getHost(md metadata.MD) string { return host } +func (rs *RemoteChunkStore) getScheme(md metadata.MD) string { + scheme := rs.httpScheme + forwardedSchemes := md.Get("x-forwarded-proto") + if len(forwardedSchemes) > 0 { + scheme = forwardedSchemes[0] + } + return scheme +} + func (rs *RemoteChunkStore) getDownloadUrl(md metadata.MD, path string) *url.URL { host := rs.getHost(md) + scheme := rs.getScheme(md) return &url.URL{ - Scheme: rs.httpScheme, + Scheme: scheme, Host: host, Path: path, } diff --git a/go/libraries/doltcore/remotesrv/grpc_test.go b/go/libraries/doltcore/remotesrv/grpc_test.go new file mode 100644 index 0000000000..44bc07104d --- /dev/null +++ b/go/libraries/doltcore/remotesrv/grpc_test.go @@ -0,0 +1,36 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package remotesrv + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" +) + +func TestGRPCSchemeSelection(t *testing.T) { + rs := &RemoteChunkStore{ + httpScheme: "http", + } + + md := metadata.New(nil) + scheme := rs.getScheme(md) + assert.Equal(t, scheme, "http") + + md.Append("x-forwarded-proto", "https") + scheme = rs.getScheme(md) + assert.Equal(t, scheme, "https") +} diff --git a/go/libraries/doltcore/sqle/database.go b/go/libraries/doltcore/sqle/database.go index 29e14d1e93..cc2d85fbf6 100644 --- a/go/libraries/doltcore/sqle/database.go +++ b/go/libraries/doltcore/sqle/database.go @@ -677,29 +677,17 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s } } - var tbl *doltdb.Table - // TODO: dolt_schemas needs work to be compatible with multiple schemas - if resolve.UseSearchPath && db.schemaName == "" && !doltdb.HasDoltPrefix(tableName) { - var tblName doltdb.TableName - tblName, tbl, ok, err = resolve.TableWithSearchPath(ctx, root, tableName) - if err != nil { - return nil, false, err - } else if !ok { - return nil, false, nil - } - - // For the remainder of this method, we will use the schema name that was resolved and the table resolved - // will inherit it - db.schemaName = tblName.Schema - } else { - tableName, tbl, ok, err = db.resolveTable(ctx, root, tableName) - if err != nil { - return nil, false, err - } else if !ok { - return nil, false, nil - } + tblName, tbl, tblExists, err := db.resolveUserTable(ctx, root, tableName) + if err != nil { + return nil, false, err + } else if !tblExists { + return nil, false, nil } + tableName = tblName.Name + // for remainder of this operation, all db operations will use the name resolved here + db.schemaName = tblName.Schema + sch, err := tbl.GetSchema(ctx) if err != nil { return nil, false, err @@ -729,27 +717,57 @@ func (db Database) getTable(ctx *sql.Context, root doltdb.RootValue, tableName s return table, true, nil } -func (db Database) resolveTable(ctx *sql.Context, root doltdb.RootValue, tableName string) (string, *doltdb.Table, bool, error) { +// resolveUserTable returns the table with the given name from the root given. The table name is resolved in a +// case-insensitive manner. The table is returned along with its case-sensitive matched name. An error is returned if +// no such table exists. +func (db Database) resolveUserTable(ctx *sql.Context, root doltdb.RootValue, tableName string) (doltdb.TableName, *doltdb.Table, bool, error) { + var tbl *doltdb.Table + var tblName doltdb.TableName + var tblExists bool + + // TODO: dolt_schemas needs work to be compatible with multiple schemas + if resolve.UseSearchPath && db.schemaName == "" && !doltdb.HasDoltPrefix(tableName) { + var err error + tblName, tbl, tblExists, err = resolve.TableWithSearchPath(ctx, root, tableName) + if err != nil { + return doltdb.TableName{}, nil, false, err + } + } else { + var err error + tblName, tbl, tblExists, err = db.tableInsensitive(ctx, root, tableName) + if err != nil { + return doltdb.TableName{}, nil, false, err + } + } + + return tblName, tbl, tblExists, nil +} + +// tableInsensitive returns the name of this table in the root given with the db's schema name, if it exists. +// Name matching is applied in a case-insensitive manner, and the table's case-corrected name is returned as the +// first result. +func (db Database) tableInsensitive(ctx *sql.Context, root doltdb.RootValue, tableName string) (doltdb.TableName, *doltdb.Table, bool, error) { tableNames, err := db.getAllTableNames(ctx, root) if err != nil { - return "", nil, false, err + return doltdb.TableName{}, nil, false, err } tableName, ok := sql.GetTableNameInsensitive(tableName, tableNames) if !ok { - return "", nil, false, nil + return doltdb.TableName{}, nil, false, nil } // TODO: should we short-circuit the schema name for system tables? - tbl, ok, err := root.GetTable(ctx, doltdb.TableName{Name: tableName, Schema: db.schemaName}) + tname := doltdb.TableName{Name: tableName, Schema: db.schemaName} + tbl, ok, err := root.GetTable(ctx, tname) if err != nil { - return "", nil, false, err + return doltdb.TableName{}, nil, false, err } else if !ok { // Should be impossible - return "", nil, false, doltdb.ErrTableNotFound + return doltdb.TableName{}, nil, false, doltdb.ErrTableNotFound } - return tableName, tbl, true, nil + return tname, tbl, true, nil } // newDoltTable returns a sql.Table wrapping the given underlying dolt table @@ -901,16 +919,18 @@ func (db Database) dropTable(ctx *sql.Context, tableName string) error { } root := ws.WorkingRoot() - tbl, tableExists, err := root.GetTable(ctx, doltdb.TableName{Name: tableName}) + tblName, tbl, tblExists, err := db.resolveUserTable(ctx, root, tableName) if err != nil { return err - } - - if !tableExists { + } else if !tblExists { return sql.ErrTableNotFound.New(tableName) } - newRoot, err := root.RemoveTables(ctx, true, false, doltdb.TableName{Name: tableName, Schema: db.schemaName}) + tableName = tblName.Name + // for remainder of this operation, all db operations will use the name resolved here + db.schemaName = tblName.Schema + + newRoot, err := root.RemoveTables(ctx, true, false, tblName) if err != nil { return err } diff --git a/go/libraries/doltcore/sqle/tables.go b/go/libraries/doltcore/sqle/tables.go index 62a94ab5c0..f5a90bcde0 100644 --- a/go/libraries/doltcore/sqle/tables.go +++ b/go/libraries/doltcore/sqle/tables.go @@ -2709,6 +2709,7 @@ func (t *WritableDoltTable) UpdateForeignKey(ctx *sql.Context, fkName string, sq return sql.ErrForeignKeyNotFound.New(fkName, t.tableName) } fkc.RemoveKeyByName(doltFk.Name) + doltFk.Name = sqlFk.Name doltFk.TableName = sqlFk.Table doltFk.ReferencedTableName = sqlFk.ParentTable doltFk.UnresolvedFKDetails.TableColumns = sqlFk.Columns diff --git a/go/store/types/incremental_test.go b/go/store/types/incremental_test.go index cb00959d9d..6203807631 100644 --- a/go/store/types/incremental_test.go +++ b/go/store/types/incremental_test.go @@ -58,6 +58,7 @@ func TestIncrementalLoadList(t *testing.T) { ts := &chunks.TestStorage{} cs := ts.NewView() vs := NewValueStore(cs) + vs.skipWriteCaching = true expected, err := NewList(context.Background(), vs, getTestVals(vs)...) require.NoError(t, err) diff --git a/go/store/types/list_test.go b/go/store/types/list_test.go index a13859c3cc..dbccaadf10 100644 --- a/go/store/types/list_test.go +++ b/go/store/types/list_test.go @@ -1134,6 +1134,8 @@ func TestListDiffLargeWithSameMiddle(t *testing.T) { cs1 := storage.NewView() vs1 := NewValueStore(cs1) + vs1.skipWriteCaching = true + nums1 := generateNumbersAsValues(vs1.Format(), 4000) l1, err := NewList(context.Background(), vs1, nums1...) require.NoError(t, err) @@ -1151,6 +1153,8 @@ func TestListDiffLargeWithSameMiddle(t *testing.T) { cs2 := storage.NewView() vs2 := NewValueStore(cs2) + vs2.skipWriteCaching = true + nums2 := generateNumbersAsValuesFromToBy(vs2.Format(), 5, 3550, 1) l2, err := NewList(context.Background(), vs2, nums2...) require.NoError(t, err) diff --git a/go/store/types/value_store.go b/go/store/types/value_store.go index d41468f2f2..a28a68a752 100644 --- a/go/store/types/value_store.go +++ b/go/store/types/value_store.go @@ -78,6 +78,7 @@ type ValueStore struct { decodedChunks *sizecache.SizeCache nbf *NomsBinFormat versOnce sync.Once + skipWriteCaching bool gcMu sync.Mutex gcCond *sync.Cond @@ -343,6 +344,10 @@ func (lvs *ValueStore) WriteValue(ctx context.Context, v Value) (Ref, error) { return Ref{}, err } + if !lvs.skipWriteCaching { + lvs.decodedChunks.Add(c.Hash(), uint64(c.Size()), v) + } + return r, nil } diff --git a/go/store/types/value_store_test.go b/go/store/types/value_store_test.go index 51d0139f80..5f15c3f928 100644 --- a/go/store/types/value_store_test.go +++ b/go/store/types/value_store_test.go @@ -37,6 +37,7 @@ func TestValueReadWriteRead(t *testing.T) { s := String("hello") vs := newTestValueStore() + vs.skipWriteCaching = true assert.Nil(vs.ReadValue(context.Background(), mustHash(s.Hash(vs.Format())))) // nil h := mustRef(vs.WriteValue(context.Background(), s)).TargetHash() rt, err := vs.Root(context.Background()) @@ -55,6 +56,7 @@ func TestReadWriteCache(t *testing.T) { storage := &chunks.TestStorage{} ts := storage.NewView() vs := NewValueStore(ts) + vs.skipWriteCaching = true var v Value = Bool(true) r, err := vs.WriteValue(context.Background(), v) @@ -82,6 +84,7 @@ func TestValueReadMany(t *testing.T) { vals := ValueSlice{String("hello"), Bool(true), Float(42)} vs := newTestValueStore() + vs.skipWriteCaching = true hashes := hash.HashSlice{} for _, v := range vals { h := mustRef(vs.WriteValue(context.Background(), v)).TargetHash() @@ -149,6 +152,7 @@ func TestPanicOnBadVersion(t *testing.T) { func TestErrorIfDangling(t *testing.T) { t.Skip("WriteValue errors with dangling ref error") vs := newTestValueStore() + vs.skipWriteCaching = true r, err := NewRef(Bool(true), vs.Format()) require.NoError(t, err) @@ -168,6 +172,7 @@ func TestGC(t *testing.T) { ctx := context.Background() vs := newTestValueStore() + vs.skipWriteCaching = true r1 := mustRef(vs.WriteValue(ctx, String("committed"))) r2 := mustRef(vs.WriteValue(ctx, String("unreferenced"))) set1 := mustSet(NewSet(ctx, vs, r1))