Merge pull request #7151 from dolthub/aaron/flatbuffer-try-usage

go: Migrate to always use Try accessors on flatbuffer submessage access.
This commit is contained in:
Aaron Son
2023-12-12 17:58:03 -08:00
committed by GitHub
8 changed files with 149 additions and 51 deletions

View File

@@ -126,7 +126,10 @@ func (binlog *Binlog) Deserialize(fb *serial.BranchControlBinlog) error {
// Read the rows
for i := 0; i < fb.RowsLength(); i++ {
serialBinlogRow := &serial.BranchControlBinlogRow{}
fb.Rows(serialBinlogRow, i)
_, err := fb.TryRows(serialBinlogRow, i)
if err != nil {
return fmt.Errorf("cannot deserialize binlog, it was created with a later version of Dolt")
}
binlog.rows[i] = BinlogRow{
IsInsert: serialBinlogRow.IsInsert(),
Database: string(serialBinlogRow.Database()),

View File

@@ -233,31 +233,46 @@ func (tbl *Namespace) Deserialize(fb *serial.BranchControlNamespace) error {
// Read the databases
for i := 0; i < fb.DatabasesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Databases(serialMatchExpr, i)
_, err = fb.TryDatabases(serialMatchExpr, i)
if err != nil {
return err
}
tbl.Databases[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the branches
for i := 0; i < fb.BranchesLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Branches(serialMatchExpr, i)
_, err = fb.TryBranches(serialMatchExpr, i)
if err != nil {
return err
}
tbl.Branches[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the users
for i := 0; i < fb.UsersLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Users(serialMatchExpr, i)
_, err = fb.TryUsers(serialMatchExpr, i)
if err != nil {
return err
}
tbl.Users[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the hosts
for i := 0; i < fb.HostsLength(); i++ {
serialMatchExpr := &serial.BranchControlMatchExpression{}
fb.Hosts(serialMatchExpr, i)
_, err = fb.TryHosts(serialMatchExpr, i)
if err != nil {
return err
}
tbl.Hosts[i] = deserializeMatchExpression(serialMatchExpr)
}
// Read the values
for i := 0; i < fb.ValuesLength(); i++ {
serialNamespaceValue := &serial.BranchControlNamespaceValue{}
fb.Values(serialNamespaceValue, i)
_, err = fb.TryValues(serialNamespaceValue, i)
if err != nil {
return err
}
tbl.Values[i] = NamespaceValue{
Database: string(serialNamespaceValue.Database()),
Branch: string(serialNamespaceValue.Branch()),

View File

@@ -897,7 +897,10 @@ func (t doltDevTable) SetIndexes(ctx context.Context, indexes IndexSet) (Table,
}
func (t doltDevTable) GetConflicts(ctx context.Context) (conflict.ConflictSchema, ConflictIndex, error) {
conflicts := t.msg.Conflicts(nil)
conflicts, err := t.msg.TryConflicts(nil)
if err != nil {
return conflict.ConflictSchema{}, nil, err
}
ouraddr := hash.New(conflicts.OurSchemaBytes())
theiraddr := hash.New(conflicts.TheirSchemaBytes())
@@ -997,7 +1000,10 @@ func (t doltDevTable) SetArtifacts(ctx context.Context, artifacts ArtifactIndex)
func (t doltDevTable) HasConflicts(ctx context.Context) (bool, error) {
conflicts := t.msg.Conflicts(nil)
conflicts, err := t.msg.TryConflicts(nil)
if err != nil {
return false, err
}
addr := hash.New(conflicts.OurSchemaBytes())
return !addr.IsEmpty(), nil
}
@@ -1023,7 +1029,10 @@ func (t doltDevTable) SetConflicts(ctx context.Context, sch conflict.ConflictSch
}
msg := t.clone()
cmsg := msg.Conflicts(nil)
cmsg, err := msg.TryConflicts(nil)
if err != nil {
return nil, err
}
copy(cmsg.DataBytes(), conflictsAddr[:])
copy(cmsg.OurSchemaBytes(), ouraddr[:])
copy(cmsg.TheirSchemaBytes(), theiraddr[:])
@@ -1034,7 +1043,10 @@ func (t doltDevTable) SetConflicts(ctx context.Context, sch conflict.ConflictSch
func (t doltDevTable) ClearConflicts(ctx context.Context) (Table, error) {
msg := t.clone()
conflicts := msg.Conflicts(nil)
conflicts, err := msg.TryConflicts(nil)
if err != nil {
return nil, err
}
var emptyhash hash.Hash
copy(conflicts.DataBytes(), emptyhash[:])
copy(conflicts.OurSchemaBytes(), emptyhash[:])
@@ -1110,7 +1122,10 @@ func (t doltDevTable) fields() (serialTableFields, error) {
}
ns := t.ns
conflicts := t.msg.Conflicts(nil)
conflicts, err := t.msg.TryConflicts(nil)
if err != nil {
return serialTableFields{}, err
}
am, err := prolly.NewAddressMap(node, ns)
if err != nil {
return serialTableFields{}, err

View File

@@ -100,7 +100,10 @@ func deserializeFlatbufferForeignKeys(msg types.SerialMessage) (*ForeignKeyColle
var fk serial.ForeignKey
for i := 0; i < c.ForeignKeysLength(); i++ {
c.ForeignKeys(&fk, i)
_, err = c.TryForeignKeys(&fk, i)
if err != nil {
return nil, err
}
childCols := make([]uint64, fk.ChildTableColumnsLength())
for j := range childCols {

View File

@@ -92,7 +92,11 @@ func deserializeSchemaFromFlatbuffer(ctx context.Context, buf []byte) (schema.Sc
return nil, err
}
err = sch.SetPkOrdinals(deserializeClusteredIndex(s))
dci, err := deserializeClusteredIndex(s)
if err != nil {
return nil, err
}
err = sch.SetPkOrdinals(dci)
if err != nil {
return nil, err
}
@@ -171,18 +175,25 @@ func serializeClusteredIndex(b *fb.Builder, sch schema.Schema) fb.UOffsetT {
return serial.IndexEnd(b)
}
func deserializeClusteredIndex(s *serial.TableSchema) []int {
func deserializeClusteredIndex(s *serial.TableSchema) ([]int, error) {
// check for keyless schema
if keylessSerialSchema(s) {
return nil
kss, err := keylessSerialSchema(s)
if err != nil {
return nil, err
}
if kss {
return nil, nil
}
ci := s.ClusteredIndex(nil)
ci, err := s.TryClusteredIndex(nil)
if err != nil {
return nil, err
}
pkOrdinals := make([]int, ci.KeyColumnsLength())
for i := range pkOrdinals {
pkOrdinals[i] = int(ci.KeyColumns(i))
}
return pkOrdinals
return pkOrdinals, nil
}
func serializeSchemaColumns(b *fb.Builder, sch schema.Schema) fb.UOffsetT {
@@ -281,7 +292,11 @@ func serializeHiddenKeylessColumns(b *fb.Builder) (id, card fb.UOffsetT) {
func deserializeColumns(ctx context.Context, s *serial.TableSchema) ([]schema.Column, error) {
length := s.ColumnsLength()
if keylessSerialSchema(s) {
isKeyless, err := keylessSerialSchema(s)
if err != nil {
return nil, err
}
if isKeyless {
// (6/15/22)
// currently, keyless id and cardinality columns
// do not exist in schema.Schema
@@ -295,7 +310,10 @@ func deserializeColumns(ctx context.Context, s *serial.TableSchema) ([]schema.Co
cols := make([]schema.Column, length)
c := serial.Column{}
for i := range cols {
s.Columns(&c, i)
_, err := s.TryColumns(&c, i)
if err != nil {
return nil, err
}
sqlType, err := typeinfoFromSqlType(string(c.SqlType()))
if err != nil {
return nil, err
@@ -395,9 +413,17 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error
idx := serial.Index{}
col := serial.Column{}
for i := 0; i < s.SecondaryIndexesLength(); i++ {
s.SecondaryIndexes(&idx, i)
_, err := s.TrySecondaryIndexes(&idx, i)
if err != nil {
return err
}
assertTrue(!idx.PrimaryKey(), "cannot deserialize secondary index with PrimaryKey() == true")
fti, err := deserializeFullTextInfo(&idx)
if err != nil {
return err
}
name := string(idx.Name())
props := schema.IndexProperties{
IsUnique: idx.UniqueKey(),
@@ -405,13 +431,16 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error
IsFullText: idx.FulltextKey(),
IsUserDefined: !idx.SystemDefined(),
Comment: string(idx.Comment()),
FullTextProperties: deserializeFullTextInfo(&idx),
FullTextProperties: fti,
}
tags := make([]uint64, idx.IndexColumnsLength())
for j := range tags {
pos := idx.IndexColumns(j)
s.Columns(&col, int(pos))
_, err := s.TryColumns(&col, int(pos))
if err != nil {
return err
}
tags[j] = col.Tag()
}
@@ -424,7 +453,7 @@ func deserializeSecondaryIndexes(sch schema.Schema, s *serial.TableSchema) error
}
}
_, err := sch.Indexes().AddIndexByColTags(name, tags, prefixLengths, props)
_, err = sch.Indexes().AddIndexByColTags(name, tags, prefixLengths, props)
if err != nil {
return err
}
@@ -455,7 +484,10 @@ func deserializeChecks(sch schema.Schema, s *serial.TableSchema) error {
coll := sch.Checks()
c := serial.CheckConstraint{}
for i := 0; i < s.ChecksLength(); i++ {
s.Checks(&c, i)
_, err := s.TryChecks(&c, i)
if err != nil {
return err
}
n, e := string(c.Name()), string(c.Expression())
if _, err := coll.AddCheck(n, e, c.Enforced()); err != nil {
return err
@@ -493,10 +525,14 @@ func serializeFullTextInfo(b *fb.Builder, idx schema.Index) fb.UOffsetT {
return serial.FulltextInfoEnd(b)
}
func deserializeFullTextInfo(idx *serial.Index) schema.FullTextProperties {
func deserializeFullTextInfo(idx *serial.Index) (schema.FullTextProperties, error) {
fulltext := serial.FulltextInfo{}
if idx.FulltextInfo(&fulltext) == nil {
return schema.FullTextProperties{}
has, err := idx.TryFulltextInfo(&fulltext)
if err != nil {
return schema.FullTextProperties{}, err
}
if has == nil {
return schema.FullTextProperties{}, nil
}
var keyPositions []uint16
@@ -517,24 +553,27 @@ func deserializeFullTextInfo(idx *serial.Index) schema.FullTextProperties {
KeyType: fulltext.KeyType(),
KeyName: string(fulltext.KeyName()),
KeyPositions: keyPositions,
}
}, nil
}
func keylessSerialSchema(s *serial.TableSchema) bool {
func keylessSerialSchema(s *serial.TableSchema) (bool, error) {
n := s.ColumnsLength()
if n < 2 {
return false
return false, nil
}
// keyless id is the 2nd to last column
// in the columns vector (by convention)
// and the only field in key tuples of
// the clustered index.
id := serial.Column{}
s.Columns(&id, n-2)
_, err := s.TryColumns(&id, n-2)
if err != nil {
return false, err
}
ok := id.Generated() && id.Hidden() &&
string(id.Name()) == keylessIdCol
if !ok {
return false
return false, nil
}
// keyless cardinality is the last column
@@ -542,9 +581,12 @@ func keylessSerialSchema(s *serial.TableSchema) bool {
// and the first field in value tuples of
// the clustered index.
card := serial.Column{}
s.Columns(&card, n-1)
_, err = s.TryColumns(&card, n-1)
if err != nil {
return false, err
}
return card.Generated() && card.Hidden() &&
string(card.Name()) == keylessCardCol
string(card.Name()) == keylessCardCol, nil
}
func sqlTypeString(t typeinfo.TypeInfo) string {

View File

@@ -343,8 +343,12 @@ type serialWorkingSetHead struct {
addr hash.Hash
}
func newSerialWorkingSetHead(bs []byte, addr hash.Hash) serialWorkingSetHead {
return serialWorkingSetHead{serial.GetRootAsWorkingSet(bs, serial.MessagePrefixSz), addr}
func newSerialWorkingSetHead(bs []byte, addr hash.Hash) (serialWorkingSetHead, error) {
fb, err := serial.TryGetRootAsWorkingSet(bs, serial.MessagePrefixSz)
if err != nil {
return serialWorkingSetHead{}, err
}
return serialWorkingSetHead{fb, addr}, nil
}
func (h serialWorkingSetHead) TypeName() string {
@@ -376,7 +380,10 @@ func (h serialWorkingSetHead) HeadWorkingSet() (*WorkingSetHead, error) {
ret.StagedAddr = new(hash.Hash)
*ret.StagedAddr = hash.New(h.msg.StagedRootAddrBytes())
}
mergeState := h.msg.MergeState(nil)
mergeState, err := h.msg.TryMergeState(nil)
if err != nil {
return nil, err
}
if mergeState != nil {
ret.MergeState = &MergeState{
preMergeWorkingAddr: new(hash.Hash),
@@ -503,7 +510,7 @@ func newHead(ctx context.Context, head types.Value, addr hash.Hash) (dsHead, err
return newSerialTagHead(data, addr)
}
if fid == serial.WorkingSetFileID {
return newSerialWorkingSetHead(data, addr), nil
return newSerialWorkingSetHead(data, addr)
}
if fid == serial.CommitFileID {
return newSerialCommitHead(sm, addr), nil

View File

@@ -100,12 +100,18 @@ func getCommitClosureSubtrees(msg serial.Message) ([]uint64, error) {
return nil, err
}
counts := make([]uint64, cnt)
m := serial.GetRootAsCommitClosure(msg, serial.MessagePrefixSz)
m, err := serial.TryGetRootAsCommitClosure(msg, serial.MessagePrefixSz)
if err != nil {
return nil, err
}
return decodeVarints(m.SubtreeCountsBytes(), counts), nil
}
func walkCommitClosureAddresses(ctx context.Context, msg serial.Message, cb func(ctx context.Context, addr hash.Hash) error) error {
m := serial.GetRootAsCommitClosure(msg, serial.MessagePrefixSz)
m, err := serial.TryGetRootAsCommitClosure(msg, serial.MessagePrefixSz)
if err != nil {
return err
}
arr := m.AddressArrayBytes()
for i := 0; i < len(arr)/hash.ByteLen; i++ {
addr := hash.New(arr[i*addrSize : (i+1)*addrSize])

View File

@@ -65,22 +65,23 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
id := serial.GetFileID(sm)
switch id {
// NOTE: splunk uses a separate path for some printing
// NOTE: We ignore the errors from field number checks here...
case serial.StoreRootFileID:
msg := serial.GetRootAsStoreRoot([]byte(sm), serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsStoreRoot([]byte(sm), serial.MessagePrefixSz)
ret := &strings.Builder{}
mapbytes := msg.AddressMapBytes()
printWithIndendationLevel(level, ret, "StoreRoot{%s}",
SerialMessage(mapbytes).humanReadableStringAtIndentationLevel(level+1))
return ret.String()
case serial.StashListFileID:
msg := serial.GetRootAsStashList([]byte(sm), serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsStashList([]byte(sm), serial.MessagePrefixSz)
ret := &strings.Builder{}
mapbytes := msg.AddressMapBytes()
printWithIndendationLevel(level, ret, "StashList{%s}",
SerialMessage(mapbytes).humanReadableStringAtIndentationLevel(level+1))
return ret.String()
case serial.StashFileID:
msg := serial.GetRootAsStash(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsStash(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
printWithIndendationLevel(level, ret, "\tBranchName: %s\n", msg.BranchName())
@@ -90,7 +91,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
printWithIndendationLevel(level, ret, "}")
return ret.String()
case serial.TagFileID:
msg := serial.GetRootAsTag(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsTag(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name())
@@ -101,7 +102,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
printWithIndendationLevel(level, ret, "}")
return ret.String()
case serial.WorkingSetFileID:
msg := serial.GetRootAsWorkingSet(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsWorkingSet(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name())
@@ -113,7 +114,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
printWithIndendationLevel(level, ret, "}")
return ret.String()
case serial.CommitFileID:
msg := serial.GetRootAsCommit(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsCommit(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
printWithIndendationLevel(level, ret, "\tName: %s\n", msg.Name())
@@ -150,7 +151,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
printWithIndendationLevel(level, ret, "}")
return ret.String()
case serial.RootValueFileID:
msg := serial.GetRootAsRootValue(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsRootValue(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
printWithIndendationLevel(level, ret, "\tFeatureVersion: %d\n", msg.FeatureVersion())
@@ -160,7 +161,7 @@ func (sm SerialMessage) humanReadableStringAtIndentationLevel(level int) string
printWithIndendationLevel(level, ret, "}")
return ret.String()
case serial.TableFileID:
msg := serial.GetRootAsTable(sm, serial.MessagePrefixSz)
msg, _ := serial.TryGetRootAsTable(sm, serial.MessagePrefixSz)
ret := &strings.Builder{}
printWithIndendationLevel(level, ret, "{\n")
@@ -274,7 +275,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er
return err
}
}
mergeState := msg.MergeState(nil)
mergeState, err := msg.TryMergeState(nil)
if err != nil {
return err
}
if mergeState != nil {
if err = cb(hash.New(mergeState.PreWorkingRootAddrBytes())); err != nil {
return err
@@ -310,7 +314,10 @@ func (sm SerialMessage) WalkAddrs(nbf *NomsBinFormat, cb func(addr hash.Hash) er
return err
}
confs := msg.Conflicts(nil)
confs, err := msg.TryConflicts(nil)
if err != nil {
return err
}
addr := hash.New(confs.DataBytes())
if !addr.IsEmpty() {
if err = cb(addr); err != nil {