From 319f2abb155e0d201ad67690cb2b5c836f90be29 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Tue, 12 Dec 2023 14:54:30 -0800 Subject: [PATCH 1/2] go: Migrate to always use Try accessors on flatbuffer submessage access. --- .../doltcore/branch_control/binlog.go | 5 +- .../doltcore/branch_control/namespace.go | 25 ++++-- go/libraries/doltcore/doltdb/durable/table.go | 25 ++++-- .../doltdb/foreign_key_serialization.go | 5 +- .../doltcore/schema/encoding/serialization.go | 88 ++++++++++++++----- go/store/datas/dataset.go | 15 +++- go/store/prolly/message/commit_closure.go | 10 ++- go/store/types/serial_message.go | 27 +++--- 8 files changed, 149 insertions(+), 51 deletions(-) diff --git a/go/libraries/doltcore/branch_control/binlog.go b/go/libraries/doltcore/branch_control/binlog.go index 46475f3a93..eed231d953 100644 --- a/go/libraries/doltcore/branch_control/binlog.go +++ b/go/libraries/doltcore/branch_control/binlog.go @@ -126,7 +126,10 @@ func (binlog *Binlog) Deserialize(fb *serial.BranchControlBinlog) error { // Read the rows for i := 0; i < fb.RowsLength(); i++ { serialBinlogRow := &serial.BranchControlBinlogRow{} - fb.Rows(serialBinlogRow, i) + _, err := fb.TryRows(serialBinlogRow, i) + if err != nil { + return fmt.Errorf("cannot deserialize binlog, it was created with a later version of Dolt") + } binlog.rows[i] = BinlogRow{ IsInsert: serialBinlogRow.IsInsert(), Database: string(serialBinlogRow.Database()), diff --git a/go/libraries/doltcore/branch_control/namespace.go b/go/libraries/doltcore/branch_control/namespace.go index 1d4f94cb8f..6a11a52fa5 100644 --- a/go/libraries/doltcore/branch_control/namespace.go +++ b/go/libraries/doltcore/branch_control/namespace.go @@ -233,31 +233,46 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error { // Read the databases for i := 0; i < fb.DatabasesLength(); i++ { serialMatchExpr := &serial.BranchControlMatchExpression{} - fb.Databases(serialMatchExpr, i) + _, err = fb.TryDatabases(serialMatchExpr, i) + if err != nil { + return err + } tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr) } // Read the branches for i := 0; i < fb.BranchesLength(); i++ { serialMatchExpr := &serial.BranchControlMatchExpression{} - fb.Branches(serialMatchExpr, i) + _, err = fb.TryBranches(serialMatchExpr, i) + if err != nil { + return err + } tbl.Branches[i] = deserializeMatchExpression(serialMatchExpr) } // Read the users for i := 0; i < fb.UsersLength(); i++ { serialMatchExpr := &serial.BranchControlMatchExpression{} - fb.Users(serialMatchExpr, i) + _, err = fb.TryUsers(serialMatchExpr, i) + if err != nil { + return err + } tbl.Users[i] = deserializeMatchExpression(serialMatchExpr) } // Read the hosts for i := 0; i < fb.HostsLength(); i++ { serialMatchExpr := &serial.BranchControlMatchExpression{} - fb.Hosts(serialMatchExpr, i) + _, err = fb.TryHosts(serialMatchExpr, i) + if err != nil { + return err + } tbl.Hosts[i] = deserializeMatchExpression(serialMatchExpr) } // Read the values for i := 0; i < fb.ValuesLength(); i++ { serialNamespaceValue := &serial.BranchControlNamespaceValue{} - fb.Values(serialNamespaceValue, i) + _, err = fb.TryValues(serialNamespaceValue, i) + if err != nil { + return err + } tbl.Values[i] = NamespaceValue{ Database: string(serialNamespaceValue.Database()), Branch: string(serialNamespaceValue.Branch()), diff --git a/go/libraries/doltcore/doltdb/durable/table.go b/go/libraries/doltcore/doltdb/durable/table.go index badb50ac46..c940f33391 100644 --- a/go/libraries/doltcore/doltdb/durable/table.go +++ b/go/libraries/doltcore/doltdb/durable/table.go @@ -897,7 +897,10 @@ func (t doltDevTable) SetIndexes(ctx context.Context, indexes IndexSet) (Table, } func (t doltDevTable) GetConflicts(ctx context.Context) (conflict.ConflictSchema, ConflictIndex, error) { - conflicts := t.msg.Conflicts(nil) + conflicts, err := t.msg.TryConflicts(nil) + if err != nil { + return conflict.ConflictSchema{}, nil, err + } ouraddr := hash.New(conflicts.OurSchemaBytes()) theiraddr := hash.New(conflicts.TheirSchemaBytes()) @@ -997,7 +1000,10 @@ func (t doltDevTable) SetArtifacts(ctx context.Context, artifacts ArtifactIndex) func (t doltDevTable) HasConflicts(ctx context.Context) (bool, error) { - conflicts := t.msg.Conflicts(nil) + conflicts, err := t.msg.TryConflicts(nil) + if err != nil { + return false, err + } addr := hash.New(conflicts.OurSchemaBytes()) return !addr.IsEmpty(), nil } @@ -1023,7 +1029,10 @@ func (t doltDevTable) SetConflicts(ctx context.Context, sch conflict.ConflictSch } msg := t.clone() - cmsg := msg.Conflicts(nil) + cmsg, err := msg.TryConflicts(nil) + if err != nil { + return nil, err + } copy(cmsg.DataBytes(), conflictsAddr[:]) copy(cmsg.OurSchemaBytes(), ouraddr[:]) copy(cmsg.TheirSchemaBytes(), theiraddr[:]) @@ -1034,7 +1043,10 @@ func (t doltDevTable) SetConflicts(ctx context.Context, sch conflict.ConflictSch func (t doltDevTable) ClearConflicts(ctx context.Context) (Table, error) { msg := t.clone() - conflicts := msg.Conflicts(nil) + conflicts, err := msg.TryConflicts(nil) + if err != nil { + return nil, err + } var emptyhash hash.Hash copy(conflicts.DataBytes(), emptyhash[:]) copy(conflicts.OurSchemaBytes(), emptyhash[:]) @@ -1110,7 +1122,10 @@ func (t doltDevTable) fields() (serialTableFields, error) { } ns := t.ns - conflicts := t.msg.Conflicts(nil) + conflicts, err := t.msg.TryConflicts(nil) + if err != nil { + return serialTableFields{}, err + } am, err := prolly.NewAddressMap(node, ns) if err != nil { return serialTableFields{}, err diff --git a/go/libraries/doltcore/doltdb/foreign_key_serialization.go b/go/libraries/doltcore/doltdb/foreign_key_serialization.go index 51cd7003d0..d3c977ab3b 100644 --- a/go/libraries/doltcore/doltdb/foreign_key_serialization.go +++ b/go/libraries/doltcore/doltdb/foreign_key_serialization.go @@ -100,7 +100,10 @@ func deserializeFlatbufferForeignKeys(msg types.SerialMessage) (*ForeignKeyColle var fk serial.ForeignKey for i := 0; i < c.ForeignKeysLength(); i++ { - c.ForeignKeys(&fk, i) + _, err = c.TryForeignKeys(&fk, i) + if err != nil { + return nil, err + } childCols := make([]uint64, fk.ChildTableColumnsLength()) for j := range childCols { diff --git a/go/libraries/doltcore/schema/encoding/serialization.go b/go/libraries/doltcore/schema/encoding/serialization.go index 5739c328fa..b07316e68c 100644 --- a/go/libraries/doltcore/schema/encoding/serialization.go +++ b/go/libraries/doltcore/schema/encoding/serialization.go @@ -92,7 +92,11 @@ func deserializeSchemaFromFlatbuffer(ctx context.Context, buf []byte) (schema.Sc return nil, err } - err = sch.SetPkOrdinals(deserializeClusteredIndex(s)) + dci, err := deserializeClusteredIndex(s) + if err != nil { + return nil, err + } + err = sch.SetPkOrdinals(dci) if err != nil { return nil, err } @@ -171,18 +175,25 @@ func serializeClusteredIndex(b *fb.Builder, sch schema.Schema) fb.UOffsetT { return serial.IndexEnd(b) } -func deserializeClusteredIndex(s *serial.TableSchema) []int { +func deserializeClusteredIndex(s *serial.TableSchema) ([]int, error) { // check for keyless schema - if keylessSerialSchema(s) { - return nil + kss, err := keylessSerialSchema(s) + if err != nil { + return nil, err + } + if kss { + return nil, nil } - ci := s.ClusteredIndex(nil) + ci, err := s.TryClusteredIndex(nil) + if err != nil { + return nil, err + } pkOrdinals := make([]int, ci.KeyColumnsLength()) for i := range pkOrdinals { pkOrdinals[i] = int(ci.KeyColumns(i)) } - return pkOrdinals + return pkOrdinals, nil } func serializeSchemaColumns(b *fb.Builder, sch schema.Schema) fb.UOffsetT { @@ -281,7 +292,11 @@ func serializeHiddenKeylessColumns(b *fb.Builder) (id, card fb.UOffsetT) { func deserializeColumns(ctx context.Context, s *serial.TableSchema) ([]schema.Column, error) { length := s.ColumnsLength() - if keylessSerialSchema(s) { + isKeyless, err := keylessSerialSchema(s) + if err != nil { + return nil, err + } + if isKeyless { // (6/15/22) // currently, keyless id and cardinality columns // do not exist in schema.Schema @@ -295,7 +310,10 @@ func deserializeColumns(ctx context.Context, s *serial.TableSchema) ([]schema.Co cols := make([]schema.Column, length) c := serial.Column{} for i := range cols { - s.Columns(&c, i) + _, err := s.TryColumns(&c, i) + if err != nil { + return nil, err + } sqlType, err := typeinfoFromSqlType(string(c.SqlType())) if err != nil { return nil, err @@ -395,9 +413,17 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error idx := serial.Index{} col := serial.Column{} for i := 0; i < s.SecondaryIndexesLength(); i++ { - s.SecondaryIndexes(&idx, i) + _, err := s.TrySecondaryIndexes(&idx, i) + if err != nil { + return err + } assertTrue(!idx.PrimaryKey(), "cannot deserialize secondary index with PrimaryKey() == true") + fti, err := deserializeFullTextInfo(&idx) + if err != nil { + return err + } + name := string(idx.Name()) props := schema.IndexProperties{ IsUnique: idx.UniqueKey(), @@ -405,13 +431,16 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error IsFullText: idx.FulltextKey(), IsUserDefined: !idx.SystemDefined(), Comment: string(idx.Comment()), - FullTextProperties: deserializeFullTextInfo(&idx), + FullTextProperties: fti, } tags := make([]uint64, idx.IndexColumnsLength()) for j := range tags { pos := idx.IndexColumns(j) - s.Columns(&col, int(pos)) + _, err := s.TryColumns(&col, int(pos)) + if err != nil { + return err + } tags[j] = col.Tag() } @@ -424,7 +453,7 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error } } - _, err := sch.Indexes().AddIndexByColTags(name, tags, prefixLengths, props) + _, err = sch.Indexes().AddIndexByColTags(name, tags, prefixLengths, props) if err != nil { return err } @@ -455,7 +484,10 @@ func deserializeChecks(sch schema.Schema, s *serial.TableSchema) error { coll := sch.Checks() c := serial.CheckConstraint{} for i := 0; i < s.ChecksLength(); i++ { - s.Checks(&c, i) + _, err := s.TryChecks(&c, i) + if err != nil { + return err + } n, e := string(c.Name()), string(c.Expression()) if _, err := coll.AddCheck(n, e, c.Enforced()); err != nil { return err @@ -493,10 +525,14 @@ func serializeFullTextInfo(b *fb.Builder, idx schema.Index) fb.UOffsetT { return serial.FulltextInfoEnd(b) } -func deserializeFullTextInfo(idx *serial.Index) schema.FullTextProperties { +func deserializeFullTextInfo(idx *serial.Index) (schema.FullTextProperties, error) { fulltext := serial.FulltextInfo{} - if idx.FulltextInfo(&fulltext) == nil { - return schema.FullTextProperties{} + has, err := idx.TryFulltextInfo(&fulltext) + if err != nil { + return schema.FullTextProperties{}, err + } + if has == nil { + return schema.FullTextProperties{}, nil } var keyPositions []uint16 @@ -517,24 +553,27 @@ func deserializeFullTextInfo(idx *serial.Index) schema.FullTextProperties { KeyType: fulltext.KeyType(), KeyName: string(fulltext.KeyName()), KeyPositions: keyPositions, - } + }, nil } -func keylessSerialSchema(s *serial.TableSchema) bool { +func keylessSerialSchema(s *serial.TableSchema) (bool, error) { n := s.ColumnsLength() if n < 2 { - return false + return false, nil } // keyless id is the 2nd to last column // in the columns vector (by convention) // and the only field in key tuples of // the clustered index. id := serial.Column{} - s.Columns(&id, n-2) + _, err := s.TryColumns(&id, n-2) + if err != nil { + return false, err + } ok := id.Generated() && id.Hidden() && string(id.Name()) == keylessIdCol if !ok { - return false + return false, nil } // keyless cardinality is the last column @@ -542,9 +581,12 @@ func keylessSerialSchema(s *serial.TableSchema) bool { // and the first field in value tuples of // the clustered index. card := serial.Column{} - s.Columns(&card, n-1) + _, err = s.TryColumns(&card, n-1) + if err != nil { + return false, err + } return card.Generated() && card.Hidden() && - string(card.Name()) == keylessCardCol + string(card.Name()) == keylessCardCol, nil } func sqlTypeString(t typeinfo.TypeInfo) string { diff --git a/go/store/datas/dataset.go b/go/store/datas/dataset.go index a27114ab28..18dd9e6988 100644 --- a/go/store/datas/dataset.go +++ b/go/store/datas/dataset.go @@ -343,8 +343,12 @@ type serialWorkingSetHead struct { addr hash.Hash } -func newSerialWorkingSetHead(bs []byte, addr hash.Hash) serialWorkingSetHead { - return serialWorkingSetHead{serial.GetRootAsWorkingSet(bs, serial.MessagePrefixSz), addr} +func newSerialWorkingSetHead(bs []byte, addr hash.Hash) (serialWorkingSetHead, error) { + fb, err := serial.TryGetRootAsWorkingSet(bs, serial.MessagePrefixSz) + if err != nil { + return serialWorkingSetHead{}, err + } + return serialWorkingSetHead{fb, addr}, nil } func (h serialWorkingSetHead) TypeName() string { @@ -376,7 +380,10 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) { ret.StagedAddr = new(hash.Hash) *ret.StagedAddr = hash.New(h.msg.StagedRootAddrBytes()) } - mergeState := h.msg.MergeState(nil) + mergeState, err := h.msg.TryMergeState(nil) + if err != nil { + return nil, err + } if mergeState != nil { ret.MergeState = &MergeState{ preMergeWorkingAddr: new(hash.Hash), @@ -503,7 +510,7 @@ func newHead(ctx context.Context, head types.Value, addr hash.Hash) (dsHead, err return newSerialTagHead(data, addr) } if fid == serial.WorkingSetFileID { - return newSerialWorkingSetHead(data, addr), nil + return newSerialWorkingSetHead(data, addr) } if fid == serial.CommitFileID { return newSerialCommitHead(sm, addr), nil diff --git a/go/store/prolly/message/commit_closure.go b/go/store/prolly/message/commit_closure.go index c05eb622fa..cbd6e41855 100644 --- a/go/store/prolly/message/commit_closure.go +++ b/go/store/prolly/message/commit_closure.go @@ -100,12 +100,18 @@ func getCommitClosureSubtrees(msg serial.Message) ([]uint64, error) { return nil, err } counts := make([]uint64, cnt) - m := serial.GetRootAsCommitClosure(msg, serial.MessagePrefixSz) + m, err := serial.TryGetRootAsCommitClosure(msg, serial.MessagePrefixSz) + if err != nil { + return nil, err + } return decodeVarints(m.SubtreeCountsBytes(), counts), nil } func walkCommitClosureAddresses(ctx context.Context, msg serial.Message, cb func(ctx context.Context, addr hash.Hash) error) error { - m := serial.GetRootAsCommitClosure(msg, serial.MessagePrefixSz) + m, err := serial.TryGetRootAsCommitClosure(msg, serial.MessagePrefixSz) + if err != nil { + return err + } arr := m.AddressArrayBytes() for i := 0; i < len(arr)/hash.ByteLen; i++ { addr := hash.New(arr[i*addrSize : (i+1)*addrSize]) diff --git a/go/store/types/serial_message.go b/go/store/types/serial_message.go index 25fb8c860d..329ef7f66c 100644 --- a/go/store/types/serial_message.go +++ b/go/store/types/serial_message.go @@ -65,22 +65,23 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string id := serial.GetFileID(sm) switch id { // NOTE: splunk uses a separate path for some printing + // NOTE: We ignore the errors from field number checks here... case serial.StoreRootFileID: - msg := serial.GetRootAsStoreRoot([]byte(sm), serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsStoreRoot([]byte(sm), serial.MessagePrefixSz) ret := &strings.Builder{} mapbytes := msg.AddressMapBytes() printWithIndendationLevel(level, ret, "StoreRoot{%s}", SerialMessage(mapbytes).humanReadableStringAtIndentationLevel(level+1)) return ret.String() case serial.StashListFileID: - msg := serial.GetRootAsStashList([]byte(sm), serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsStashList([]byte(sm), serial.MessagePrefixSz) ret := &strings.Builder{} mapbytes := msg.AddressMapBytes() printWithIndendationLevel(level, ret, "StashList{%s}", SerialMessage(mapbytes).humanReadableStringAtIndentationLevel(level+1)) return ret.String() case serial.StashFileID: - msg := serial.GetRootAsStash(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsStash(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") printWithIndendationLevel(level, ret, "\tBranchName: %s\n", msg.BranchName()) @@ -90,7 +91,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string printWithIndendationLevel(level, ret, "}") return ret.String() case serial.TagFileID: - msg := serial.GetRootAsTag(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsTag(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name()) @@ -101,7 +102,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string printWithIndendationLevel(level, ret, "}") return ret.String() case serial.WorkingSetFileID: - msg := serial.GetRootAsWorkingSet(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsWorkingSet(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name()) @@ -113,7 +114,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string printWithIndendationLevel(level, ret, "}") return ret.String() case serial.CommitFileID: - msg := serial.GetRootAsCommit(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsCommit(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name()) @@ -150,7 +151,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string printWithIndendationLevel(level, ret, "}") return ret.String() case serial.RootValueFileID: - msg := serial.GetRootAsRootValue(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsRootValue(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") printWithIndendationLevel(level, ret, "\tFeatureVersion: %d\n", msg.FeatureVersion()) @@ -160,7 +161,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string printWithIndendationLevel(level, ret, "}") return ret.String() case serial.TableFileID: - msg := serial.GetRootAsTable(sm, serial.MessagePrefixSz) + msg, _ := serial.TryGetRootAsTable(sm, serial.MessagePrefixSz) ret := &strings.Builder{} printWithIndendationLevel(level, ret, "{\n") @@ -274,7 +275,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er return err } } - mergeState := msg.MergeState(nil) + mergeState, err := msg.TryMergeState(nil) + if err != nil { + return err + } if mergeState != nil { if err = cb(hash.New(mergeState.PreWorkingRootAddrBytes())); err != nil { return err @@ -310,7 +314,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er return err } - confs := msg.Conflicts(nil) + confs, err := msg.TryConflicts(nil) + if err != nil { + return err + } addr := hash.New(confs.DataBytes()) if !addr.IsEmpty() { if err = cb(addr); err != nil { From 0b2b4ffae29b007d68a8d78926c49cda88be4a30 Mon Sep 17 00:00:00 2001 From: reltuk Date: Tue, 12 Dec 2023 23:01:19 +0000 Subject: [PATCH 2/2] [ga-format-pr] Run go/utils/repofmt/format_repo.sh and go/Godeps/update.sh --- go/libraries/doltcore/branch_control/binlog.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/libraries/doltcore/branch_control/binlog.go b/go/libraries/doltcore/branch_control/binlog.go index eed231d953..c8d21945eb 100644 --- a/go/libraries/doltcore/branch_control/binlog.go +++ b/go/libraries/doltcore/branch_control/binlog.go @@ -127,9 +127,9 @@ func (binlog *Binlog) Deserialize(fb *serial.BranchControlBinlog) error { for i := 0; i < fb.RowsLength(); i++ { serialBinlogRow := &serial.BranchControlBinlogRow{} _, err := fb.TryRows(serialBinlogRow, i) - if err != nil { + if err != nil { return fmt.Errorf("cannot deserialize binlog, it was created with a later version of Dolt") - } + } binlog.rows[i] = BinlogRow{ IsInsert: serialBinlogRow.IsInsert(), Database: string(serialBinlogRow.Database()),