merge with main

This commit is contained in:
James Cor
2023-02-20 10:01:38 -08:00
65 changed files with 1212 additions and 903 deletions
+18 -1
View File
@@ -145,7 +145,7 @@ func (cmd ImportCmd) ArgParser() *argparser.ArgParser {
ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"table", "Name of the table to be created."})
ap.ArgListHelp = append(ap.ArgListHelp, [2]string{"file", "The file being used to infer the schema."})
ap.SupportsFlag(createFlag, "c", "Create a table with the schema inferred from the {{.LessThan}}file{{.GreaterThan}} provided.")
ap.SupportsFlag(updateFlag, "u", "Update a table to match the inferred schema of the {{.LessThan}}file{{.GreaterThan}} provided")
ap.SupportsFlag(updateFlag, "u", "Update a table to match the inferred schema of the {{.LessThan}}file{{.GreaterThan}} provided. All previous data will be lost.")
ap.SupportsFlag(replaceFlag, "r", "Replace a table with a new schema that has the inferred schema from the {{.LessThan}}file{{.GreaterThan}} provided. All previous data will be lost.")
ap.SupportsFlag(dryRunFlag, "", "Print the sql statement that would be run if executed without the flag.")
ap.SupportsFlag(keepTypesParam, "", "When a column already exists in the table, and it's also in the {{.LessThan}}file{{.GreaterThan}} provided, use the type from the table.")
@@ -219,6 +219,23 @@ func getSchemaImportArgs(ctx context.Context, apr *argparser.ArgParseResults, dE
return nil, errhand.BuildDError("error: failed to create table.").AddDetails("A table named '%s' already exists.", tblName).AddDetails("Use --replace or --update instead of --create.").Build()
}
if op != CreateOp {
rows, err := tbl.GetRowData(ctx)
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
}
rowCnt, err := rows.Count()
if err != nil {
return nil, errhand.VerboseErrorFromError(err)
}
if rowCnt > 0 {
return nil, errhand.BuildDError("This operation will delete all row data. If this is your intent, "+
"run dolt sql -q 'delete from %s' to delete all row data, then re-run this command.", tblName).Build()
}
}
var existingSch schema.Schema = schema.EmptySchema
if tblExists {
existingSch, err = tbl.GetSchema(ctx)
+36 -8
View File
@@ -19,6 +19,7 @@ import (
"net/http"
"os"
"strings"
"sync"
"testing"
_ "github.com/go-sql-driver/mysql"
@@ -61,10 +62,12 @@ var (
func TestServerArgs(t *testing.T) {
serverController := NewServerController()
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
go func() {
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
startServer(context.Background(), "0.0.0", "dolt sql-server", []string{
"-H", "localhost",
"-P", "15200",
@@ -75,7 +78,7 @@ func TestServerArgs(t *testing.T) {
"-r",
}, dEnv, serverController)
}()
err := serverController.WaitForStart()
err = serverController.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
@@ -103,17 +106,20 @@ listener:
read_timeout_millis: 5000
write_timeout_millis: 5000
`
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
serverController := NewServerController()
go func() {
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
dEnv.FS.WriteFile("config.yaml", []byte(yamlConfig))
startServer(context.Background(), "0.0.0", "dolt sql-server", []string{
"--config", "config.yaml",
}, dEnv, serverController)
}()
err := serverController.WaitForStart()
err = serverController.WaitForStart()
require.NoError(t, err)
conn, err := dbr.Open("mysql", "username:password@tcp(localhost:15200)/", nil)
require.NoError(t, err)
@@ -127,6 +133,9 @@ listener:
func TestServerBadArgs(t *testing.T) {
env, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, env.DoltDB.Close())
}()
tests := [][]string{
{"-H", "127.0.0.0.1"},
@@ -156,6 +165,9 @@ func TestServerBadArgs(t *testing.T) {
func TestServerGoodParams(t *testing.T) {
env, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, env.DoltDB.Close())
}()
tests := []ServerConfig{
DefaultServerConfig(),
@@ -195,6 +207,9 @@ func TestServerGoodParams(t *testing.T) {
func TestServerSelect(t *testing.T) {
env, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, env.DoltDB.Close())
}()
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15300)
@@ -254,8 +269,16 @@ func TestServerFailsIfPortInUse(t *testing.T) {
}
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
go server.ListenAndServe()
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
server.ListenAndServe()
}()
go func() {
startServer(context.Background(), "test", "dolt sql-server", []string{
"-H", "localhost",
@@ -271,11 +294,15 @@ func TestServerFailsIfPortInUse(t *testing.T) {
err = serverController.WaitForStart()
require.Error(t, err)
server.Close()
wg.Wait()
}
func TestServerSetDefaultBranch(t *testing.T) {
dEnv, err := sqle.CreateEnvWithSeedData()
require.NoError(t, err)
defer func() {
assert.NoError(t, dEnv.DoltDB.Close())
}()
serverConfig := DefaultServerConfig().withLogLevel(LogLevel_Fatal).WithPort(15302)
@@ -408,6 +435,7 @@ func TestReadReplica(t *testing.T) {
defer os.Chdir(cwd)
multiSetup := testcommands.NewMultiRepoTestSetup(t.Fatal)
defer multiSetup.Close()
defer os.RemoveAll(multiSetup.Root)
multiSetup.NewDB("read_replica")
+2 -2
View File
@@ -15,7 +15,7 @@ require (
github.com/dolthub/fslock v0.0.3
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869
github.com/dolthub/vitess v0.0.0-20230216234925-189ffe819e56
github.com/dustin/go-humanize v1.0.0
github.com/fatih/color v1.13.0
github.com/flynn-archive/go-shlex v0.0.0-20150515145356-3f9db97f8568
@@ -58,7 +58,7 @@ require (
github.com/cenkalti/backoff/v4 v4.1.3
github.com/cespare/xxhash v1.1.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/go-mysql-server v0.14.1-0.20230217230235-6c5f5e129a67
github.com/dolthub/go-mysql-server v0.14.1-0.20230217225532-09205e0f234f
github.com/google/flatbuffers v2.0.6+incompatible
github.com/jmoiron/sqlx v1.3.4
github.com/kch42/buzhash v0.0.0-20160816060738-9bdec3dec7c6
+4 -4
View File
@@ -166,16 +166,16 @@ github.com/dolthub/flatbuffers v1.13.0-dh.1 h1:OWJdaPep22N52O/0xsUevxJ6Qfw1M2txC
github.com/dolthub/flatbuffers v1.13.0-dh.1/go.mod h1:CorYGaDmXjHz1Z7i50PYXG1Ricn31GcA2wNOTFIQAKE=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-mysql-server v0.14.1-0.20230217230235-6c5f5e129a67 h1:5RItVRX5BkR5v+PN+MFbIXDlLRuLrR9HQRye2ypQbo4=
github.com/dolthub/go-mysql-server v0.14.1-0.20230217230235-6c5f5e129a67/go.mod h1:3PGGtLcVPnJumgozqqAKZPae88QmvkOd1KGS+Z2/RXU=
github.com/dolthub/go-mysql-server v0.14.1-0.20230217225532-09205e0f234f h1:yuOrpt0Gwf8aYe7SmimteAjt/eNyUvBeaNRCq0RCMfA=
github.com/dolthub/go-mysql-server v0.14.1-0.20230217225532-09205e0f234f/go.mod h1:BRFyf6PUuoR+iSLZ+JdpjtqgHzo5cT+tF7oHIpVdytY=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474 h1:xTrR+l5l+1Lfq0NvhiEsctylXinUMFhhsqaEcl414p8=
github.com/dolthub/jsonpath v0.0.0-20210609232853-d49537a30474/go.mod h1:kMz7uXOXq4qRriCEyZ/LUeTqraLJCjf0WVZcUi6TxUY=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869 h1:RiSFAJqwBJmFbISgxWEdpljUak1uFtNCKG0zGT8xzA4=
github.com/dolthub/vitess v0.0.0-20230210003150-3065f526d869/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs=
github.com/dolthub/vitess v0.0.0-20230216234925-189ffe819e56 h1:dHuKfUwaDUe847BVN3Wo+4GUGUNdlhuUif4RWkvG3Go=
github.com/dolthub/vitess v0.0.0-20230216234925-189ffe819e56/go.mod h1:oVFIBdqMFEkt4Xz2fzFJBNtzKhDEjwdCF0dzde39iKs=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
+14 -6
View File
@@ -66,6 +66,12 @@ func IsReadOnlySystemTable(name string) bool {
return HasDoltPrefix(name) && !set.NewStrSet(writeableSystemTables).Contains(name)
}
// IsNonAlterableSystemTable returns whether the table name given is a system table that cannot be dropped or altered
// by the user.
func IsNonAlterableSystemTable(name string) bool {
return IsReadOnlySystemTable(name) || strings.ToLower(name) == SchemasTableName
}
// GetNonSystemTableNames gets non-system table names
func GetNonSystemTableNames(ctx context.Context, root *RootValue) ([]string, error) {
tn, err := root.GetTableNames(ctx)
@@ -224,17 +230,19 @@ const (
// SchemasTableName is the name of the dolt schema fragment table
SchemasTableName = "dolt_schemas"
// SchemasTablesIdCol is an incrementing integer that represents the insertion index.
// Deprecated: This column is no longer used and will be removed in a future release.
SchemasTablesIdCol = "id"
// Currently: `view` or `trigger`.
// SchemasTablesTypeCol is the name of the column that stores the type of a schema fragment in the dolt_schemas table
SchemasTablesTypeCol = "type"
// The name of the database entity.
// SchemasTablesNameCol The name of the column that stores the name of a schema fragment in the dolt_schemas table
SchemasTablesNameCol = "name"
// The schema fragment associated with the database entity.
// For example, the SELECT statement for a CREATE VIEW.
// SchemasTablesFragmentCol The name of the column that stores the SQL fragment of a schema element in the
// dolt_schemas table
SchemasTablesFragmentCol = "fragment"
// The extra information for schema; currently contains creation time for triggers and views
// SchemasTablesExtraCol The name of the column that stores extra information about a schema element in the
// dolt_schemas table
SchemasTablesExtraCol = "extra"
// The name of the index that is on the table.
//
SchemasTablesIndexName = "fragment_name"
)
@@ -84,6 +84,15 @@ func (mr *MultiRepoTestSetup) homeProv() (string, error) {
return mr.Home, nil
}
func (mr *MultiRepoTestSetup) Close() {
for _, db := range mr.DoltDBs {
err := db.Close()
if err != nil {
mr.Errhand(err)
}
}
}
func (mr *MultiRepoTestSetup) Cleanup(dbName string) {
os.RemoveAll(mr.Root)
}
@@ -107,7 +107,7 @@ func mergeProllySecondaryIndexes(
leftSet, rightSet durable.IndexSet,
finalSch schema.Schema,
finalRows durable.Index,
artifacts prolly.ArtifactsEditor,
artifacts *prolly.ArtifactsEditor,
) (durable.IndexSet, error) {
ancSet, err := tm.ancTbl.GetIndexSet(ctx)
@@ -197,7 +197,7 @@ func mergeProllySecondaryIndexes(
return mergedIndexSet, nil
}
func buildIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, postMergeSchema schema.Schema, index schema.Index, m prolly.Map, artEditor prolly.ArtifactsEditor, theirRootIsh doltdb.Rootish, tblName string) (durable.Index, error) {
func buildIndex(ctx context.Context, vrw types.ValueReadWriter, ns tree.NodeStore, postMergeSchema schema.Schema, index schema.Index, m prolly.Map, artEditor *prolly.ArtifactsEditor, theirRootIsh doltdb.Rootish, tblName string) (durable.Index, error) {
if index.IsUnique() {
meta, err := makeUniqViolMeta(postMergeSchema, index)
if err != nil {
@@ -377,7 +377,7 @@ func (m *valueMerger) processColumn(i int, left, right, base val.Tuple) ([]byte,
}
type conflictProcessor interface {
process(ctx context.Context, conflictChan chan confVals, artEditor prolly.ArtifactsEditor) error
process(ctx context.Context, conflictChan chan confVals, artEditor *prolly.ArtifactsEditor) error
}
func makeConflictProcessor(ctx context.Context, tm TableMerger) (conflictProcessor, error) {
@@ -434,7 +434,7 @@ func newInsertingProcessor(theirRootIsh, baseRootIsh doltdb.Rootish) (*inserting
return &p, nil
}
func (p *insertingProcessor) process(ctx context.Context, conflictChan chan confVals, artEditor prolly.ArtifactsEditor) error {
func (p *insertingProcessor) process(ctx context.Context, conflictChan chan confVals, artEditor *prolly.ArtifactsEditor) error {
for {
select {
case conflict, ok := <-conflictChan:
@@ -453,7 +453,7 @@ func (p *insertingProcessor) process(ctx context.Context, conflictChan chan conf
type abortingProcessor struct{}
func (p abortingProcessor) process(ctx context.Context, conflictChan chan confVals, artEditor prolly.ArtifactsEditor) error {
func (p abortingProcessor) process(ctx context.Context, conflictChan chan confVals, _ *prolly.ArtifactsEditor) error {
select {
case _, ok := <-conflictChan:
if !ok {
+1 -1
View File
@@ -212,7 +212,7 @@ type foreignKeyViolationWriter struct {
currTbl *doltdb.Table
// prolly
artEditor prolly.ArtifactsEditor
artEditor *prolly.ArtifactsEditor
kd val.TupleDesc
cInfoJsonData []byte
@@ -39,7 +39,7 @@ func addUniqIdxViols(
index schema.Index,
left, right, base prolly.Map,
m prolly.Map,
artEditor prolly.ArtifactsEditor,
artEditor *prolly.ArtifactsEditor,
theirRootIsh doltdb.Rootish,
tblName string) error {
@@ -153,7 +153,7 @@ func (m UniqCVMeta) PrettyPrint() string {
return jsonStr
}
func replaceUniqueKeyViolation(ctx context.Context, edt prolly.ArtifactsEditor, m prolly.Map, k val.Tuple, kd val.TupleDesc, theirRootIsh doltdb.Rootish, vInfo []byte, tblName string) error {
func replaceUniqueKeyViolation(ctx context.Context, edt *prolly.ArtifactsEditor, m prolly.Map, k val.Tuple, kd val.TupleDesc, theirRootIsh doltdb.Rootish, vInfo []byte, tblName string) error {
var value val.Tuple
err := m.Get(ctx, k, func(_, v val.Tuple) error {
value = v
@@ -52,6 +52,10 @@ var (
StringDefaultType = &varStringType{gmstypes.MustCreateStringWithDefaults(sqltypes.VarChar, 16383)}
)
func CreateVarStringTypeFromSqlType(stringType sql.StringType) TypeInfo {
return &varStringType{stringType}
}
func CreateVarStringTypeFromParams(params map[string]string) (TypeInfo, error) {
var length int64
var collation sql.CollationID
+53 -24
View File
@@ -425,7 +425,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
}
}
dt, found = dtables.NewUnscopedDiffTable(ctx, db.ddb, head), true
dt, found = dtables.NewUnscopedDiffTable(ctx, db.name, db.ddb, head), true
case doltdb.TableOfTablesInConflictName:
dt, found = dtables.NewTableOfTablesInConflict(ctx, db.name, db.ddb), true
case doltdb.TableOfTablesWithViolationsName:
@@ -467,6 +467,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
}
}
}
if found {
return dt, found, nil
}
@@ -732,10 +733,15 @@ func (db Database) DropTable(ctx *sql.Context, tableName string) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
if doltdb.IsReadOnlySystemTable(tableName) {
if doltdb.IsNonAlterableSystemTable(tableName) {
return ErrSystemTableAlter.New(tableName)
}
return db.dropTable(ctx, tableName)
}
// dropTable drops the table with the name given, without any business logic checks
func (db Database) dropTable(ctx *sql.Context, tableName string) error {
ds := dsess.DSessFromSess(ctx.Session)
if _, ok := ds.GetTemporaryTable(ctx, db.Name(), tableName); ok {
ds.DropTemporaryTable(ctx, db.Name(), tableName)
@@ -1025,7 +1031,7 @@ func (db Database) RenameTable(ctx *sql.Context, oldName, newName string) error
return err
}
if doltdb.IsReadOnlySystemTable(oldName) {
if doltdb.IsNonAlterableSystemTable(oldName) {
return ErrSystemTableAlter.New(oldName)
}
@@ -1269,7 +1275,7 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
tbl, err := GetOrCreateDoltSchemasTable(ctx, db)
tbl, err := getOrCreateDoltSchemasTable(ctx, db)
if err != nil {
return err
}
@@ -1282,24 +1288,6 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
return existingErr
}
sess := dsess.DSessFromSess(ctx.Session)
dbState, _, err := sess.LookupDbState(ctx, db.Name())
if err != nil {
return err
}
ts := dbState.WriteSession
ws, err := ts.Flush(ctx)
if err != nil {
return err
}
// If rows exist, then grab the highest id and add 1 to get the new id
idx, err := nextSchemasTableIndex(ctx, ws.WorkingRoot())
if err != nil {
return err
}
// Insert the new row into the db
inserter := tbl.Inserter(ctx)
defer func() {
@@ -1316,13 +1304,15 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
if err != nil {
return err
}
return inserter.Insert(ctx, sql.Row{fragType, name, definition, idx, extraJSON})
return inserter.Insert(ctx, sql.Row{fragType, name, definition, extraJSON})
}
func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
return err
}
stbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
if err != nil {
return err
@@ -1345,7 +1335,46 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str
return err
}
return deleter.Close(ctx)
err = deleter.Close(ctx)
if err != nil {
return err
}
// If the dolt schemas table is now empty, drop it entirely. This is necessary to prevent the creation and
// immediate dropping of views or triggers, when none previously existed, from changing the database state.
return db.dropTableIfEmpty(ctx, doltdb.SchemasTableName)
}
// dropTableIfEmpty drops the table named if it exists and has at least one row.
func (db Database) dropTableIfEmpty(ctx *sql.Context, tableName string) error {
stbl, found, err := db.GetTableInsensitive(ctx, tableName)
if err != nil {
return err
}
if !found {
return nil
}
table, err := stbl.(*WritableDoltTable).DoltTable.DoltTable(ctx)
if err != nil {
return err
}
rows, err := table.GetRowData(ctx)
if err != nil {
return err
}
numRows, err := rows.Count()
if err != nil {
return err
}
if numRows == 0 {
return db.dropTable(ctx, tableName)
}
return nil
}
// GetAllTemporaryTables returns all temporary tables
@@ -539,7 +539,7 @@ type prollyConflictDeleter struct {
kd, vd val.TupleDesc
kB, vB *val.TupleBuilder
pool pool.BuffPool
ed prolly.ArtifactsEditor
ed *prolly.ArtifactsEditor
ct ProllyConflictsTable
rs RootSetter
ourDiffTypeIdx int
@@ -217,7 +217,7 @@ type prollyCVDeleter struct {
kd val.TupleDesc
kb *val.TupleBuilder
pool pool.BuffPool
ed prolly.ArtifactsEditor
ed *prolly.ArtifactsEditor
cvt *prollyConstraintViolationsTable
}
@@ -45,6 +45,7 @@ var _ sql.FilteredTable = (*UnscopedDiffTable)(nil)
// UnscopedDiffTable is a sql.Table implementation of a system table that shows which tables have
// changed in each commit, across all branches.
type UnscopedDiffTable struct {
dbName string
ddb *doltdb.DoltDB
head *doltdb.Commit
partitionFilters []sql.Expression
@@ -60,8 +61,8 @@ type tableChange struct {
}
// NewUnscopedDiffTable creates an UnscopedDiffTable
func NewUnscopedDiffTable(_ *sql.Context, ddb *doltdb.DoltDB, head *doltdb.Commit) sql.Table {
return &UnscopedDiffTable{ddb: ddb, head: head}
func NewUnscopedDiffTable(_ *sql.Context, dbName string, ddb *doltdb.DoltDB, head *doltdb.Commit) sql.Table {
return &UnscopedDiffTable{dbName: dbName, ddb: ddb, head: head}
}
// Filters returns the list of filters that are applied to this table.
@@ -193,9 +194,9 @@ func (dt *UnscopedDiffTable) LookupPartitions(ctx *sql.Context, lookup sql.Index
func (dt *UnscopedDiffTable) newWorkingSetRowItr(ctx *sql.Context) (sql.RowIter, error) {
sess := dsess.DSessFromSess(ctx.Session)
roots, ok := sess.GetRoots(ctx, ctx.GetCurrentDatabase())
roots, ok := sess.GetRoots(ctx, dt.dbName)
if !ok {
return nil, fmt.Errorf("unable to lookup roots for database %s", ctx.GetCurrentDatabase())
return nil, fmt.Errorf("unable to lookup roots for database %s", dt.dbName)
}
staged, unstaged, err := diff.GetStagedUnstagedTableDeltas(ctx, roots)
@@ -744,10 +744,35 @@ var DoltScripts = []queries.ScriptTest{
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "SELECT type, name, fragment, id FROM dolt_schemas ORDER BY 1, 2",
Query: "SELECT type, name, fragment FROM dolt_schemas ORDER BY 1, 2",
Expected: []sql.Row{
{"view", "view1", "CREATE VIEW view1 AS SELECT v1 FROM viewtest", int64(1)},
{"view", "view2", "CREATE VIEW view2 AS SELECT v2 FROM viewtest", int64(2)},
{"view", "view1", "CREATE VIEW view1 AS SELECT v1 FROM viewtest"},
{"view", "view2", "CREATE VIEW view2 AS SELECT v2 FROM viewtest"},
},
},
{
Query: "CREATE VIEW VIEW1 AS SELECT v2 FROM viewtest",
ExpectedErr: sql.ErrExistingView,
},
{
Query: "drop view view1",
SkipResultsCheck: true,
},
{
Query: "SELECT type, name, fragment FROM dolt_schemas ORDER BY 1, 2",
Expected: []sql.Row{
{"view", "view2", "CREATE VIEW view2 AS SELECT v2 FROM viewtest"},
},
},
{
Query: "CREATE VIEW VIEW1 AS SELECT v1 FROM viewtest",
SkipResultsCheck: true,
},
{
Query: "SELECT type, name, fragment FROM dolt_schemas ORDER BY 1, 2",
Expected: []sql.Row{
{"view", "view1", "CREATE VIEW VIEW1 AS SELECT v1 FROM viewtest"},
{"view", "view2", "CREATE VIEW view2 AS SELECT v2 FROM viewtest"},
},
},
},
@@ -995,20 +995,18 @@ var MergeScripts = []queries.ScriptTest{
"CALL dolt_checkout('other')",
"CREATE TRIGGER trigger3 BEFORE INSERT ON x FOR EACH ROW SET new.a = (new.a * 2) + 100",
"CREATE TRIGGER trigger4 BEFORE INSERT ON x FOR EACH ROW SET new.a = (new.a * 2) + 1000",
"UPDATE dolt_schemas SET id = id + 1 WHERE name = 'trigger4'",
"CALL dolt_commit('-am', 'created triggers 3 & 4 on other');",
"CALL dolt_checkout('main');",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "CALL DOLT_MERGE('other');",
Expected: []sql.Row{{0, 1}},
Expected: []sql.Row{{0, 0}},
},
{
Query: "select count(*) from dolt_schemas where type = 'trigger';",
Expected: []sql.Row{{4}},
},
// todo: merge triggers correctly
//{
// Query: "select count(*) from dolt_schemas where type = 'trigger';",
// Expected: []sql.Row{{4}},
//},
},
},
{
+151 -216
View File
@@ -17,24 +17,18 @@ package sqle
import (
"fmt"
"io"
"strings"
"time"
"github.com/dolthub/go-mysql-server/sql"
gmstypes "github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb/durable"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/schema"
"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/index"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/store/types"
)
var errDoltSchemasTableFormat = fmt.Errorf("`%s` schema in unexpected format", doltdb.SchemasTableName)
var noSchemaIndexDefined = fmt.Errorf("could not find index `%s` on system table `%s`", doltdb.SchemasTablesIndexName, doltdb.SchemasTableName)
const (
viewFragment = "view"
triggerFragment = "trigger"
@@ -44,61 +38,57 @@ type Extra struct {
CreatedAt int64
}
// The fixed dolt schema for the `dolt_schemas` table.
func SchemasTableSchema() schema.Schema {
typeCol, err := schema.NewColumnWithTypeInfo(doltdb.SchemasTablesTypeCol, schema.DoltSchemasTypeTag, typeinfo.StringDefaultType, false, "", false, "")
func mustNewColWithTypeInfo(name string, tag uint64, typeInfo typeinfo.TypeInfo, partOfPK bool, defaultVal string, autoIncrement bool, comment string, constraints ...schema.ColConstraint) schema.Column {
col, err := schema.NewColumnWithTypeInfo(name, tag, typeInfo, partOfPK, defaultVal, autoIncrement, comment, constraints...)
if err != nil {
panic(err)
}
nameCol, err := schema.NewColumnWithTypeInfo(doltdb.SchemasTablesNameCol, schema.DoltSchemasNameTag, typeinfo.StringDefaultType, false, "", false, "")
if err != nil {
panic(err)
}
fragmentCol, err := schema.NewColumnWithTypeInfo(doltdb.SchemasTablesFragmentCol, schema.DoltSchemasFragmentTag, typeinfo.StringDefaultType, false, "", false, "")
if err != nil {
panic(err)
}
idCol, err := schema.NewColumnWithTypeInfo(doltdb.SchemasTablesIdCol, schema.DoltSchemasIdTag, typeinfo.Int64Type, true, "", false, "", schema.NotNullConstraint{})
if err != nil {
panic(err)
}
extraCol, err := schema.NewColumnWithTypeInfo(doltdb.SchemasTablesExtraCol, schema.DoltSchemasExtraTag, typeinfo.JSONType, false, "", false, "")
if err != nil {
panic(err)
}
colColl := schema.NewColCollection(typeCol, nameCol, fragmentCol, idCol, extraCol)
return schema.MustSchemaFromCols(colColl)
return col
}
// GetOrCreateDoltSchemasTable returns the `dolt_schemas` table in `db`, creating it if it does not already exist.
func GetOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *WritableDoltTable, retErr error) {
func mustCreateStringType(baseType query.Type, length int64, collation sql.CollationID) sql.StringType {
ti, err := gmstypes.CreateString(baseType, length, collation)
if err != nil {
panic(err)
}
return ti
}
// dolt_schemas columns
var schemasTableCols = schema.NewColCollection(
mustNewColWithTypeInfo(doltdb.SchemasTablesTypeCol, schema.DoltSchemasTypeTag, typeinfo.CreateVarStringTypeFromSqlType(mustCreateStringType(query.Type_VARCHAR, 64, sql.Collation_utf8mb4_0900_ai_ci)), true, "", false, ""),
mustNewColWithTypeInfo(doltdb.SchemasTablesNameCol, schema.DoltSchemasNameTag, typeinfo.CreateVarStringTypeFromSqlType(mustCreateStringType(query.Type_VARCHAR, 64, sql.Collation_utf8mb4_0900_ai_ci)), true, "", false, ""),
mustNewColWithTypeInfo(doltdb.SchemasTablesFragmentCol, schema.DoltSchemasFragmentTag, typeinfo.CreateVarStringTypeFromSqlType(gmstypes.LongText), false, "", false, ""),
mustNewColWithTypeInfo(doltdb.SchemasTablesExtraCol, schema.DoltSchemasExtraTag, typeinfo.JSONType, false, "", false, ""),
)
var schemaTableSchema = schema.MustSchemaFromCols(schemasTableCols)
// getOrCreateDoltSchemasTable returns the `dolt_schemas` table in `db`, creating it if it does not already exist.
// Also migrates data to the correct format if necessary.
func getOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *WritableDoltTable, retErr error) {
tbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
if err != nil {
return nil, err
}
var rowsToAdd []sql.Row
if found {
schemasTable := tbl.(*WritableDoltTable)
// Old schemas table does not contain the `id` or `extra` column.
if !tbl.Schema().Contains(doltdb.SchemasTablesIdCol, doltdb.SchemasTableName) || !tbl.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.SchemasTableName) {
root, err := db.GetRoot(ctx)
if err != nil {
return nil, err
}
root, rowsToAdd, err = migrateOldSchemasTableToNew(ctx, db, root, schemasTable)
if err != nil {
return nil, err
}
// Old schemas table contains the `id` column or is missing an `extra` column.
if tbl.Schema().Contains(doltdb.SchemasTablesIdCol, doltdb.SchemasTableName) || !tbl.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.SchemasTableName) {
return migrateOldSchemasTableToNew(ctx, db, schemasTable)
} else {
return schemasTable, nil
}
}
root, err := db.GetRoot(ctx)
if err != nil {
return nil, err
}
// Create the schemas table as an empty table
err = db.createDoltTable(ctx, doltdb.SchemasTableName, root, SchemasTableSchema())
// Create new empty table
err = db.createDoltTable(ctx, doltdb.SchemasTableName, root, schemaTableSchema)
if err != nil {
return nil, err
}
@@ -109,199 +99,129 @@ func GetOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *Writabl
if !found {
return nil, sql.ErrTableNotFound.New("dolt_schemas")
}
// Create a unique index on the old primary key columns (type, name)
t := (&AlterableDoltTable{*tbl.(*WritableDoltTable)})
err = t.CreateIndex(ctx, sql.IndexDef{
Name: doltdb.SchemasTablesIndexName,
Columns: []sql.IndexColumn{
{Name: doltdb.SchemasTablesTypeCol, Length: 0},
{Name: doltdb.SchemasTablesNameCol, Length: 0},
},
Constraint: sql.IndexConstraint_Unique,
Storage: sql.IndexUsing_Default,
})
return tbl.(*WritableDoltTable), nil
}
func migrateOldSchemasTableToNew(ctx *sql.Context, db Database, schemasTable *WritableDoltTable) (newTable *WritableDoltTable, rerr error) {
// Copy all of the old data over and add an index column and an extra column
iter, err := SqlTableToRowIter(ctx, schemasTable.DoltTable, nil)
if err != nil {
return nil, err
}
// If there was an old schemas table that contained rows, then add that data here
tbl, found, err = db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
// need to get the column indexes from the current schema
nameIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
typeIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
fragmentIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
extraIdx := schemasTable.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
defer func(iter sql.RowIter, ctx *sql.Context) {
err := iter.Close(ctx)
if err != nil && rerr == nil {
rerr = err
}
}(iter, ctx)
var newRows []sql.Row
for {
sqlRow, err := iter.Next(ctx)
if err == io.EOF {
break
}
if err != nil {
return nil, err
}
newRow := make(sql.Row, schemasTableCols.Size())
newRow[0] = sqlRow[typeIdx]
newRow[1] = sqlRow[nameIdx]
newRow[2] = sqlRow[fragmentIdx]
if extraIdx >= 0 {
newRow[3] = sqlRow[extraIdx]
}
newRows = append(newRows, newRow)
}
err = db.dropTable(ctx, doltdb.SchemasTableName)
if err != nil {
return nil, err
}
root, err := db.GetRoot(ctx)
if err != nil {
return nil, err
}
err = db.createDoltTable(ctx, doltdb.SchemasTableName, root, schemaTableSchema)
if err != nil {
return nil, err
}
tbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
if err != nil {
return nil, err
}
if !found {
return nil, sql.ErrTableNotFound.New("dolt_schemas")
}
if len(rowsToAdd) != 0 {
err = func() (retErr error) {
inserter := tbl.(*WritableDoltTable).Inserter(ctx)
defer func() {
err := inserter.Close(ctx)
if retErr == nil {
retErr = err
}
}()
for _, sqlRow := range rowsToAdd {
err = inserter.Insert(ctx, sqlRow)
if err != nil {
return err
}
}
return nil
}()
inserter := tbl.(*WritableDoltTable).Inserter(ctx)
for _, row := range newRows {
err = inserter.Insert(ctx, row)
if err != nil {
return nil, err
}
}
err = inserter.Close(ctx)
if err != nil {
return nil, err
}
return tbl.(*WritableDoltTable), nil
}
func migrateOldSchemasTableToNew(
ctx *sql.Context,
db Database,
root *doltdb.RootValue,
schemasTable *WritableDoltTable,
) (
*doltdb.RootValue,
[]sql.Row,
error,
) {
// Copy all of the old data over and add an index column and an extra column
var rowsToAdd []sql.Row
table, err := schemasTable.DoltTable.DoltTable(ctx)
if err != nil {
return nil, nil, err
}
rowData, err := table.GetNomsRowData(ctx)
if err != nil {
return nil, nil, err
}
id := int64(1)
err = rowData.IterAll(ctx, func(key, val types.Value) error {
dRow, err := row.FromNoms(schemasTable.sch, key.(types.Tuple), val.(types.Tuple))
if err != nil {
return err
}
sqlRow, err := sqlutil.DoltRowToSqlRow(dRow, schemasTable.sch)
if err != nil {
return err
}
// append the new id to row, if missing
if !schemasTable.sqlSchema().Contains(doltdb.SchemasTablesIdCol, doltdb.SchemasTableName) {
sqlRow = append(sqlRow, id)
}
// append the extra cols to row
sqlRow = append(sqlRow, nil)
rowsToAdd = append(rowsToAdd, sqlRow)
id++
return nil
})
if err != nil {
return nil, nil, err
}
err = db.DropTable(ctx, doltdb.SchemasTableName)
if err != nil {
return nil, nil, err
}
root, err = db.GetRoot(ctx)
if err != nil {
return nil, nil, err
}
return root, rowsToAdd, nil
}
func nextSchemasTableIndex(ctx *sql.Context, root *doltdb.RootValue) (int64, error) {
tbl, _, err := root.GetTable(ctx, doltdb.SchemasTableName)
if err != nil {
return 0, err
}
rows, err := tbl.GetRowData(ctx)
if err != nil {
return 0, err
}
empty, err := rows.Empty()
if err != nil {
return 0, err
}
if empty {
return 1, nil
}
if types.IsFormat_DOLT(tbl.Format()) {
p := durable.ProllyMapFromIndex(rows)
key := p.LastKey(ctx)
kd, _ := p.Descriptors()
i, _ := kd.GetInt64(0, key)
return i + 1, nil
} else {
m := durable.NomsMapFromIndex(rows)
keyTpl, _, err := m.Last(ctx)
if err != nil {
return 0, err
}
if keyTpl == nil {
return 1, nil
}
key, err := keyTpl.(types.Tuple).Get(1)
if err != nil {
return 0, err
}
return int64(key.(types.Int)) + 1, nil
}
}
// fragFromSchemasTable returns the row with the given schema fragment if it exists.
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (sql.Row, bool, error) {
indexes, err := tbl.GetIndexes(ctx)
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (r sql.Row, found bool, rerr error) {
fragType, name = strings.ToLower(fragType), strings.ToLower(name)
// This performs a full table scan in the worst case, but it's only used when adding or dropping a trigger or view
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
if err != nil {
return nil, false, err
}
var fragNameIndex sql.Index
for _, index := range indexes {
if index.ID() == doltdb.SchemasTablesIndexName {
fragNameIndex = index
break
defer func(iter sql.RowIter, ctx *sql.Context) {
err := iter.Close(ctx)
if err != nil && rerr == nil {
rerr = err
}
}
if fragNameIndex == nil {
return nil, false, noSchemaIndexDefined
}
}(iter, ctx)
exprs := fragNameIndex.Expressions()
lookup, err := sql.NewIndexBuilder(fragNameIndex).Equals(ctx, exprs[0], fragType).Equals(ctx, exprs[1], name).Build(ctx)
if err != nil {
return nil, false, err
}
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
// need to get the column indexes from the current schema
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
iter, err := index.RowIterForIndexLookup(ctx, tbl.DoltTable, lookup, tbl.sqlSch, nil)
if err != nil {
return nil, false, err
}
defer func() {
if cerr := iter.Close(ctx); cerr != nil {
err = cerr
}
}()
// todo(andy): use filtered reader?
for {
sqlRow, err := iter.Next(ctx)
if err == io.EOF {
return nil, false, nil
break
}
if err != nil {
return nil, false, err
}
if sqlRow[0] != fragType || sqlRow[1] != name {
continue
// These columns are case insensitive, make sure to do a case-insensitive comparison
if strings.ToLower(sqlRow[typeIdx].(string)) == fragType && strings.ToLower(sqlRow[nameIdx].(string)) == name {
return sqlRow, true, nil
}
return sqlRow, true, nil
}
return nil, false, nil
}
type schemaFragment struct {
@@ -310,12 +230,26 @@ type schemaFragment struct {
created time.Time
}
func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType string) ([]schemaFragment, error) {
func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType string) (sf []schemaFragment, rerr error) {
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
if err != nil {
return nil, err
}
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
// need to get the column indexes from the current schema
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
fragmentIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
extraIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
defer func(iter sql.RowIter, ctx *sql.Context) {
err := iter.Close(ctx)
if err != nil && rerr == nil {
rerr = err
}
}(iter, ctx)
var frags []schemaFragment
for {
sqlRow, err := iter.Next(ctx)
@@ -326,34 +260,35 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
return nil, err
}
if sqlRow[0] != fragType {
if sqlRow[typeIdx] != fragType {
continue
}
// For tables that haven't been converted yet or are filled with nil, use 1 as the trigger creation time
if len(sqlRow) < 5 || sqlRow[4] == nil {
// For older tables, use 1 as the trigger creation time
if extraIdx < 0 || sqlRow[extraIdx] == nil {
frags = append(frags, schemaFragment{
name: sqlRow[1].(string),
fragment: sqlRow[2].(string),
name: sqlRow[nameIdx].(string),
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
})
continue
}
// Extract Created Time from JSON column
createdTime, err := getCreatedTime(ctx, sqlRow)
createdTime, err := getCreatedTime(ctx, sqlRow[extraIdx].(gmstypes.JSONValue))
frags = append(frags, schemaFragment{
name: sqlRow[1].(string),
fragment: sqlRow[2].(string),
name: sqlRow[nameIdx].(string),
fragment: sqlRow[fragmentIdx].(string),
created: time.Unix(createdTime, 0).UTC(),
})
}
return frags, nil
}
func getCreatedTime(ctx *sql.Context, row sql.Row) (int64, error) {
doc, err := row[4].(gmstypes.JSONValue).Unmarshall(ctx)
func getCreatedTime(ctx *sql.Context, extraCol gmstypes.JSONValue) (int64, error) {
doc, err := extraCol.Unmarshall(ctx)
if err != nil {
return 0, err
}
+71 -98
View File
@@ -16,6 +16,8 @@ package sqle
import (
"context"
"io"
"strings"
"testing"
"github.com/dolthub/go-mysql-server/sql"
@@ -25,17 +27,12 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
"github.com/dolthub/dolt/go/libraries/doltcore/row"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/json"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/store/types"
)
func TestSchemaTableRecreationOlder(t *testing.T) {
if types.Format_Default != types.Format_LD_1 {
t.Skip() // schema table migrations predate NBF __DOLT__
}
func TestSchemaTableMigrationOriginal(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
@@ -51,15 +48,17 @@ func TestSchemaTableRecreationOlder(t *testing.T) {
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{ // schema of dolt_schemas table before the change
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{ // original schema of dolt_schemas table
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
{Name: doltdb.SchemasTablesNameCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
{Name: doltdb.SchemasTablesFragmentCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
}), sql.Collation_Default)
require.NoError(t, err)
sqlTbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
require.NoError(t, err)
require.True(t, found)
inserter := sqlTbl.(*WritableDoltTable).Inserter(ctx)
err = inserter.Insert(ctx, sql.Row{"view", "view1", "SELECT v1 FROM test;"})
require.NoError(t, err)
@@ -68,59 +67,33 @@ func TestSchemaTableRecreationOlder(t *testing.T) {
err = inserter.Close(ctx)
require.NoError(t, err)
table, err := sqlTbl.(*WritableDoltTable).DoltTable.DoltTable(ctx)
tbl, err := getOrCreateDoltSchemasTable(ctx, db) // removes the old table and recreates it with the new schema
require.NoError(t, err)
rowData, err := table.GetNomsRowData(ctx)
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
require.NoError(t, err)
expectedVals := []sql.Row{
{"view", "view1", "SELECT v1 FROM test;"},
{"view", "view2", "SELECT v2 FROM test;"},
var rows []sql.Row
for {
row, err := iter.Next(ctx)
if err == io.EOF {
break
}
require.NoError(t, err)
rows = append(rows, row)
}
index := 0
_ = rowData.IterAll(ctx, func(keyTpl, valTpl types.Value) error {
dRow, err := row.FromNoms(sqlTbl.(*WritableDoltTable).sch, keyTpl.(types.Tuple), valTpl.(types.Tuple))
require.NoError(t, err)
sqlRow, err := sqlutil.DoltRowToSqlRow(dRow, sqlTbl.(*WritableDoltTable).sch)
require.NoError(t, err)
assert.Equal(t, expectedVals[index], sqlRow)
index++
return nil
})
tbl, err := GetOrCreateDoltSchemasTable(ctx, db) // removes the old table and recreates it with the new schema
require.NoError(t, err)
table, err = tbl.DoltTable.DoltTable(ctx)
require.NoError(t, err)
rowData, err = table.GetNomsRowData(ctx)
require.NoError(t, err)
expectedVals = []sql.Row{
{"view", "view1", "SELECT v1 FROM test;", int64(1), nil},
{"view", "view2", "SELECT v2 FROM test;", int64(2), nil},
require.NoError(t, iter.Close(ctx))
expectedRows := []sql.Row{
{"view", "view1", "SELECT v1 FROM test;", nil},
{"view", "view2", "SELECT v2 FROM test;", nil},
}
index = 0
_ = rowData.IterAll(ctx, func(keyTpl, valTpl types.Value) error {
dRow, err := row.FromNoms(tbl.sch, keyTpl.(types.Tuple), valTpl.(types.Tuple))
require.NoError(t, err)
sqlRow, err := sqlutil.DoltRowToSqlRow(dRow, tbl.sch)
require.NoError(t, err)
assert.Equal(t, expectedVals[index], sqlRow)
index++
return nil
})
indexes := tbl.sch.Indexes().AllIndexes()
require.Len(t, indexes, 1)
assert.Equal(t, true, indexes[0].IsUnique())
assert.Equal(t, doltdb.SchemasTablesIndexName, indexes[0].Name())
assert.Equal(t, expectedRows, rows)
}
func TestSchemaTableRecreation(t *testing.T) {
if types.Format_Default != types.Format_LD_1 {
t.Skip() // schema table migrations predate NBF __DOLT__
}
func TestSchemaTableMigrationV1(t *testing.T) {
ctx := NewTestSQLCtx(context.Background())
dEnv := dtestutils.CreateTestEnv()
tmpDir, err := dEnv.TempTableFilesDir()
@@ -136,70 +109,70 @@ func TestSchemaTableRecreation(t *testing.T) {
require.NoError(t, err)
ctx.SetCurrentDatabase(db.Name())
// This is the schema of dolt_schemas table after the change adding the ID column, but before adding the extra column
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{ //
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
{Name: doltdb.SchemasTablesNameCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: true},
// original schema of dolt_schemas table with the ID column
err = db.createSqlTable(ctx, doltdb.SchemasTableName, sql.NewPrimaryKeySchema(sql.Schema{
{Name: doltdb.SchemasTablesTypeCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
{Name: doltdb.SchemasTablesNameCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
{Name: doltdb.SchemasTablesFragmentCol, Type: gmstypes.Text, Source: doltdb.SchemasTableName, PrimaryKey: false},
{Name: doltdb.SchemasTablesIdCol, Type: gmstypes.Int64, Source: doltdb.SchemasTableName, PrimaryKey: false},
{Name: doltdb.SchemasTablesIdCol, Type: gmstypes.Int64, Source: doltdb.SchemasTableName, PrimaryKey: true},
{Name: doltdb.SchemasTablesExtraCol, Type: gmstypes.JsonType{}, Source: doltdb.SchemasTableName, PrimaryKey: false, Nullable: true},
}), sql.Collation_Default)
require.NoError(t, err)
sqlTbl, found, err := db.GetTableInsensitive(ctx, doltdb.SchemasTableName)
require.NoError(t, err)
require.True(t, found)
inserter := sqlTbl.(*WritableDoltTable).Inserter(ctx)
err = inserter.Insert(ctx, sql.Row{"view", "view1", "SELECT v1 FROM test;", int64(1)})
// JSON string has no spaces because our various JSON libraries don't agree on how to marshall it
err = inserter.Insert(ctx, sql.Row{"view", "view1", "SELECT v1 FROM test;", 1, `{"extra":"data"}`})
require.NoError(t, err)
err = inserter.Insert(ctx, sql.Row{"view", "view2", "SELECT v2 FROM test;", int64(2)})
err = inserter.Insert(ctx, sql.Row{"view", "view2", "SELECT v2 FROM test;", 2, nil})
require.NoError(t, err)
err = inserter.Close(ctx)
require.NoError(t, err)
table, err := sqlTbl.(*WritableDoltTable).DoltTable.DoltTable(ctx)
tbl, err := getOrCreateDoltSchemasTable(ctx, db) // removes the old table and recreates it with the new schema
require.NoError(t, err)
rowData, err := table.GetNomsRowData(ctx)
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
require.NoError(t, err)
expectedVals := []sql.Row{
{"view", "view1", "SELECT v1 FROM test;", int64(1)},
{"view", "view2", "SELECT v2 FROM test;", int64(2)},
var rows []sql.Row
for {
row, err := iter.Next(ctx)
if err == io.EOF {
break
}
require.NoError(t, err)
// convert the JSONDocument to a string for comparison
if row[3] != nil {
// Annoying difference in representation between storage versions here
jsonDoc, ok := row[3].(gmstypes.JSONDocument)
if ok {
row[3], err = jsonDoc.ToString(nil)
row[3] = strings.ReplaceAll(row[3].(string), " ", "") // remove spaces
}
nomsJson, ok := row[3].(json.NomsJSON)
if ok {
row[3], err = nomsJson.ToString(ctx)
row[3] = strings.ReplaceAll(row[3].(string), " ", "") // remove spaces
}
require.NoError(t, err)
}
rows = append(rows, row)
}
index := 0
_ = rowData.IterAll(ctx, func(keyTpl, valTpl types.Value) error {
dRow, err := row.FromNoms(sqlTbl.(*WritableDoltTable).sch, keyTpl.(types.Tuple), valTpl.(types.Tuple))
require.NoError(t, err)
sqlRow, err := sqlutil.DoltRowToSqlRow(dRow, sqlTbl.(*WritableDoltTable).sch)
require.NoError(t, err)
assert.Equal(t, expectedVals[index], sqlRow)
index++
return nil
})
tbl, err := GetOrCreateDoltSchemasTable(ctx, db) // removes the old table and recreates it with the new schema
require.NoError(t, err)
require.NoError(t, iter.Close(ctx))
table, err = tbl.DoltTable.DoltTable(ctx)
require.NoError(t, err)
rowData, err = table.GetNomsRowData(ctx)
require.NoError(t, err)
expectedVals = []sql.Row{
{"view", "view1", "SELECT v1 FROM test;", int64(1), nil},
{"view", "view2", "SELECT v2 FROM test;", int64(2), nil},
expectedRows := []sql.Row{
{"view", "view1", "SELECT v1 FROM test;", `{"extra":"data"}`},
{"view", "view2", "SELECT v2 FROM test;", nil},
}
index = 0
_ = rowData.IterAll(ctx, func(keyTpl, valTpl types.Value) error {
dRow, err := row.FromNoms(tbl.sch, keyTpl.(types.Tuple), valTpl.(types.Tuple))
require.NoError(t, err)
sqlRow, err := sqlutil.DoltRowToSqlRow(dRow, tbl.sch)
require.NoError(t, err)
assert.Equal(t, expectedVals[index], sqlRow)
index++
return nil
})
indexes := tbl.sch.Indexes().AllIndexes()
require.Len(t, indexes, 1)
assert.Equal(t, true, indexes[0].IsUnique())
assert.Equal(t, doltdb.SchemasTablesIndexName, indexes[0].Name())
assert.Equal(t, expectedRows, rows)
}
+7 -12
View File
@@ -795,8 +795,8 @@ func TestRenameTableStatements(t *testing.T) {
}
func TestAlterSystemTables(t *testing.T) {
systemTableNames := []string{"dolt_log", "dolt_history_people", "dolt_diff_people", "dolt_commit_diff_people"} // "dolt_docs",
reservedTableNames := []string{"dolt_schemas", "dolt_query_catalog"}
systemTableNames := []string{"dolt_log", "dolt_history_people", "dolt_diff_people", "dolt_commit_diff_people", "dolt_schemas"} // "dolt_docs",
reservedTableNames := []string{"dolt_query_catalog"}
var dEnv *env.DoltEnv
var err error
@@ -807,15 +807,15 @@ func TestAlterSystemTables(t *testing.T) {
err := CreateEmptyTestTable(dEnv, "dolt_docs", doltdb.DocsSchema)
require.NoError(t, err)
err = CreateEmptyTestTable(dEnv, doltdb.SchemasTableName, SchemasTableSchema())
err = CreateEmptyTestTable(dEnv, doltdb.SchemasTableName, schemaTableSchema)
require.NoError(t, err)
CreateTestTable(t, dEnv, "dolt_docs", doltdb.DocsSchema,
"INSERT INTO dolt_docs VALUES ('LICENSE.md','A license')")
CreateTestTable(t, dEnv, doltdb.DoltQueryCatalogTableName, dtables.DoltQueryCatalogSchema,
"INSERT INTO dolt_query_catalog VALUES ('abc123', 1, 'example', 'select 2+2 from dual', 'description')")
CreateTestTable(t, dEnv, doltdb.SchemasTableName, SchemasTableSchema(),
"INSERT INTO dolt_schemas (type, name, fragment, id) VALUES ('view', 'name', 'create view name as select 2+2 from dual', 1)")
CreateTestTable(t, dEnv, doltdb.SchemasTableName, schemaTableSchema,
"INSERT INTO dolt_schemas (type, name, fragment) VALUES ('view', 'name', 'create view name as select 2+2 from dual')")
}
t.Run("Create", func(t *testing.T) {
@@ -825,15 +825,10 @@ func TestAlterSystemTables(t *testing.T) {
}
})
// The _history and _diff tables give not found errors right now because of https://github.com/dolthub/dolt/issues/373.
// We can remove the divergent failure logic when the issue is fixed.
t.Run("Drop", func(t *testing.T) {
setup()
for _, tableName := range systemTableNames {
expectedErr := "system table"
if strings.HasPrefix(tableName, "dolt_diff") || strings.HasPrefix(tableName, "dolt_history") {
expectedErr = "system tables cannot be dropped or altered"
}
for _, tableName := range append(systemTableNames) {
expectedErr := "system tables cannot be dropped or altered"
assertFails(t, dEnv, fmt.Sprintf("drop table %s", tableName), expectedErr)
}
for _, tableName := range reservedTableNames {
+3 -3
View File
@@ -207,12 +207,12 @@ var systemTableDeleteTests = []DeleteTest{
},
{
Name: "delete dolt_schemas",
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, SchemasTableSchema(),
"INSERT INTO dolt_schemas (type, name, fragment, id) VALUES ('view', 'name', 'create view name as select 2+2 from dual', 1)"),
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemaTableSchema,
"INSERT INTO dolt_schemas (type, name, fragment) VALUES ('view', 'name', 'create view name as select 2+2 from dual')"),
DeleteQuery: "delete from dolt_schemas",
SelectQuery: "select * from dolt_schemas",
ExpectedRows: ToSqlRows(dtables.DoltQueryCatalogSchema),
ExpectedSchema: SchemasTableSchema(),
ExpectedSchema: schemaTableSchema,
},
}
+6 -6
View File
@@ -397,13 +397,13 @@ var systemTableInsertTests = []InsertTest{
},
{
Name: "insert into dolt_schemas",
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, SchemasTableSchema(), ""),
InsertQuery: "insert into dolt_schemas (id, type, name, fragment) values (1, 'view', 'name', 'create view name as select 2+2 from dual')",
SelectQuery: "select * from dolt_schemas ORDER BY id",
ExpectedRows: ToSqlRows(CompressSchema(SchemasTableSchema()),
NewRow(types.String("view"), types.String("name"), types.String("create view name as select 2+2 from dual"), types.Int(1)),
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemaTableSchema, ""),
InsertQuery: "insert into dolt_schemas (type, name, fragment) values ('view', 'name', 'create view name as select 2+2 from dual')",
SelectQuery: "select * from dolt_schemas ORDER BY name",
ExpectedRows: ToSqlRows(CompressSchema(schemaTableSchema),
NewRow(types.String("view"), types.String("name"), types.String("create view name as select 2+2 from dual")),
),
ExpectedSchema: CompressSchema(SchemasTableSchema()),
ExpectedSchema: CompressSchema(schemaTableSchema),
},
}
@@ -272,12 +272,12 @@ var systemTableReplaceTests = []ReplaceTest{
},
{
Name: "replace into dolt_schemas",
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, SchemasTableSchema(),
"INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', 1, NULL)"),
ReplaceQuery: "replace into dolt_schemas (id, type, name, fragment) values ('1', 'view', 'name', 'create view name as select 1+1 from dual')",
SelectQuery: "select type, name, fragment, id, extra from dolt_schemas",
ExpectedRows: []sql.Row{{"view", "name", "create view name as select 1+1 from dual", int64(1), nil}},
ExpectedSchema: CompressSchema(SchemasTableSchema()),
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemaTableSchema,
"INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', NULL)"),
ReplaceQuery: "replace into dolt_schemas (type, name, fragment) values ('view', 'name', 'create view name as select 1+1 from dual')",
SelectQuery: "select type, name, fragment, extra from dolt_schemas",
ExpectedRows: []sql.Row{{"view", "name", "create view name as select 1+1 from dual", nil}},
ExpectedSchema: CompressSchema(schemaTableSchema),
},
}
+4 -11
View File
@@ -1298,21 +1298,14 @@ var systemTableSelectTests = []SelectTest{
},
{
Name: "select from dolt_schemas",
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, SchemasTableSchema(),
`INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', 1, NULL)`),
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemaTableSchema,
`INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', NULL)`),
Query: "select * from dolt_schemas",
ExpectedRows: []sql.Row{{"view", "name", "create view name as select 2+2 from dual", int64(1), nil}},
ExpectedSchema: CompressSchema(SchemasTableSchema()),
ExpectedRows: []sql.Row{{"view", "name", "create view name as select 2+2 from dual", nil}},
ExpectedSchema: CompressSchema(schemaTableSchema),
},
}
func CreateTestJSON() types.JSON {
vrw := types.NewMemoryValueStore()
extraJSON, _ := types.NewMap(nil, vrw, types.String("CreatedAt"), types.Float(1))
res, _ := types.NewJSONDoc(types.Format_Default, vrw, extraJSON)
return res
}
func TestSelectSystemTables(t *testing.T) {
for _, test := range systemTableSelectTests {
t.Run(test.Name, func(t *testing.T) {
+4 -4
View File
@@ -377,12 +377,12 @@ var systemTableUpdateTests = []UpdateTest{
},
{
Name: "update dolt_schemas",
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, SchemasTableSchema(),
`INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', 1, NULL)`),
AdditionalSetup: CreateTableFn(doltdb.SchemasTableName, schemaTableSchema,
`INSERT INTO dolt_schemas VALUES ('view', 'name', 'create view name as select 2+2 from dual', NULL)`),
UpdateQuery: "update dolt_schemas set type = 'not a view'",
SelectQuery: "select * from dolt_schemas",
ExpectedRows: []sql.Row{{"not a view", "name", "create view name as select 2+2 from dual", int64(1), nil}},
ExpectedSchema: CompressSchema(SchemasTableSchema()),
ExpectedRows: []sql.Row{{"not a view", "name", "create view name as select 2+2 from dual", nil}},
ExpectedSchema: CompressSchema(schemaTableSchema),
},
}
+1
View File
@@ -62,6 +62,7 @@ func runRoot(ctx context.Context, args []string) int {
cfg := config.NewResolver()
cs, err := cfg.GetChunkStore(ctx, args[0])
util.CheckErrorNoUsage(err)
defer cs.Close()
currRoot, err := cs.Root(ctx)
+2 -1
View File
@@ -164,6 +164,7 @@ type datasFactory func(context.Context) (types.ValueReadWriter, datas.Database)
func testPuller(t *testing.T, makeDB datasFactory) {
ctx := context.Background()
vs, db := makeDB(ctx)
defer db.Close()
deltas := []struct {
name string
@@ -325,6 +326,7 @@ func testPuller(t *testing.T, makeDB datasFactory) {
}()
sinkvs, sinkdb := makeDB(ctx)
defer sinkdb.Close()
tmpDir := filepath.Join(os.TempDir(), uuid.New().String())
err = os.MkdirAll(tmpDir, os.ModePerm)
@@ -351,7 +353,6 @@ func testPuller(t *testing.T, makeDB datasFactory) {
eq, err := pullerAddrEquality(ctx, rootAddr, sinkRootAddr, vs, sinkvs)
require.NoError(t, err)
assert.True(t, eq)
})
}
}
+3 -3
View File
@@ -98,14 +98,14 @@ func newAWSChunkSource(ctx context.Context, ddb *ddbTableStore, s3 *s3ObjectRead
func loadTableIndex(ctx context.Context, stats *Stats, cnt uint32, q MemoryQuotaProvider, loadIndexBytes func(p []byte) error) (tableIndex, error) {
idxSz := int(indexSize(cnt) + footerSize)
offsetSz := int((cnt - (cnt / 2)) * offsetSize)
buf, err := q.AcquireQuotaBytes(ctx, uint64(idxSz+offsetSz))
buf, err := q.AcquireQuotaBytes(ctx, idxSz+offsetSz)
if err != nil {
return nil, err
}
t1 := time.Now()
if err := loadIndexBytes(buf[:idxSz]); err != nil {
q.ReleaseQuotaBytes(buf)
q.ReleaseQuotaBytes(len(buf))
return nil, err
}
stats.IndexReadLatency.SampleTimeSince(t1)
@@ -113,7 +113,7 @@ func loadTableIndex(ctx context.Context, stats *Stats, cnt uint32, q MemoryQuota
idx, err := parseTableIndexWithOffsetBuff(buf[:idxSz], buf[idxSz:], q)
if err != nil {
q.ReleaseQuotaBytes(buf)
q.ReleaseQuotaBytes(len(buf))
}
return idx, err
}
+2
View File
@@ -67,6 +67,7 @@ func TestAWSChunkSource(t *testing.T) {
t.Run("Has Chunks", func(t *testing.T) {
src := makeSrc(len(chunks) + 1)
assertChunksInReader(chunks, src, assert.New(t))
src.close()
})
})
@@ -76,6 +77,7 @@ func TestAWSChunkSource(t *testing.T) {
t.Run("Has Chunks", func(t *testing.T) {
src := makeSrc(len(chunks) - 1)
assertChunksInReader(chunks, src, assert.New(t))
src.close()
})
})
}
+16 -11
View File
@@ -25,6 +25,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/url"
"sort"
@@ -392,9 +393,9 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
readWg.Add(1)
go func(m manualPart) {
defer readWg.Done()
n, _ := m.srcR.Read(buff[m.dstStart:m.dstEnd])
if int64(n) < m.dstEnd-m.dstStart {
ae.SetIfError(errors.New("failed to read all the table data"))
err := m.run(ctx, buff)
if err != nil {
ae.SetIfError(fmt.Errorf("failed to read conjoin table data: %w", err))
}
}(man)
}
@@ -507,8 +508,17 @@ type copyPart struct {
}
type manualPart struct {
srcR io.Reader
dstStart, dstEnd int64
src chunkSource
start, end int64
}
func (mp manualPart) run(ctx context.Context, buff []byte) error {
reader, _, err := mp.src.reader(ctx)
if err != nil {
return err
}
_, err = io.ReadFull(reader, buff[mp.start:mp.end])
return err
}
// dividePlan assumes that plan.sources (which is of type chunkSourcesByDescendingDataSize) is correctly sorted by descending data size.
@@ -545,12 +555,7 @@ func dividePlan(ctx context.Context, plan compactionPlan, minPartSize, maxPartSi
var offset int64
for ; i < len(plan.sources.sws); i++ {
sws := plan.sources.sws[i]
rdr, _, err := sws.source.reader(ctx)
if err != nil {
return nil, nil, nil, err
}
manuals = append(manuals, manualPart{rdr, offset, offset + int64(sws.dataLen)})
manuals = append(manuals, manualPart{sws.source, offset, offset + int64(sws.dataLen)})
offset += int64(sws.dataLen)
buffSize += sws.dataLen
}
+37 -1
View File
@@ -58,10 +58,12 @@ func TestAWSTablePersisterPersist(t *testing.T) {
src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
defer src.close()
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(testChunks, r, assert)
r.close()
}
}
})
@@ -75,9 +77,11 @@ func TestAWSTablePersisterPersist(t *testing.T) {
src, err := s3p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
defer src.close()
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTableWithNamespace(ctx, ns, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(testChunks, r, assert)
r.close()
}
}
})
@@ -99,6 +103,7 @@ func TestAWSTablePersisterPersist(t *testing.T) {
src, err := s3p.Persist(context.Background(), mt, existingTable, &Stats{})
require.NoError(t, err)
defer src.close()
assert.True(mustUint32(src.count()) == 0)
_, present := s3svc.data[src.hash().String()]
@@ -268,6 +273,11 @@ func TestAWSTablePersisterDividePlan(t *testing.T) {
tooBig := bytesToChunkSource(t, bigUns...)
sources := chunkSources{justRight, tooBig, tooSmall}
defer func() {
for _, s := range sources {
s.close()
}
}()
plan, err := planRangeCopyConjoin(sources, &Stats{})
require.NoError(t, err)
copies, manuals, _, err := dividePlan(context.Background(), plan, minPartSize, maxPartSize)
@@ -294,7 +304,7 @@ func TestAWSTablePersisterDividePlan(t *testing.T) {
assert.Len(manuals, 1)
ti, err = tooSmall.index()
require.NoError(t, err)
assert.EqualValues(calcChunkRangeSize(ti), manuals[0].dstEnd-manuals[0].dstStart)
assert.EqualValues(calcChunkRangeSize(ti), manuals[0].end-manuals[0].start)
}
func TestAWSTablePersisterCalcPartSizes(t *testing.T) {
@@ -349,6 +359,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
ti, err := src.index()
require.NoError(t, err)
smallChunkTotal += calcChunkRangeSize(ti)
ti.Close()
}
t.Run("Small", func(t *testing.T) {
@@ -372,10 +383,15 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
sources := makeSources(s3p, chunks)
src, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
require.NoError(t, err)
defer src.close()
for _, s := range sources {
s.close()
}
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(chunks, r, assert)
r.close()
}
}
})
@@ -388,10 +404,15 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
sources := makeSources(s3p, smallChunks)
src, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
require.NoError(t, err)
defer src.close()
for _, s := range sources {
s.close()
}
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(smallChunks, r, assert)
r.close()
}
}
})
@@ -424,11 +445,16 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
}
src, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
require.NoError(t, err)
defer src.close()
for _, s := range sources {
s.close()
}
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(bigUns1, r, assert)
assertChunksInReader(bigUns2, r, assert)
r.close()
}
}
})
@@ -460,11 +486,16 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
src, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
require.NoError(t, err)
defer src.close()
for _, s := range sources {
s.close()
}
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(bigUns1, r, assert)
assertChunksInReader(medChunks, r, assert)
r.close()
}
}
})
@@ -510,12 +541,17 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
src, err := s3p.ConjoinAll(context.Background(), sources, &Stats{})
require.NoError(t, err)
defer src.close()
for _, s := range sources {
s.close()
}
if assert.True(mustUint32(src.count()) > 0) {
if r, err := s3svc.readerForTable(ctx, src.hash()); assert.NotNil(r) && assert.NoError(err) {
assertChunksInReader(smallChunks, r, assert)
assertChunksInReader(bigUns1, r, assert)
assertChunksInReader(medChunks, r, assert)
r.close()
}
}
})
+3
View File
@@ -320,6 +320,7 @@ func (suite *BlockStoreSuite) TestChunkStoreFlushOptimisticLockFail() {
interloper, err := suite.factory(context.Background(), suite.dir)
suite.NoError(err)
defer interloper.Close()
err = interloper.Put(context.Background(), c1, noopGetAddrs)
suite.NoError(err)
h, err := interloper.Root(context.Background())
@@ -369,6 +370,7 @@ func (suite *BlockStoreSuite) TestChunkStoreRebaseOnNoOpFlush() {
interloper, err := suite.factory(context.Background(), suite.dir)
suite.NoError(err)
defer interloper.Close()
err = interloper.Put(context.Background(), c1, noopGetAddrs)
suite.NoError(err)
root, err := interloper.Root(context.Background())
@@ -408,6 +410,7 @@ func (suite *BlockStoreSuite) TestChunkStorePutWithRebase() {
interloper, err := suite.factory(context.Background(), suite.dir)
suite.NoError(err)
defer interloper.Close()
err = interloper.Put(context.Background(), c1, noopGetAddrs)
suite.NoError(err)
h, err := interloper.Root(context.Background())
@@ -40,6 +40,7 @@ func TestCmpChunkTableWriter(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer tr.close()
hashes := make(hash.HashSet)
for _, chnk := range testMDChunks {
@@ -89,6 +90,7 @@ func TestCmpChunkTableWriter(t *testing.T) {
require.NoError(t, err)
outputTR, err := newTableReader(outputTI, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer outputTR.close()
compareContentsOfTables(t, ctx, hashes, tr, outputTR)
}
+8
View File
@@ -197,6 +197,13 @@ func conjoinTables(ctx context.Context, conjoinees []tableSpec, p tablePersister
return
})
}
defer func() {
for _, cs := range toConjoin {
if cs != nil {
cs.close()
}
}
}()
if err = eg.Wait(); err != nil {
return tableSpec{}, err
}
@@ -207,6 +214,7 @@ func conjoinTables(ctx context.Context, conjoinees []tableSpec, p tablePersister
if err != nil {
return tableSpec{}, err
}
defer conjoinedSrc.close()
stats.ConjoinLatency.SampleTimeSince(t1)
stats.TablesPerConjoin.SampleLen(len(toConjoin))
+10
View File
@@ -68,6 +68,7 @@ func makeTestSrcs(t *testing.T, tableSizes []uint32, p tablePersister) (srcs chu
c, err := cs.clone()
require.NoError(t, err)
srcs = append(srcs, c)
cs.close()
}
return
}
@@ -143,6 +144,14 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) {
}
expectSrcs, actualSrcs := open(expect), open(actual)
defer func() {
for _, s := range expectSrcs {
s.close()
}
for _, s := range actualSrcs {
s.close()
}
}()
ctx := context.Background()
for _, src := range expectSrcs {
@@ -173,6 +182,7 @@ func testConjoin(t *testing.T, factory func(t *testing.T) tablePersister) {
mt.addChunk(computeAddr(data), data)
src, err := p.Persist(context.Background(), mt, nil, &Stats{})
require.NoError(t, err)
defer src.close()
return tableSpec{src.hash(), mustUint32(src.count())}
}
+25 -3
View File
@@ -54,8 +54,9 @@ func TestFSTableCacheOnOpen(t *testing.T) {
names = append(names, name)
}
for _, name := range names {
_, err := fts.Open(context.Background(), name, 1, nil)
tr, err := fts.Open(context.Background(), name, 1, nil)
require.NoError(t, err)
defer tr.close()
}
}()
@@ -63,6 +64,7 @@ func TestFSTableCacheOnOpen(t *testing.T) {
for i, name := range names {
src, err := fts.Open(context.Background(), name, 1, nil)
require.NoError(t, err)
defer src.close()
h := computeAddr([]byte{byte(i)})
assert.True(src.has(h))
}
@@ -70,8 +72,9 @@ func TestFSTableCacheOnOpen(t *testing.T) {
// Kick a table out of the cache
name, err := writeTableData(dir, []byte{0xff})
require.NoError(t, err)
_, err = fts.Open(context.Background(), name, 1, nil)
tr, err := fts.Open(context.Background(), name, 1, nil)
require.NoError(t, err)
defer tr.close()
present := fc.reportEntries()
// Since 0 refcount entries are evicted randomly, the only thing we can validate is that fc remains at its target size
@@ -125,6 +128,7 @@ func TestFSTablePersisterPersist(t *testing.T) {
src, err := persistTableData(fts, testChunks...)
require.NoError(t, err)
defer src.close()
if assert.True(mustUint32(src.count()) > 0) {
buff, err := os.ReadFile(filepath.Join(dir, src.hash().String()))
require.NoError(t, err)
@@ -132,6 +136,7 @@ func TestFSTablePersisterPersist(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer tr.close()
assertChunksInReader(testChunks, tr, assert)
}
}
@@ -182,17 +187,20 @@ func TestFSTablePersisterCacheOnPersist(t *testing.T) {
func() {
src, err := persistTableData(fts, testChunks...)
require.NoError(t, err)
defer src.close()
name = src.hash()
}()
// Table should still be cached
src, err := fts.Open(context.Background(), name, uint32(len(testChunks)), nil)
require.NoError(t, err)
defer src.close()
assertChunksInReader(testChunks, src, assert)
// Evict |name| from cache
_, err = persistTableData(fts, []byte{0xff})
tr, err := persistTableData(fts, []byte{0xff})
require.NoError(t, err)
defer tr.close()
present := fc.reportEntries()
// Since 0 refcount entries are evicted randomly, the only thing we can validate is that fc remains at its target size
@@ -223,9 +231,15 @@ func TestFSTablePersisterConjoinAll(t *testing.T) {
sources[i], err = fts.Open(ctx, name, 2, nil)
require.NoError(t, err)
}
defer func() {
for _, s := range sources {
s.close()
}
}()
src, err := fts.ConjoinAll(ctx, sources, &Stats{})
require.NoError(t, err)
defer src.close()
if assert.True(mustUint32(src.count()) > 0) {
buff, err := os.ReadFile(filepath.Join(dir, src.hash().String()))
@@ -234,6 +248,7 @@ func TestFSTablePersisterConjoinAll(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer tr.close()
assertChunksInReader(testChunks, tr, assert)
}
@@ -263,9 +278,15 @@ func TestFSTablePersisterConjoinAllDups(t *testing.T) {
sources[i], err = fts.Persist(ctx, mt, nil, &Stats{})
require.NoError(t, err)
}
defer func() {
for _, s := range sources {
s.close()
}
}()
src, err := fts.ConjoinAll(ctx, sources, &Stats{})
require.NoError(t, err)
defer src.close()
if assert.True(mustUint32(src.count()) > 0) {
buff, err := os.ReadFile(filepath.Join(dir, src.hash().String()))
@@ -274,6 +295,7 @@ func TestFSTablePersisterConjoinAllDups(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer tr.close()
assertChunksInReader(testChunks, tr, assert)
assert.EqualValues(reps*len(testChunks), mustUint32(tr.count()))
}
+10 -5
View File
@@ -86,29 +86,34 @@ func newFileTableReader(ctx context.Context, dir string, h addr, chunkCount uint
indexOffset := sz - idxSz
r := io.NewSectionReader(f, indexOffset, idxSz)
if int64(int(idxSz)) != idxSz {
err = fmt.Errorf("table file %s/%s is too large to read on this platform. index size %d > max int.", dir, h.String(), idxSz)
return
}
var b []byte
b, err = q.AcquireQuotaBytes(ctx, uint64(idxSz))
b, err = q.AcquireQuotaBytes(ctx, int(idxSz))
if err != nil {
return
}
_, err = io.ReadFull(r, b)
if err != nil {
q.ReleaseQuotaBytes(b)
q.ReleaseQuotaBytes(len(b))
return
}
defer func() {
unrefErr := fc.UnrefFile(path)
if unrefErr != nil {
if unrefErr != nil && err == nil {
q.ReleaseQuotaBytes(len(b))
err = unrefErr
}
}()
ti, err = parseTableIndex(ctx, b, q)
if err != nil {
q.ReleaseQuotaBytes(b)
q.ReleaseQuotaBytes(len(b))
return
}
+1
View File
@@ -56,5 +56,6 @@ func TestMmapTableReader(t *testing.T) {
trc, err := newFileTableReader(ctx, dir, h, uint32(len(chunks)), &UnlimitedQuotaProvider{}, fc)
require.NoError(t, err)
defer trc.close()
assertChunksInReader(chunks, trc, assert)
}
+7 -12
View File
@@ -56,7 +56,6 @@ const (
// both memTable persists and manifest updates to a single file.
type chunkJournal struct {
wr *journalWriter
src journalChunkSource
path string
contents manifestContents
@@ -105,7 +104,7 @@ func (j *chunkJournal) openJournal(ctx context.Context) (err error) {
return err
}
_, j.src, err = j.wr.ProcessJournal(ctx)
_, err = j.wr.bootstrapJournal(ctx)
if err != nil {
return err
}
@@ -117,7 +116,7 @@ func (j *chunkJournal) openJournal(ctx context.Context) (err error) {
}
if ok {
// write the current root hash to the journal file
if err = j.wr.WriteRootHash(contents.root); err != nil {
if err = j.wr.writeRootHash(contents.root); err != nil {
return
}
j.contents = contents
@@ -133,8 +132,7 @@ func (j *chunkJournal) openJournal(ctx context.Context) (err error) {
}
// parse existing journal file
var root hash.Hash
root, j.src, err = j.wr.ProcessJournal(ctx)
root, err := j.wr.bootstrapJournal(ctx)
if err != nil {
return err
}
@@ -178,15 +176,12 @@ func (j *chunkJournal) Persist(ctx context.Context, mt *memTable, haver chunkRea
continue
}
c := chunks.NewChunkWithHash(hash.Hash(*record.a), mt.chunks[*record.a])
cc := ChunkToCompressedChunk(c)
lookup, err := j.wr.WriteChunk(cc)
err := j.wr.writeCompressedChunk(ChunkToCompressedChunk(c))
if err != nil {
return nil, err
}
j.src.lookups.put(*record.a, lookup)
j.src.uncompressedSz += uint64(c.Size())
}
return j.src, nil
return journalChunkSource{journal: j.wr}, nil
}
// ConjoinAll implements tablePersister.
@@ -200,7 +195,7 @@ func (j *chunkJournal) Open(ctx context.Context, name addr, chunkCount uint32, s
if err := j.maybeInit(ctx); err != nil {
return nil, err
}
return j.src, nil
return journalChunkSource{journal: j.wr}, nil
}
return j.persister.Open(ctx, name, chunkCount, stats)
}
@@ -263,7 +258,7 @@ func (j *chunkJournal) Update(ctx context.Context, lastLock addr, next manifestC
}
}
if err := j.wr.WriteRootHash(next.root); err != nil {
if err := j.wr.writeRootHash(next.root); err != nil {
return manifestContents{}, err
}
j.contents = next
+19 -71
View File
@@ -18,7 +18,6 @@ import (
"context"
"fmt"
"io"
"sync"
"golang.org/x/sync/errgroup"
@@ -45,65 +44,28 @@ func rangeFromLookup(l recLookup) Range {
return Range{
// see journalRec for serialization format
Offset: uint64(l.journalOff) + uint64(l.payloadOff),
Length: l.recordLen - (l.payloadOff + recChecksumSz),
Length: l.recordLen - (l.payloadOff + journalRecChecksumSz),
}
}
// lookupMap is a thread-safe collection of recLookups.
type lookupMap struct {
data map[addr]recLookup
lock *sync.RWMutex
}
func newLookupMap() lookupMap {
return lookupMap{
data: make(map[addr]recLookup),
lock: new(sync.RWMutex),
}
}
func (m lookupMap) get(a addr) (l recLookup, ok bool) {
m.lock.RLock()
defer m.lock.RUnlock()
l, ok = m.data[a]
return
}
func (m lookupMap) put(a addr, l recLookup) {
m.lock.Lock()
defer m.lock.Unlock()
m.data[a] = l
return
}
func (m lookupMap) count() int {
m.lock.RLock()
defer m.lock.RUnlock()
return len(m.data)
}
// journalChunkSource is a chunkSource that reads chunks
// from a chunkJournal. Unlike other NBS chunkSources,
// it is not immutable and its set of chunks grows as
// more commits are made to the chunkJournal.
type journalChunkSource struct {
address addr
journal snapshotReader
lookups lookupMap
uncompressedSz uint64
journal *journalWriter
}
var _ chunkSource = journalChunkSource{}
func (s journalChunkSource) has(h addr) (bool, error) {
_, ok := s.lookups.get(h)
return ok, nil
return s.journal.hasAddr(h), nil
}
func (s journalChunkSource) hasMany(addrs []hasRecord) (missing bool, err error) {
for i := range addrs {
a := addrs[i].a
if _, ok := s.lookups.get(*a); ok {
ok := s.journal.hasAddr(*addrs[i].a)
if ok {
addrs[i].has = true
} else {
missing = true
@@ -113,28 +75,11 @@ func (s journalChunkSource) hasMany(addrs []hasRecord) (missing bool, err error)
}
func (s journalChunkSource) getCompressed(_ context.Context, h addr, _ *Stats) (CompressedChunk, error) {
l, ok := s.lookups.get(h)
if !ok {
return CompressedChunk{}, nil
}
buf := make([]byte, l.recordLen)
if _, err := s.journal.ReadAt(buf, l.journalOff); err != nil {
return CompressedChunk{}, nil
}
rec, err := readJournalRecord(buf)
if err != nil {
return CompressedChunk{}, err
} else if h != rec.address {
return CompressedChunk{}, fmt.Errorf("chunk record hash does not match lookup hash (%s != %s)",
h.String(), rec.address.String())
}
return NewCompressedChunk(hash.Hash(h), rec.payload)
return s.journal.getCompressedChunk(h)
}
func (s journalChunkSource) get(ctx context.Context, h addr, stats *Stats) ([]byte, error) {
cc, err := s.getCompressed(ctx, h, stats)
func (s journalChunkSource) get(_ context.Context, h addr, _ *Stats) ([]byte, error) {
cc, err := s.journal.getCompressedChunk(h)
if err != nil {
return nil, err
} else if cc.IsEmpty() {
@@ -181,20 +126,20 @@ func (s journalChunkSource) getManyCompressed(ctx context.Context, _ *errgroup.G
}
func (s journalChunkSource) count() (uint32, error) {
return uint32(s.lookups.count()), nil
return s.journal.recordCount(), nil
}
func (s journalChunkSource) uncompressedLen() (uint64, error) {
return s.uncompressedSz, nil
return s.journal.uncompressedSize(), nil
}
func (s journalChunkSource) hash() addr {
return s.address
return journalAddr
}
// reader implements chunkSource.
func (s journalChunkSource) reader(context.Context) (io.ReadCloser, uint64, error) {
rdr, sz, err := s.journal.Snapshot()
rdr, sz, err := s.journal.snapshot()
return io.NopCloser(rdr), uint64(sz), err
}
@@ -204,12 +149,14 @@ func (s journalChunkSource) getRecordRanges(requests []getRecord) (map[hash.Hash
if req.found {
continue
}
l, ok := s.lookups.get(*req.a)
if !ok {
rng, ok, err := s.journal.getRange(*req.a)
if err != nil {
return nil, err
} else if !ok {
continue
}
req.found = true // update |requests|
ranges[hash.Hash(*req.a)] = rangeFromLookup(l)
ranges[hash.Hash(*req.a)] = rng
}
return ranges, nil
}
@@ -217,7 +164,7 @@ func (s journalChunkSource) getRecordRanges(requests []getRecord) (map[hash.Hash
// size implements chunkSource.
// size returns the total size of the chunkSource: chunks, index, and footer
func (s journalChunkSource) currentSize() uint64 {
return uint64(s.journal.CurrentSize())
return uint64(s.journal.currentSize())
}
// index implements chunkSource.
@@ -230,6 +177,7 @@ func (s journalChunkSource) clone() (chunkSource, error) {
}
func (s journalChunkSource) close() error {
// |s.journal| closed via chunkJournal
return nil
}
+66 -66
View File
@@ -24,7 +24,7 @@ import (
"github.com/dolthub/dolt/go/store/d"
)
// journalRec is a record in a chunk journal. It's serialization format uses
// journalRec is a record in a chunk journal. Its serialization format uses
// uint8 tag prefixes to identify fields and allow for format evolution.
//
// There are two kinds of journalRecs: chunk records and root hash records.
@@ -43,7 +43,7 @@ import (
// offset. See recLookup for more detail.
type journalRec struct {
length uint32
kind recKind
kind journalRecKind
address addr
payload []byte
checksum uint32
@@ -52,7 +52,7 @@ type journalRec struct {
// payloadOffset returns the journalOffset of the payload within the record
// assuming only the checksum field follows the payload.
func (r journalRec) payloadOffset() uint32 {
return r.length - uint32(len(r.payload)+recChecksumSz)
return r.length - uint32(len(r.payload)+journalRecChecksumSz)
}
// uncompressedPayloadSize returns the uncompressed size of the payload.
@@ -63,76 +63,76 @@ func (r journalRec) uncompressedPayloadSize() (sz uint64) {
return
}
type recKind uint8
type journalRecKind uint8
const (
unknownKind recKind = 0
rootHashRecKind recKind = 1
chunkRecKind recKind = 2
unknownJournalRecKind journalRecKind = 0
rootHashJournalRecKind journalRecKind = 1
chunkJournalRecKind journalRecKind = 2
)
type recTag uint8
type journalRecTag uint8
const (
unknownTag recTag = 0
kindTag recTag = 1
addrTag recTag = 2
payloadTag recTag = 3
unknownJournalRecTag journalRecTag = 0
kindJournalRecTag journalRecTag = 1
addrJournalRecTag journalRecTag = 2
payloadJournalRecTag journalRecTag = 3
)
const (
recTagSz = 1
recLenSz = 4
recKindSz = 1
recAddrSz = 20
recChecksumSz = 4
journalRecTagSz = 1
journalRecLenSz = 4
journalRecKindSz = 1
journalRecAddrSz = 20
journalRecChecksumSz = 4
// todo(andy): less arbitrary
recMaxSz = 128 * 1024
journalRecMaxSz = 128 * 1024
)
func chunkRecordSize(c CompressedChunk) (recordSz, payloadOff uint32) {
recordSz += recLenSz
recordSz += recTagSz + recKindSz
recordSz += recTagSz + recAddrSz
recordSz += recTagSz // payload tag
recordSz += journalRecLenSz
recordSz += journalRecTagSz + journalRecKindSz
recordSz += journalRecTagSz + journalRecAddrSz
recordSz += journalRecTagSz // payload tag
payloadOff = recordSz
recordSz += uint32(len(c.FullCompressedChunk))
recordSz += recChecksumSz
recordSz += journalRecChecksumSz
return
}
func rootHashRecordSize() (recordSz int) {
recordSz += recLenSz
recordSz += recTagSz + recKindSz
recordSz += recTagSz + recAddrSz
recordSz += recChecksumSz
recordSz += journalRecLenSz
recordSz += journalRecTagSz + journalRecKindSz
recordSz += journalRecTagSz + journalRecAddrSz
recordSz += journalRecChecksumSz
return
}
func writeChunkRecord(buf []byte, c CompressedChunk) (n uint32) {
// length
l, _ := chunkRecordSize(c)
writeUint(buf[:recLenSz], l)
n += recLenSz
writeUint(buf[:journalRecLenSz], l)
n += journalRecLenSz
// kind
buf[n] = byte(kindTag)
n += recTagSz
buf[n] = byte(chunkRecKind)
n += recKindSz
buf[n] = byte(kindJournalRecTag)
n += journalRecTagSz
buf[n] = byte(chunkJournalRecKind)
n += journalRecKindSz
// address
buf[n] = byte(addrTag)
n += recTagSz
buf[n] = byte(addrJournalRecTag)
n += journalRecTagSz
copy(buf[n:], c.H[:])
n += recAddrSz
n += journalRecAddrSz
// payload
buf[n] = byte(payloadTag)
n += recTagSz
buf[n] = byte(payloadJournalRecTag)
n += journalRecTagSz
copy(buf[n:], c.FullCompressedChunk)
n += uint32(len(c.FullCompressedChunk))
// checksum
writeUint(buf[n:], crc(buf[:n]))
n += recChecksumSz
n += journalRecChecksumSz
d.PanicIfFalse(l == n)
return
}
@@ -140,62 +140,62 @@ func writeChunkRecord(buf []byte, c CompressedChunk) (n uint32) {
func writeRootHashRecord(buf []byte, root addr) (n uint32) {
// length
l := rootHashRecordSize()
writeUint(buf[:recLenSz], uint32(l))
n += recLenSz
writeUint(buf[:journalRecLenSz], uint32(l))
n += journalRecLenSz
// kind
buf[n] = byte(kindTag)
n += recTagSz
buf[n] = byte(rootHashRecKind)
n += recKindSz
buf[n] = byte(kindJournalRecTag)
n += journalRecTagSz
buf[n] = byte(rootHashJournalRecKind)
n += journalRecKindSz
// address
buf[n] = byte(addrTag)
n += recTagSz
buf[n] = byte(addrJournalRecTag)
n += journalRecTagSz
copy(buf[n:], root[:])
n += recAddrSz
n += journalRecAddrSz
// empty payload
// checksum
writeUint(buf[n:], crc(buf[:n]))
n += recChecksumSz
n += journalRecChecksumSz
return
}
func readJournalRecord(buf []byte) (rec journalRec, err error) {
rec.length = readUint(buf)
buf = buf[recLenSz:]
for len(buf) > recChecksumSz {
tag := recTag(buf[0])
buf = buf[recTagSz:]
buf = buf[journalRecLenSz:]
for len(buf) > journalRecChecksumSz {
tag := journalRecTag(buf[0])
buf = buf[journalRecTagSz:]
switch tag {
case kindTag:
rec.kind = recKind(buf[0])
buf = buf[recKindSz:]
case addrTag:
case kindJournalRecTag:
rec.kind = journalRecKind(buf[0])
buf = buf[journalRecKindSz:]
case addrJournalRecTag:
copy(rec.address[:], buf)
buf = buf[recAddrSz:]
case payloadTag:
sz := len(buf) - recChecksumSz
buf = buf[journalRecAddrSz:]
case payloadJournalRecTag:
sz := len(buf) - journalRecChecksumSz
rec.payload = buf[:sz]
buf = buf[sz:]
case unknownTag:
case unknownJournalRecTag:
fallthrough
default:
err = fmt.Errorf("unknown record field tag: %d", tag)
return
}
}
rec.checksum = readUint(buf[:recChecksumSz])
rec.checksum = readUint(buf[:journalRecChecksumSz])
return
}
func validateJournalRecord(buf []byte) (ok bool) {
if len(buf) > (recLenSz + recChecksumSz) {
off := len(buf) - recChecksumSz
if len(buf) > (journalRecLenSz + journalRecChecksumSz) {
off := len(buf) - journalRecChecksumSz
ok = crc(buf[:off]) == readUint(buf[off:])
}
return
}
func processRecords(ctx context.Context, r io.ReadSeeker, cb func(o int64, r journalRec) error) (int64, error) {
func processJournalRecords(ctx context.Context, r io.ReadSeeker, cb func(o int64, r journalRec) error) (int64, error) {
var (
buf []byte
off int64
@@ -210,7 +210,7 @@ func processRecords(ctx context.Context, r io.ReadSeeker, cb func(o int64, r jou
}
l := readUint(buf)
if l > recMaxSz {
if l > journalRecMaxSz {
break
} else if buf, err = rdr.Peek(int(l)); err != nil {
break
+32 -32
View File
@@ -97,7 +97,7 @@ func TestProcessRecords(t *testing.T) {
return
}
n, err := processRecords(ctx, bytes.NewReader(journal), check)
n, err := processJournalRecords(ctx, bytes.NewReader(journal), check)
assert.Equal(t, cnt, i)
assert.Equal(t, int(off), int(n))
require.NoError(t, err)
@@ -105,7 +105,7 @@ func TestProcessRecords(t *testing.T) {
i, sum = 0, 0
// write a bogus record to the end and process again
writeCorruptRecord(journal[off:])
n, err = processRecords(ctx, bytes.NewReader(journal), check)
n, err = processJournalRecords(ctx, bytes.NewReader(journal), check)
assert.Equal(t, cnt, i)
assert.Equal(t, int(off), int(n))
require.NoError(t, err)
@@ -134,29 +134,29 @@ func makeChunkRecord() (journalRec, []byte) {
buf := make([]byte, sz)
// length
writeUint(buf[n:], uint32(len(buf)))
n += recLenSz
n += journalRecLenSz
// kind
buf[n] = byte(kindTag)
n += recTagSz
buf[n] = byte(chunkRecKind)
n += recKindSz
buf[n] = byte(kindJournalRecTag)
n += journalRecTagSz
buf[n] = byte(chunkJournalRecKind)
n += journalRecKindSz
// address
buf[n] = byte(addrTag)
n += recTagSz
buf[n] = byte(addrJournalRecTag)
n += journalRecTagSz
copy(buf[n:], cc.H[:])
n += recAddrSz
n += journalRecAddrSz
// payload
buf[n] = byte(payloadTag)
n += recTagSz
buf[n] = byte(payloadJournalRecTag)
n += journalRecTagSz
copy(buf[n:], payload)
n += len(payload)
// checksum
c := crc(buf[:len(buf)-recChecksumSz])
writeUint(buf[len(buf)-recChecksumSz:], c)
c := crc(buf[:len(buf)-journalRecChecksumSz])
writeUint(buf[len(buf)-journalRecChecksumSz:], c)
r := journalRec{
length: uint32(len(buf)),
kind: chunkRecKind,
kind: chunkJournalRecKind,
address: addr(cc.H),
payload: payload,
checksum: c,
@@ -170,23 +170,23 @@ func makeRootHashRecord() (journalRec, []byte) {
buf := make([]byte, rootHashRecordSize())
// length
writeUint(buf[n:], uint32(len(buf)))
n += recLenSz
n += journalRecLenSz
// kind
buf[n] = byte(kindTag)
n += recTagSz
buf[n] = byte(rootHashRecKind)
n += recKindSz
buf[n] = byte(kindJournalRecTag)
n += journalRecTagSz
buf[n] = byte(rootHashJournalRecKind)
n += journalRecKindSz
// address
buf[n] = byte(addrTag)
n += recTagSz
buf[n] = byte(addrJournalRecTag)
n += journalRecTagSz
copy(buf[n:], a[:])
n += recAddrSz
n += journalRecAddrSz
// checksum
c := crc(buf[:len(buf)-recChecksumSz])
writeUint(buf[len(buf)-recChecksumSz:], c)
c := crc(buf[:len(buf)-journalRecChecksumSz])
writeUint(buf[len(buf)-journalRecChecksumSz:], c)
r := journalRec{
length: uint32(len(buf)),
kind: rootHashRecKind,
kind: rootHashJournalRecKind,
address: a,
checksum: c,
}
@@ -194,13 +194,13 @@ func makeRootHashRecord() (journalRec, []byte) {
}
func makeUnknownTagRecord() (buf []byte) {
const fakeTag recTag = 111
const fakeTag journalRecTag = 111
_, buf = makeRootHashRecord()
// overwrite recKind
buf[recLenSz] = byte(fakeTag)
buf[journalRecLenSz] = byte(fakeTag)
// redo checksum
c := crc(buf[:len(buf)-recChecksumSz])
writeUint(buf[len(buf)-recChecksumSz:], c)
c := crc(buf[:len(buf)-journalRecChecksumSz])
writeUint(buf[len(buf)-journalRecChecksumSz:], c)
return
}
@@ -210,12 +210,12 @@ func writeCorruptRecord(buf []byte) (n uint32) {
rand.Read(buf[:n])
// write a valid size, kind
writeUint(buf, n)
buf[recLenSz] = byte(rootHashRecKind)
buf[journalRecLenSz] = byte(rootHashJournalRecKind)
return
}
func mustCompressedChunk(rec journalRec) CompressedChunk {
d.PanicIfFalse(rec.kind == chunkRecKind)
d.PanicIfFalse(rec.kind == chunkJournalRecKind)
cc, err := NewCompressedChunk(hash.Hash(rec.address), rec.payload)
d.PanicIfError(err)
return cc
+1 -1
View File
@@ -96,7 +96,7 @@ func TestReadRecordRanges(t *testing.T) {
jcs, err := j.Persist(ctx, mt, emptyChunkSource{}, &Stats{})
require.NoError(t, err)
rdr, sz, err := jcs.(journalChunkSource).journal.Snapshot()
rdr, sz, err := jcs.(journalChunkSource).journal.snapshot()
require.NoError(t, err)
buf = make([]byte, sz)
+199 -147
View File
@@ -78,9 +78,10 @@ func openJournalWriter(ctx context.Context, path string) (wr *journalWriter, exi
}
return &journalWriter{
buf: make([]byte, 0, journalWriterBuffSize),
file: f,
path: path,
buf: make([]byte, 0, journalWriterBuffSize),
lookups: make(map[addr]recLookup),
file: f,
path: path,
}, true, nil
}
@@ -117,34 +118,139 @@ func createJournalWriter(ctx context.Context, path string) (wr *journalWriter, e
}
return &journalWriter{
buf: make([]byte, 0, journalWriterBuffSize),
file: f,
path: path,
buf: make([]byte, 0, journalWriterBuffSize),
lookups: make(map[addr]recLookup),
file: f,
path: path,
}, nil
}
type snapshotReader interface {
io.ReaderAt
// Snapshot returns an io.Reader that provides a consistent view
// of the current state of the snapshotReader.
Snapshot() (io.Reader, int64, error)
// currentSize returns the current size.
CurrentSize() int64
}
type journalWriter struct {
buf []byte
file *os.File
off int64
path string
lock sync.RWMutex
buf []byte
lookups map[addr]recLookup
file *os.File
off int64
uncmpSz uint64
path string
lock sync.RWMutex
}
var _ io.WriteCloser = &journalWriter{}
var _ snapshotReader = &journalWriter{}
var _ io.Closer = &journalWriter{}
func (wr *journalWriter) ReadAt(p []byte, off int64) (n int, err error) {
// bootstrapJournal reads the journal file collecting a recLookup for each record and
// returning the latest committed root hash.
func (wr *journalWriter) bootstrapJournal(ctx context.Context) (last hash.Hash, err error) {
wr.lock.Lock()
defer wr.lock.Unlock()
wr.off, err = processJournalRecords(ctx, wr.file, func(o int64, r journalRec) error {
switch r.kind {
case chunkJournalRecKind:
wr.lookups[r.address] = recLookup{
journalOff: o,
recordLen: r.length,
payloadOff: r.payloadOffset(),
}
wr.uncmpSz += r.uncompressedPayloadSize()
case rootHashJournalRecKind:
last = hash.Hash(r.address)
default:
return fmt.Errorf("unknown journal record kind (%d)", r.kind)
}
return nil
})
if err != nil {
return hash.Hash{}, err
}
return
}
// hasAddr returns true if the journal contains a chunk with addr |h|.
func (wr *journalWriter) hasAddr(h addr) (ok bool) {
wr.lock.RLock()
defer wr.lock.RUnlock()
_, ok = wr.lookups[h]
return
}
// getCompressedChunk reads the CompressedChunks with addr |h|.
func (wr *journalWriter) getCompressedChunk(h addr) (CompressedChunk, error) {
wr.lock.RLock()
defer wr.lock.RUnlock()
l, ok := wr.lookups[h]
if !ok {
return CompressedChunk{}, nil
}
buf := make([]byte, l.recordLen)
if _, err := wr.readAt(buf, l.journalOff); err != nil {
return CompressedChunk{}, nil
}
rec, err := readJournalRecord(buf)
if err != nil {
return CompressedChunk{}, err
} else if h != rec.address {
err = fmt.Errorf("chunk record hash does not match (%s != %s)",
h.String(), rec.address.String())
return CompressedChunk{}, err
}
return NewCompressedChunk(hash.Hash(h), rec.payload)
}
// getRange returns a Range for the chunk with addr |h|.
func (wr *journalWriter) getRange(h addr) (rng Range, ok bool, err error) {
// callers will use |rng| to read directly from the
// journal file, so we must flush here
if err = wr.maybeFlush(); err != nil {
return
}
wr.lock.RLock()
defer wr.lock.RUnlock()
var l recLookup
l, ok = wr.lookups[h]
if ok {
rng = rangeFromLookup(l)
}
return
}
// writeCompressedChunk writes |cc| to the journal.
func (wr *journalWriter) writeCompressedChunk(cc CompressedChunk) error {
wr.lock.Lock()
defer wr.lock.Unlock()
l, o := chunkRecordSize(cc)
rec := recLookup{
journalOff: wr.offset(),
recordLen: l,
payloadOff: o,
}
buf, err := wr.getBytes(int(rec.recordLen))
if err != nil {
return err
}
_ = writeChunkRecord(buf, cc)
wr.lookups[addr(cc.H)] = rec
return nil
}
// writeRootHash commits |root| to the journal and syncs the file to disk.
func (wr *journalWriter) writeRootHash(root hash.Hash) error {
wr.lock.Lock()
defer wr.lock.Unlock()
buf, err := wr.getBytes(rootHashRecordSize())
if err != nil {
return err
}
_ = writeRootHashRecord(buf, addr(root))
if err = wr.flush(); err != nil {
return err
}
return wr.file.Sync()
}
// readAt reads len(p) bytes from the journal at offset |off|.
func (wr *journalWriter) readAt(p []byte, off int64) (n int, err error) {
wr.lock.RLock()
defer wr.lock.RUnlock()
var bp []byte
@@ -169,129 +275,7 @@ func (wr *journalWriter) ReadAt(p []byte, off int64) (n int, err error) {
return
}
func (wr *journalWriter) Snapshot() (io.Reader, int64, error) {
wr.lock.Lock()
defer wr.lock.Unlock()
if err := wr.flush(); err != nil {
return nil, 0, err
}
// open a new file descriptor with an
// independent lifecycle from |wr.file|
f, err := os.Open(wr.path)
if err != nil {
return nil, 0, err
}
return io.LimitReader(f, wr.off), wr.off, nil
}
func (wr *journalWriter) CurrentSize() int64 {
wr.lock.RLock()
defer wr.lock.RUnlock()
return wr.off
}
func (wr *journalWriter) Write(p []byte) (n int, err error) {
wr.lock.Lock()
defer wr.lock.Unlock()
if len(p) > len(wr.buf) {
// write directly to |wr.file|
if err = wr.flush(); err != nil {
return 0, err
}
n, err = wr.file.WriteAt(p, wr.off)
wr.off += int64(n)
return
}
var buf []byte
if buf, err = wr.getBytes(len(p)); err != nil {
return 0, err
}
n = copy(buf, p)
return
}
func (wr *journalWriter) ProcessJournal(ctx context.Context) (last hash.Hash, cs journalChunkSource, err error) {
wr.lock.Lock()
defer wr.lock.Unlock()
src := journalChunkSource{
journal: wr,
address: journalAddr,
lookups: newLookupMap(),
}
wr.off, err = processRecords(ctx, wr.file, func(o int64, r journalRec) error {
switch r.kind {
case chunkRecKind:
src.lookups.put(r.address, recLookup{
journalOff: o,
recordLen: r.length,
payloadOff: r.payloadOffset(),
})
src.uncompressedSz += r.uncompressedPayloadSize()
case rootHashRecKind:
last = hash.Hash(r.address)
default:
return fmt.Errorf("unknown journal record kind (%d)", r.kind)
}
return nil
})
if err != nil {
return hash.Hash{}, journalChunkSource{}, err
}
cs = src
return
}
func (wr *journalWriter) WriteChunk(cc CompressedChunk) (recLookup, error) {
wr.lock.Lock()
defer wr.lock.Unlock()
l, o := chunkRecordSize(cc)
rec := recLookup{
journalOff: wr.offset(),
recordLen: l,
payloadOff: o,
}
buf, err := wr.getBytes(int(rec.recordLen))
if err != nil {
return recLookup{}, err
}
_ = writeChunkRecord(buf, cc)
return rec, nil
}
func (wr *journalWriter) WriteRootHash(root hash.Hash) error {
wr.lock.Lock()
defer wr.lock.Unlock()
buf, err := wr.getBytes(rootHashRecordSize())
if err != nil {
return err
}
_ = writeRootHashRecord(buf, addr(root))
if err = wr.flush(); err != nil {
return err
}
return wr.file.Sync()
}
func (wr *journalWriter) Close() (err error) {
wr.lock.Lock()
defer wr.lock.Unlock()
if err = wr.flush(); err != nil {
return err
}
if cerr := wr.file.Sync(); cerr != nil {
err = cerr
}
if cerr := wr.file.Close(); cerr != nil {
err = cerr
}
return
}
func (wr *journalWriter) offset() int64 {
return wr.off + int64(len(wr.buf))
}
// getBytes returns a buffer for writers to copy data into.
func (wr *journalWriter) getBytes(n int) (buf []byte, err error) {
c, l := cap(wr.buf), len(wr.buf)
if n > c {
@@ -308,6 +292,7 @@ func (wr *journalWriter) getBytes(n int) (buf []byte, err error) {
return
}
// flush writes buffered data into the journal file.
func (wr *journalWriter) flush() (err error) {
if _, err = wr.file.WriteAt(wr.buf, wr.off); err != nil {
return err
@@ -316,3 +301,70 @@ func (wr *journalWriter) flush() (err error) {
wr.buf = wr.buf[:0]
return
}
// maybeFlush flushes buffered data, if any exists.
func (wr *journalWriter) maybeFlush() (err error) {
wr.lock.RLock()
empty := len(wr.buf) == 0
wr.lock.RUnlock()
if empty {
return
}
wr.lock.Lock()
defer wr.lock.Unlock()
return wr.flush()
}
// snapshot returns an io.Reader with a consistent view of
// the current state of the journal file.
func (wr *journalWriter) snapshot() (io.Reader, int64, error) {
wr.lock.Lock()
defer wr.lock.Unlock()
if err := wr.flush(); err != nil {
return nil, 0, err
}
// open a new file descriptor with an
// independent lifecycle from |wr.file|
f, err := os.Open(wr.path)
if err != nil {
return nil, 0, err
}
return io.LimitReader(f, wr.off), wr.off, nil
}
func (wr *journalWriter) offset() int64 {
return wr.off + int64(len(wr.buf))
}
func (wr *journalWriter) currentSize() int64 {
wr.lock.RLock()
defer wr.lock.RUnlock()
return wr.offset()
}
func (wr *journalWriter) uncompressedSize() uint64 {
wr.lock.RLock()
defer wr.lock.RUnlock()
return wr.uncmpSz
}
func (wr *journalWriter) recordCount() uint32 {
wr.lock.RLock()
defer wr.lock.RUnlock()
return uint32(len(wr.lookups))
}
func (wr *journalWriter) Close() (err error) {
wr.lock.Lock()
defer wr.lock.Unlock()
if err = wr.flush(); err != nil {
return err
}
if cerr := wr.file.Sync(); cerr != nil {
err = cerr
}
if cerr := wr.file.Close(); cerr != nil {
err = cerr
}
return
}
+22 -29
View File
@@ -123,17 +123,6 @@ func TestJournalWriter(t *testing.T) {
{kind: readOp, buf: []byte("loremipsumdolorsitamet"), readAt: 0},
},
},
{
name: "write larger that buffer",
size: 8,
ops: []operation{
{kind: writeOp, buf: []byte("loremipsum")},
{kind: flushOp},
{kind: writeOp, buf: []byte("dolorsitamet")},
{kind: readOp, buf: []byte("dolorsitamet"), readAt: 10},
{kind: readOp, buf: []byte("loremipsumdolorsitamet"), readAt: 0},
},
},
{
name: "flush empty buffer",
size: 16,
@@ -160,19 +149,23 @@ func TestJournalWriter(t *testing.T) {
j, err := createJournalWriter(ctx, newTestFilePath(t))
require.NotNil(t, j)
require.NoError(t, err)
// set specific buffer size
j.buf = make([]byte, 0, test.size)
var off int64
for i, op := range test.ops {
switch op.kind {
case readOp:
act := make([]byte, len(op.buf))
n, err := j.ReadAt(act, op.readAt)
n, err := j.readAt(act, op.readAt)
assert.NoError(t, err, "operation %d errored", i)
assert.Equal(t, len(op.buf), n, "operation %d failed", i)
assert.Equal(t, op.buf, act, "operation %d failed", i)
case writeOp:
n, err := j.Write(op.buf)
assert.NoError(t, err, "operation %d errored", i)
var p []byte
p, err = j.getBytes(len(op.buf))
require.NoError(t, err, "operation %d errored", i)
n := copy(p, op.buf)
assert.Equal(t, len(op.buf), n, "operation %d failed", i)
off += int64(n)
case flushOp:
@@ -188,22 +181,21 @@ func TestJournalWriter(t *testing.T) {
}
}
func TestJournalWriterWriteChunk(t *testing.T) {
func TestJournalWriterWriteCompressedChunk(t *testing.T) {
ctx := context.Background()
j, err := createJournalWriter(ctx, newTestFilePath(t))
require.NotNil(t, j)
require.NoError(t, err)
data := randomCompressedChunks()
lookups := make(map[addr]recLookup)
for a, cc := range data {
l, err := j.WriteChunk(cc)
err = j.writeCompressedChunk(cc)
require.NoError(t, err)
lookups[a] = l
l := j.lookups[a]
validateLookup(t, j, l, cc)
}
for a, l := range lookups {
for a, l := range j.lookups {
validateLookup(t, j, l, data[a])
}
require.NoError(t, j.Close())
@@ -217,22 +209,22 @@ func TestJournalWriterBootstrap(t *testing.T) {
require.NoError(t, err)
data := randomCompressedChunks()
lookups := make(map[addr]recLookup)
for a, cc := range data {
l, err := j.WriteChunk(cc)
for _, cc := range data {
err = j.writeCompressedChunk(cc)
require.NoError(t, err)
lookups[a] = l
}
assert.NoError(t, j.Close())
j, _, err = openJournalWriter(ctx, path)
require.NoError(t, err)
_, source, err := j.ProcessJournal(ctx)
_, err = j.bootstrapJournal(ctx)
require.NoError(t, err)
for a, l := range lookups {
for a, l := range j.lookups {
validateLookup(t, j, l, data[a])
}
source := journalChunkSource{journal: j}
for a, cc := range data {
buf, err := source.get(ctx, a, nil)
require.NoError(t, err)
@@ -245,7 +237,7 @@ func TestJournalWriterBootstrap(t *testing.T) {
func validateLookup(t *testing.T, j *journalWriter, l recLookup, cc CompressedChunk) {
b := make([]byte, l.recordLen)
n, err := j.ReadAt(b, l.journalOff)
n, err := j.readAt(b, l.journalOff)
require.NoError(t, err)
assert.Equal(t, int(l.recordLen), n)
rec, err := readJournalRecord(b)
@@ -259,13 +251,14 @@ func TestJournalWriterSyncClose(t *testing.T) {
j, err := createJournalWriter(ctx, newTestFilePath(t))
require.NotNil(t, j)
require.NoError(t, err)
_, _, err = j.ProcessJournal(ctx)
_, err = j.bootstrapJournal(ctx)
require.NoError(t, err)
// close triggers flush
n, err := j.Write([]byte("sit"))
p := []byte("sit")
buf, err := j.getBytes(len(p))
require.NoError(t, err)
assert.Equal(t, 3, n)
copy(buf, p)
err = j.Close()
require.NoError(t, err)
assert.Equal(t, 0, len(j.buf))
+3
View File
@@ -156,6 +156,7 @@ func TestMemTableWrite(t *testing.T) {
require.NoError(t, err)
tr1, err := newTableReader(ti1, tableReaderAtFromBytes(td1), fileBlockSize)
require.NoError(t, err)
defer tr1.close()
assert.True(tr1.has(computeAddr(chunks[1])))
td2, _, err := buildTable(chunks[2:])
@@ -164,6 +165,7 @@ func TestMemTableWrite(t *testing.T) {
require.NoError(t, err)
tr2, err := newTableReader(ti2, tableReaderAtFromBytes(td2), fileBlockSize)
require.NoError(t, err)
defer tr2.close()
assert.True(tr2.has(computeAddr(chunks[2])))
_, data, count, err := mt.write(chunkReaderGroup{tr1, tr2}, &Stats{})
@@ -174,6 +176,7 @@ func TestMemTableWrite(t *testing.T) {
require.NoError(t, err)
outReader, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize)
require.NoError(t, err)
defer outReader.close()
assert.True(outReader.has(computeAddr(chunks[0])))
assert.False(outReader.has(computeAddr(chunks[1])))
assert.False(outReader.has(computeAddr(chunks[2])))
+10 -9
View File
@@ -20,8 +20,8 @@ import (
)
type MemoryQuotaProvider interface {
AcquireQuotaBytes(ctx context.Context, sz uint64) ([]byte, error)
ReleaseQuotaBytes(buf []byte) error
AcquireQuotaBytes(ctx context.Context, sz int) ([]byte, error)
ReleaseQuotaBytes(sz int)
Usage() uint64
}
@@ -34,23 +34,24 @@ func NewUnlimitedMemQuotaProvider() *UnlimitedQuotaProvider {
return &UnlimitedQuotaProvider{}
}
func (q *UnlimitedQuotaProvider) AcquireQuotaBytes(ctx context.Context, sz uint64) ([]byte, error) {
func (q *UnlimitedQuotaProvider) AcquireQuotaBytes(ctx context.Context, sz int) ([]byte, error) {
buf := make([]byte, sz)
q.mu.Lock()
defer q.mu.Unlock()
q.used += sz
q.used += uint64(sz)
return buf, nil
}
func (q *UnlimitedQuotaProvider) ReleaseQuotaBytes(buf []byte) error {
func (q *UnlimitedQuotaProvider) ReleaseQuotaBytes(sz int) {
if sz < 0 {
panic("tried to release negative bytes")
}
q.mu.Lock()
defer q.mu.Unlock()
memory := uint64(len(buf))
if memory > q.used {
if uint64(sz) > q.used {
panic("tried to release too much quota")
}
q.used -= memory
return nil
q.used -= uint64(sz)
}
func (q *UnlimitedQuotaProvider) Usage() uint64 {
+1 -1
View File
@@ -78,12 +78,12 @@ func (m *fakeS3) readerForTable(ctx context.Context, name addr) (chunkReader, er
defer m.mu.Unlock()
if buff, present := m.data[name.String()]; present {
ti, err := parseTableIndexByCopy(ctx, buff, &UnlimitedQuotaProvider{})
if err != nil {
return nil, err
}
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), s3BlockSize)
if err != nil {
ti.Close()
return nil, err
}
return tr, nil
+2 -2
View File
@@ -1625,9 +1625,9 @@ func (nbs *NomsBlockStore) swapTables(ctx context.Context, specs []tableSpec) (e
if err != nil {
return err
}
oldTables := nbs.tables
nbs.tables, nbs.upstream = ts, upstream
return nil
return oldTables.close()
}
// SetRootChunk changes the root chunk hash from the previous value to the new root.
+3
View File
@@ -124,6 +124,7 @@ func TestNBSAsTableFileStore(t *testing.T) {
func TestConcurrentPuts(t *testing.T) {
st, _, _ := makeTestLocalStore(t, 100)
defer st.Close()
errgrp, ctx := errgroup.WithContext(context.Background())
@@ -190,6 +191,7 @@ func TestNBSPruneTableFiles(t *testing.T) {
maxTableFiles := 16
st, nomsDir, _ := makeTestLocalStore(t, maxTableFiles)
fileToData := populateLocalStore(t, st, numTableFiles)
defer st.Close()
// add a chunk and flush to trigger a conjoin
c := chunks.NewChunk([]byte("it's a boy!"))
@@ -272,6 +274,7 @@ func makeChunkSet(N, size int) (s map[hash.Hash]chunks.Chunk) {
func TestNBSCopyGC(t *testing.T) {
ctx := context.Background()
st, _, _ := makeTestLocalStore(t, 8)
defer st.Close()
keepers := makeChunkSet(64, 64)
tossers := makeChunkSet(64, 64)
+45 -20
View File
@@ -19,7 +19,11 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"runtime"
"runtime/debug"
"sync/atomic"
"github.com/dolthub/dolt/go/libraries/utils/iohelp"
@@ -31,6 +35,25 @@ var (
ErrWrongCopySize = errors.New("could not copy enough bytes")
)
// By setting the environment variable DOLT_ASSERT_TABLE_FILES_CLOSED to any
// non-empty string, dolt will run some sanity checks on table file lifecycle
// management. In particular, dolt will install a GC finalizer on the table
// file index buffer to assert that it has been properly closed at the time
// that it gets garbage collected.
//
// This is mostly intended for developers. It isa recommended mode in tests and
// can make sense in other contexts as well. At the time of this writing---
// (2023/02, aaron@)---lifecycle management in tests in particular is not good
// enough to globally enable this.
var TableIndexAssertClosedWithGCFinalizer bool
func init() {
if os.Getenv("DOLT_ASSERT_TABLE_FILES_CLOSED") != "" {
TableIndexAssertClosedWithGCFinalizer = true
}
}
type tableIndex interface {
// entrySuffixMatches returns true if the entry at index |idx| matches
// the suffix of the address |h|. Used by |lookup| after finding
@@ -109,13 +132,13 @@ func parseTableIndex(ctx context.Context, buff []byte, q MemoryQuotaProvider) (o
chunks2 := chunkCount / 2
chunks1 := chunkCount - chunks2
offsetsBuff1, err := q.AcquireQuotaBytes(ctx, uint64(chunks1*offsetSize))
offsetsBuff1, err := q.AcquireQuotaBytes(ctx, int(chunks1*offsetSize))
if err != nil {
return onHeapTableIndex{}, err
}
idx, err := newOnHeapTableIndex(buff, offsetsBuff1, chunkCount, totalUncompressedData, q)
if err != nil {
q.ReleaseQuotaBytes(offsetsBuff1)
q.ReleaseQuotaBytes(len(offsetsBuff1))
}
return idx, err
}
@@ -150,28 +173,32 @@ func readTableIndexByCopy(ctx context.Context, rd io.ReadSeeker, q MemoryQuotaPr
return onHeapTableIndex{}, err
}
buff, err := q.AcquireQuotaBytes(ctx, uint64(idxSz))
if int64(int(idxSz)) != idxSz {
return onHeapTableIndex{}, fmt.Errorf("table file index is too large to read on this platform. index size %d > max int.", idxSz)
}
buff, err := q.AcquireQuotaBytes(ctx, int(idxSz))
if err != nil {
return onHeapTableIndex{}, err
}
_, err = io.ReadFull(rd, buff)
if err != nil {
q.ReleaseQuotaBytes(buff)
q.ReleaseQuotaBytes(len(buff))
return onHeapTableIndex{}, err
}
chunks1 := chunkCount - (chunkCount / 2)
offsets1Buff, err := q.AcquireQuotaBytes(ctx, uint64(chunks1*offsetSize))
offsets1Buff, err := q.AcquireQuotaBytes(ctx, int(chunks1*offsetSize))
if err != nil {
q.ReleaseQuotaBytes(buff)
q.ReleaseQuotaBytes(len(buff))
return onHeapTableIndex{}, err
}
idx, err := newOnHeapTableIndex(buff, offsets1Buff, chunkCount, totalUncompressedData, q)
if err != nil {
q.ReleaseQuotaBytes(buff)
q.ReleaseQuotaBytes(offsets1Buff)
q.ReleaseQuotaBytes(len(buff))
q.ReleaseQuotaBytes(len(offsets1Buff))
}
return idx, err
}
@@ -251,6 +278,12 @@ func newOnHeapTableIndex(indexBuff []byte, offsetsBuff1 []byte, count uint32, to
refCnt := new(int32)
*refCnt = 1
if TableIndexAssertClosedWithGCFinalizer {
stack := string(debug.Stack())
runtime.SetFinalizer(refCnt, func(i *int32) {
panic(fmt.Sprintf("OnHeapTableIndex not closed:\n%s", stack))
})
}
return onHeapTableIndex{
refCnt: refCnt,
@@ -506,19 +539,11 @@ func (ti onHeapTableIndex) Close() error {
return nil
}
if err := ti.q.ReleaseQuotaBytes(ti.prefixTuples); err != nil {
return err
if TableIndexAssertClosedWithGCFinalizer {
runtime.SetFinalizer(ti.refCnt, nil)
}
if err := ti.q.ReleaseQuotaBytes(ti.offsets1); err != nil {
return err
}
if err := ti.q.ReleaseQuotaBytes(ti.offsets2); err != nil {
return err
}
if err := ti.q.ReleaseQuotaBytes(ti.suffixes); err != nil {
return err
}
return ti.q.ReleaseQuotaBytes(ti.footer)
ti.q.ReleaseQuotaBytes(len(ti.prefixTuples) + len(ti.offsets1) + len(ti.offsets2) + len(ti.suffixes) + len(ti.footer))
return nil
}
func (ti onHeapTableIndex) clone() (tableIndex, error) {
+3
View File
@@ -129,6 +129,7 @@ func TestResolveOneHash(t *testing.T) {
td, _, err := buildTable(chunks)
tIdx, err := parseTableIndexByCopy(ctx, td, &UnlimitedQuotaProvider{})
require.NoError(t, err)
defer tIdx.Close()
// get hashes out
hashes := make([]string, len(chunks))
@@ -161,6 +162,7 @@ func TestResolveFewHash(t *testing.T) {
td, _, err := buildTable(chunks)
tIdx, err := parseTableIndexByCopy(ctx, td, &UnlimitedQuotaProvider{})
require.NoError(t, err)
defer tIdx.Close()
// get hashes out
hashes := make([]string, len(chunks))
@@ -194,6 +196,7 @@ func TestAmbiguousShortHash(t *testing.T) {
td, _, err := buildFakeChunkTable(chunks)
idx, err := parseTableIndexByCopy(ctx, td, &UnlimitedQuotaProvider{})
require.NoError(t, err)
defer idx.Close()
tests := []struct {
pre string
+10 -2
View File
@@ -38,6 +38,8 @@ func TestPlanCompaction(t *testing.T) {
{[]byte("solo")},
}
q := &UnlimitedQuotaProvider{}
var sources chunkSources
var dataLens []uint64
var totalUnc uint64
@@ -47,7 +49,7 @@ func TestPlanCompaction(t *testing.T) {
}
data, name, err := buildTable(content)
require.NoError(t, err)
ti, err := parseTableIndexByCopy(ctx, data, &UnlimitedQuotaProvider{})
ti, err := parseTableIndexByCopy(ctx, data, q)
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(data), fileBlockSize)
require.NoError(t, err)
@@ -55,6 +57,11 @@ func TestPlanCompaction(t *testing.T) {
dataLens = append(dataLens, uint64(len(data))-indexSize(mustUint32(src.count()))-footerSize)
sources = append(sources, src)
}
defer func() {
for _, s := range sources {
s.close()
}
}()
plan, err := planRangeCopyConjoin(sources, &Stats{})
require.NoError(t, err)
@@ -65,7 +72,7 @@ func TestPlanCompaction(t *testing.T) {
totalChunks += mustUint32(src.count())
}
idx, err := parseTableIndex(ctx, plan.mergedIndex, &UnlimitedQuotaProvider{})
idx, err := parseTableIndexByCopy(ctx, plan.mergedIndex, q)
require.NoError(t, err)
assert.Equal(totalChunks, idx.chunkCount())
@@ -73,6 +80,7 @@ func TestPlanCompaction(t *testing.T) {
tr, err := newTableReader(idx, tableReaderAtFromBytes(nil), fileBlockSize)
require.NoError(t, err)
defer tr.close()
for _, content := range tableContents {
assertChunksInReader(content, tr, assert)
}
+12
View File
@@ -49,6 +49,9 @@ func TestTableSetPrepend(t *testing.T) {
assert := assert.New(t)
ts := newFakeTableSet(&UnlimitedQuotaProvider{})
specs, err := ts.toSpecs()
defer func() {
ts.close()
}()
require.NoError(t, err)
assert.Empty(specs)
mt := newMemTable(testMemTableSize)
@@ -75,6 +78,9 @@ func TestTableSetPrepend(t *testing.T) {
func TestTableSetToSpecsExcludesEmptyTable(t *testing.T) {
assert := assert.New(t)
ts := newFakeTableSet(&UnlimitedQuotaProvider{})
defer func() {
ts.close()
}()
specs, err := ts.toSpecs()
require.NoError(t, err)
assert.Empty(specs)
@@ -101,6 +107,9 @@ func TestTableSetToSpecsExcludesEmptyTable(t *testing.T) {
func TestTableSetFlattenExcludesEmptyTable(t *testing.T) {
assert := assert.New(t)
ts := newFakeTableSet(&UnlimitedQuotaProvider{})
defer func() {
ts.close()
}()
specs, err := ts.toSpecs()
require.NoError(t, err)
assert.Empty(specs)
@@ -183,6 +192,9 @@ func TestTableSetRebase(t *testing.T) {
func TestTableSetPhysicalLen(t *testing.T) {
assert := assert.New(t)
ts := newFakeTableSet(&UnlimitedQuotaProvider{})
defer func() {
ts.close()
}()
specs, err := ts.toSpecs()
require.NoError(t, err)
assert.Empty(specs)
+9
View File
@@ -83,6 +83,7 @@ func TestSimple(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
assertChunksInReader(chunks, tr, assert)
@@ -131,6 +132,7 @@ func TestHasMany(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
addrs := addrSlice{computeAddr(chunks[0]), computeAddr(chunks[1]), computeAddr(chunks[2])}
hasAddrs := []hasRecord{
@@ -183,6 +185,7 @@ func TestHasManySequentialPrefix(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(buff), fileBlockSize)
require.NoError(t, err)
defer tr.close()
hasAddrs := make([]hasRecord, 2)
// Leave out the first address
@@ -237,6 +240,7 @@ func BenchmarkHasMany(b *testing.B) {
require.NoError(b, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(b, err)
defer tr.close()
b.ResetTimer()
b.Run("dense has many", func(b *testing.B) {
@@ -277,6 +281,7 @@ func TestGetMany(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
addrs := addrSlice{computeAddr(data[0]), computeAddr(data[1]), computeAddr(data[2])}
getBatch := []getRecord{
@@ -312,6 +317,7 @@ func TestCalcReads(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), 0)
require.NoError(t, err)
defer tr.close()
addrs := addrSlice{computeAddr(chunks[0]), computeAddr(chunks[1]), computeAddr(chunks[2])}
getBatch := []getRecord{
{&addrs[0], binary.BigEndian.Uint64(addrs[0][:addrPrefixSize]), false},
@@ -350,6 +356,7 @@ func TestExtract(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
addrs := addrSlice{computeAddr(chunks[0]), computeAddr(chunks[1]), computeAddr(chunks[2])}
@@ -390,6 +397,7 @@ func Test65k(t *testing.T) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
for i := 0; i < count; i++ {
data := dataFn(i)
@@ -444,6 +452,7 @@ func doTestNGetMany(t *testing.T, count int) {
require.NoError(t, err)
tr, err := newTableReader(ti, tableReaderAtFromBytes(tableData), fileBlockSize)
require.NoError(t, err)
defer tr.close()
getBatch := make([]getRecord, len(data))
for i := 0; i < count; i++ {
+7 -7
View File
@@ -155,9 +155,9 @@ func (m ArtifactMap) Pool() pool.BuffPool {
return m.tuples.NodeStore.Pool()
}
func (m ArtifactMap) Editor() ArtifactsEditor {
func (m ArtifactMap) Editor() *ArtifactsEditor {
artKD, artVD := m.Descriptors()
return ArtifactsEditor{
return &ArtifactsEditor{
srcKeyDesc: m.srcKeyDesc,
mut: MutableMap{
tuples: m.tuples.Mutate(),
@@ -316,7 +316,7 @@ type ArtifactsEditor struct {
pool pool.BuffPool
}
func (wr ArtifactsEditor) Add(ctx context.Context, srcKey val.Tuple, theirRootIsh hash.Hash, artType ArtifactType, meta []byte) error {
func (wr *ArtifactsEditor) Add(ctx context.Context, srcKey val.Tuple, theirRootIsh hash.Hash, artType ArtifactType, meta []byte) error {
for i := 0; i < srcKey.Count(); i++ {
wr.artKB.PutRaw(i, srcKey.GetField(i))
}
@@ -344,7 +344,7 @@ func (e *ErrMergeArtifactCollision) Error() string {
// the given will be inserted. Returns true if a violation was replaced. If an
// existing violation exists but has a different |meta.VInfo| value then
// ErrMergeArtifactCollision is a returned.
func (wr ArtifactsEditor) ReplaceConstraintViolation(ctx context.Context, srcKey val.Tuple, theirRootIsh hash.Hash, artType ArtifactType, meta ConstraintViolationMeta) error {
func (wr *ArtifactsEditor) ReplaceConstraintViolation(ctx context.Context, srcKey val.Tuple, theirRootIsh hash.Hash, artType ArtifactType, meta ConstraintViolationMeta) error {
itr, err := wr.mut.IterRange(ctx, PrefixRange(srcKey, wr.srcKeyDesc))
if err != nil {
return err
@@ -406,11 +406,11 @@ func (wr ArtifactsEditor) ReplaceConstraintViolation(ctx context.Context, srcKey
return nil
}
func (wr ArtifactsEditor) Delete(ctx context.Context, key val.Tuple) error {
func (wr *ArtifactsEditor) Delete(ctx context.Context, key val.Tuple) error {
return wr.mut.Delete(ctx, key)
}
func (wr ArtifactsEditor) Flush(ctx context.Context) (ArtifactMap, error) {
func (wr *ArtifactsEditor) Flush(ctx context.Context) (ArtifactMap, error) {
s := message.NewMergeArtifactSerializer(wr.artKB.Desc, wr.NodeStore().Pool())
m, err := wr.mut.flushWithSerializer(ctx, s)
@@ -426,7 +426,7 @@ func (wr ArtifactsEditor) Flush(ctx context.Context) (ArtifactMap, error) {
}, nil
}
func (wr ArtifactsEditor) NodeStore() tree.NodeStore {
func (wr *ArtifactsEditor) NodeStore() tree.NodeStore {
return wr.mut.NodeStore()
}
-4
View File
@@ -63,10 +63,6 @@ func (mut *MutableMap) Map(ctx context.Context) (Map, error) {
}
func (mut *MutableMap) flushWithSerializer(ctx context.Context, s message.Serializer) (Map, error) {
if err := mut.Checkpoint(ctx); err != nil {
return Map{}, err
}
sm := mut.tuples.StaticMap
fn := tree.ApplyMutations[val.Tuple, val.TupleDesc, message.Serializer]
+9 -1
View File
@@ -239,14 +239,18 @@ func TestHref(t *testing.T) {
sp, _ := ForDatabase("aws://table/foo/bar/baz")
assert.Equal("aws://table/foo/bar/baz", sp.Href())
sp.Close()
sp, _ = ForDataset("aws://[table:bucket]/foo/bar/baz::myds")
assert.Equal("aws://[table:bucket]/foo/bar/baz", sp.Href())
sp.Close()
sp, _ = ForPath("aws://[table:bucket]/foo/bar/baz::myds.my.path")
assert.Equal("aws://[table:bucket]/foo/bar/baz", sp.Href())
sp.Close()
sp, err := ForPath("mem::myds.my.path")
assert.NoError(err)
assert.Equal("", sp.Href())
sp.Close()
}
func TestForDatabase(t *testing.T) {
@@ -323,8 +327,9 @@ func TestForDataset(t *testing.T) {
validDatasetNames := []string{"a", "Z", "0", "/", "-", "_"}
for _, s := range validDatasetNames {
_, err := ForDataset("mem::" + s)
spec, err := ForDataset("mem::" + s)
assert.NoError(t, err)
spec.Close()
}
tmpDir, err := os.MkdirTemp("", "spec_test")
@@ -416,6 +421,8 @@ func TestMultipleSpecsSameNBS(t *testing.T) {
assert.NoError(err1)
assert.NoError(err2)
defer spec1.Close()
defer spec2.Close()
s := types.String("hello")
db := spec1.GetDatabase(context.Background())
@@ -479,6 +486,7 @@ func TestExternalProtocol(t *testing.T) {
sp, err := ForDataset("test:foo::bar")
assert.NoError(err)
defer sp.Close()
assert.Equal("test", sp.Protocol)
assert.Equal("foo", sp.DatabaseName)