mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-04 19:41:26 -05:00
Merge pull request #6994 from nustiueudinastea/concurrent-remotes-map
Concurrent remotes map
This commit is contained in:
+8
-4
@@ -1,4 +1,4 @@
|
||||
name: Enginetest Race
|
||||
name: Race tests
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -8,8 +8,8 @@ on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
enginerace:
|
||||
name: Go tests - race enginetests
|
||||
racetests:
|
||||
name: Go race tests
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -26,9 +26,13 @@ jobs:
|
||||
go-version: ^1.21
|
||||
id: go
|
||||
- uses: actions/checkout@v3
|
||||
- name: Test All
|
||||
- name: Test engine
|
||||
working-directory: ./go
|
||||
run: |
|
||||
DOLT_SKIP_PREPARED_ENGINETESTS=1 go test -vet=off -v -race -timeout 30m github.com/dolthub/dolt/go/libraries/doltcore/sqle/enginetest
|
||||
env:
|
||||
DOLT_DEFAULT_BIN_FORMAT: ${{ matrix.dolt_fmt }}
|
||||
- name: Test concurrentmap
|
||||
working-directory: ./go
|
||||
run: |
|
||||
go test -vet=off -v -race -timeout 1m github.com/dolthub/dolt/go/libraries/utils/concurrentmap
|
||||
@@ -214,7 +214,7 @@ func printRemotes(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.Ver
|
||||
return errhand.BuildDError("Unable to get remotes from the local directory").AddCause(err).Build()
|
||||
}
|
||||
|
||||
for _, r := range remotes {
|
||||
for _, r := range remotes.Snapshot() {
|
||||
if apr.Contains(cli.VerboseFlag) {
|
||||
paramStr := make([]byte, 0)
|
||||
if len(r.Params) > 0 {
|
||||
|
||||
+1
-1
@@ -229,7 +229,7 @@ func validateBranchMergedIntoUpstream(ctx context.Context, dbdata env.DbData, br
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
remote, ok := remotes[remoteName]
|
||||
remote, ok := remotes.Get(remoteName)
|
||||
if !ok {
|
||||
// TODO: skip error?
|
||||
return fmt.Errorf("remote %s not found", remoteName)
|
||||
|
||||
+31
-19
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
@@ -113,11 +114,7 @@ func createRepoState(fs filesys.Filesys) (*RepoState, error) {
|
||||
|
||||
// deep copy remotes and backups ¯\_(ツ)_/¯ (see commit c59cbead)
|
||||
if repoState != nil {
|
||||
remotes := make(map[string]Remote, len(repoState.Remotes))
|
||||
for n, r := range repoState.Remotes {
|
||||
remotes[n] = r
|
||||
}
|
||||
repoState.Remotes = remotes
|
||||
repoState.Remotes = repoState.Remotes.DeepCopy()
|
||||
|
||||
backups := make(map[string]Remote, len(repoState.Backups))
|
||||
for n, r := range repoState.Backups {
|
||||
@@ -856,7 +853,7 @@ func (dEnv *DoltEnv) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GR
|
||||
return NewGRPCDialProviderFromDoltEnv(dEnv).GetGRPCDialParams(config)
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) GetRemotes() (map[string]Remote, error) {
|
||||
func (dEnv *DoltEnv) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
|
||||
if dEnv.RSLoadErr != nil {
|
||||
return nil, dEnv.RSLoadErr
|
||||
}
|
||||
@@ -866,12 +863,21 @@ func (dEnv *DoltEnv) GetRemotes() (map[string]Remote, error) {
|
||||
|
||||
// CheckRemoteAddressConflict checks whether any backups or remotes share the given URL. Returns the first remote if multiple match.
|
||||
// Returns NoRemote and false if none match.
|
||||
func CheckRemoteAddressConflict(absUrl string, remotes, backups map[string]Remote) (Remote, bool) {
|
||||
for _, r := range remotes {
|
||||
if r.Url == absUrl {
|
||||
return r, true
|
||||
func CheckRemoteAddressConflict(absUrl string, remotes *concurrentmap.Map[string, Remote], backups map[string]Remote) (Remote, bool) {
|
||||
if remotes != nil {
|
||||
var rm *Remote
|
||||
remotes.Iter(func(key string, value Remote) bool {
|
||||
if value.Url == absUrl {
|
||||
rm = &value
|
||||
return false
|
||||
}
|
||||
return true
|
||||
})
|
||||
if rm != nil {
|
||||
return *rm, true
|
||||
}
|
||||
}
|
||||
|
||||
for _, r := range backups {
|
||||
if r.Url == absUrl {
|
||||
return r, true
|
||||
@@ -881,7 +887,7 @@ func CheckRemoteAddressConflict(absUrl string, remotes, backups map[string]Remot
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) AddRemote(r Remote) error {
|
||||
if _, ok := dEnv.RepoState.Remotes[r.Name]; ok {
|
||||
if _, ok := dEnv.RepoState.Remotes.Get(r.Name); ok {
|
||||
return ErrRemoteAlreadyExists
|
||||
}
|
||||
|
||||
@@ -937,7 +943,7 @@ func (dEnv *DoltEnv) AddBackup(r Remote) error {
|
||||
}
|
||||
|
||||
func (dEnv *DoltEnv) RemoveRemote(ctx context.Context, name string) error {
|
||||
remote, ok := dEnv.RepoState.Remotes[name]
|
||||
remote, ok := dEnv.RepoState.Remotes.Get(name)
|
||||
if !ok {
|
||||
return ErrRemoteNotFound
|
||||
}
|
||||
@@ -1048,7 +1054,7 @@ func (dEnv *DoltEnv) FindRef(ctx context.Context, refStr string) (ref.DoltRef, e
|
||||
slashIdx := strings.IndexRune(refStr, '/')
|
||||
if slashIdx > 0 {
|
||||
remoteName := refStr[:slashIdx]
|
||||
if _, ok := dEnv.RepoState.Remotes[remoteName]; ok {
|
||||
if _, ok := dEnv.RepoState.Remotes.Get(remoteName); ok {
|
||||
remoteRef, err := ref.NewRemoteRefFromPathStr(refStr)
|
||||
|
||||
if err != nil {
|
||||
@@ -1079,7 +1085,7 @@ func GetRefSpecs(rsr RepoStateReader, remoteName string) ([]ref.RemoteRefSpec, e
|
||||
}
|
||||
if remoteName == "" {
|
||||
remote, err = GetDefaultRemote(rsr)
|
||||
} else if r, ok := remotes[remoteName]; ok {
|
||||
} else if r, ok := remotes.Get(remoteName); ok {
|
||||
remote = r
|
||||
} else {
|
||||
err = ErrInvalidRepository.New(remoteName)
|
||||
@@ -1122,15 +1128,21 @@ func GetDefaultRemote(rsr RepoStateReader) (Remote, error) {
|
||||
return NoRemote, err
|
||||
}
|
||||
|
||||
if len(remotes) == 0 {
|
||||
remotesLen := remotes.Len()
|
||||
if remotesLen == 0 {
|
||||
return NoRemote, ErrNoRemote
|
||||
} else if len(remotes) == 1 {
|
||||
for _, v := range remotes {
|
||||
return v, nil
|
||||
} else if remotesLen == 1 {
|
||||
var remote *Remote
|
||||
remotes.Iter(func(key string, value Remote) bool {
|
||||
remote = &value
|
||||
return false
|
||||
})
|
||||
if remote != nil {
|
||||
return *remote, nil
|
||||
}
|
||||
}
|
||||
|
||||
if remote, ok := remotes["origin"]; ok {
|
||||
if remote, ok := remotes.Get("origin"); ok {
|
||||
return remote, nil
|
||||
}
|
||||
|
||||
|
||||
+2
-1
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
@@ -52,7 +53,7 @@ func createTestEnv(isInitialized bool, hasLocalConfig bool) (*DoltEnv, *filesys.
|
||||
initialDirs = append(initialDirs, doltDataDir)
|
||||
|
||||
mainRef := ref.NewBranchRef(DefaultInitBranch)
|
||||
repoState := &RepoState{Head: ref.MarshalableRef{Ref: mainRef}}
|
||||
repoState := &RepoState{Head: ref.MarshalableRef{Ref: mainRef}, Remotes: concurrentmap.New[string, Remote]()}
|
||||
repoStateData, err := json.Marshal(repoState)
|
||||
|
||||
if err != nil {
|
||||
|
||||
Vendored
+3
-2
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/store/chunks"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
@@ -207,8 +208,8 @@ func (m MemoryRepoState) SetCWBHeadRef(_ context.Context, r ref.MarshalableRef)
|
||||
return
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) GetRemotes() (map[string]Remote, error) {
|
||||
return make(map[string]Remote), nil
|
||||
func (m MemoryRepoState) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
|
||||
return &concurrentmap.Map[string, Remote]{}, nil
|
||||
}
|
||||
|
||||
func (m MemoryRepoState) AddRemote(r Remote) error {
|
||||
|
||||
Vendored
+7
-4
@@ -222,7 +222,7 @@ func getRemote(rsr RepoStateReader, name string) (Remote, error) {
|
||||
return NoRemote, err
|
||||
}
|
||||
|
||||
remote, ok := remotes[name]
|
||||
remote, ok := remotes.Get(name)
|
||||
if !ok {
|
||||
return NoRemote, ErrInvalidRepository.New(name)
|
||||
}
|
||||
@@ -383,7 +383,7 @@ func RemoteForFetchArgs(args []string, rsr RepoStateReader) (Remote, []string, e
|
||||
return NoRemote, nil, err
|
||||
}
|
||||
|
||||
if len(remotes) == 0 {
|
||||
if remotes.Len() == 0 {
|
||||
return NoRemote, nil, ErrNoRemote
|
||||
}
|
||||
|
||||
@@ -395,7 +395,7 @@ func RemoteForFetchArgs(args []string, rsr RepoStateReader) (Remote, []string, e
|
||||
args = args[1:]
|
||||
}
|
||||
|
||||
remote, ok := remotes[remName]
|
||||
remote, ok := remotes.Get(remName)
|
||||
if !ok {
|
||||
msg := "does not appear to be a dolt database. could not read from the remote database. please make sure you have the correct access rights and the database exists"
|
||||
return NoRemote, nil, fmt.Errorf("%w; '%s' %s", ErrUnknownRemote, remName, msg)
|
||||
@@ -584,7 +584,10 @@ func NewPullSpec(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remote := remotes[refSpecs[0].GetRemote()]
|
||||
remote, found := remotes.Get(refSpecs[0].GetRemote())
|
||||
if !found {
|
||||
return nil, ErrPullWithNoRemoteAndNoUpstream
|
||||
}
|
||||
|
||||
var remoteRef ref.DoltRef
|
||||
if remoteRefName == "" {
|
||||
|
||||
+19
-16
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
"github.com/dolthub/dolt/go/store/types"
|
||||
@@ -30,7 +31,7 @@ import (
|
||||
type RepoStateReader interface {
|
||||
CWBHeadRef() (ref.DoltRef, error)
|
||||
CWBHeadSpec() (*doltdb.CommitSpec, error)
|
||||
GetRemotes() (map[string]Remote, error)
|
||||
GetRemotes() (*concurrentmap.Map[string, Remote], error)
|
||||
GetBackups() (map[string]Remote, error)
|
||||
GetBranches() (map[string]BranchConfig, error)
|
||||
}
|
||||
@@ -68,10 +69,10 @@ type BranchConfig struct {
|
||||
}
|
||||
|
||||
type RepoState struct {
|
||||
Head ref.MarshalableRef `json:"head"`
|
||||
Remotes map[string]Remote `json:"remotes"`
|
||||
Backups map[string]Remote `json:"backups"`
|
||||
Branches map[string]BranchConfig `json:"branches"`
|
||||
Head ref.MarshalableRef `json:"head"`
|
||||
Remotes *concurrentmap.Map[string, Remote] `json:"remotes"`
|
||||
Backups map[string]Remote `json:"backups"`
|
||||
Branches map[string]BranchConfig `json:"branches"`
|
||||
// |staged|, |working|, and |merge| are legacy fields left over from when Dolt repos stored this info in the repo
|
||||
// state file, not in the DB directly. They're still here so that we can migrate existing repositories forward to the
|
||||
// new storage format, but they should be used only for this purpose and are no longer written.
|
||||
@@ -83,13 +84,13 @@ type RepoState struct {
|
||||
// repoStateLegacy only exists to unmarshall legacy repo state files, since the JSON marshaller can't work with
|
||||
// unexported fields
|
||||
type repoStateLegacy struct {
|
||||
Head ref.MarshalableRef `json:"head"`
|
||||
Remotes map[string]Remote `json:"remotes"`
|
||||
Backups map[string]Remote `json:"backups"`
|
||||
Branches map[string]BranchConfig `json:"branches"`
|
||||
Staged string `json:"staged,omitempty"`
|
||||
Working string `json:"working,omitempty"`
|
||||
Merge *mergeState `json:"merge,omitempty"`
|
||||
Head ref.MarshalableRef `json:"head"`
|
||||
Remotes *concurrentmap.Map[string, Remote] `json:"remotes"`
|
||||
Backups map[string]Remote `json:"backups"`
|
||||
Branches map[string]BranchConfig `json:"branches"`
|
||||
Staged string `json:"staged,omitempty"`
|
||||
Working string `json:"working,omitempty"`
|
||||
Merge *mergeState `json:"merge,omitempty"`
|
||||
}
|
||||
|
||||
// repoStateLegacyFromRepoState creates a new repoStateLegacy from a RepoState file. Only for testing.
|
||||
@@ -153,11 +154,13 @@ func LoadRepoState(fs filesys.ReadWriteFS) (*RepoState, error) {
|
||||
func CloneRepoState(fs filesys.ReadWriteFS, r Remote) (*RepoState, error) {
|
||||
init := ref.NewBranchRef(DefaultInitBranch) // best effort
|
||||
hashStr := hash.Hash{}.String()
|
||||
remotes := concurrentmap.New[string, Remote]()
|
||||
remotes.Set(r.Name, r)
|
||||
rs := &RepoState{
|
||||
Head: ref.MarshalableRef{Ref: init},
|
||||
staged: hashStr,
|
||||
working: hashStr,
|
||||
Remotes: map[string]Remote{r.Name: r},
|
||||
Remotes: remotes,
|
||||
Branches: make(map[string]BranchConfig),
|
||||
Backups: make(map[string]Remote),
|
||||
}
|
||||
@@ -179,7 +182,7 @@ func CreateRepoState(fs filesys.ReadWriteFS, br string) (*RepoState, error) {
|
||||
|
||||
rs := &RepoState{
|
||||
Head: ref.MarshalableRef{Ref: headRef},
|
||||
Remotes: make(map[string]Remote),
|
||||
Remotes: concurrentmap.New[string, Remote](),
|
||||
Branches: make(map[string]BranchConfig),
|
||||
Backups: make(map[string]Remote),
|
||||
}
|
||||
@@ -213,11 +216,11 @@ func (rs *RepoState) CWBHeadSpec() *doltdb.CommitSpec {
|
||||
}
|
||||
|
||||
func (rs *RepoState) AddRemote(r Remote) {
|
||||
rs.Remotes[r.Name] = r
|
||||
rs.Remotes.Set(r.Name, r)
|
||||
}
|
||||
|
||||
func (rs *RepoState) RemoveRemote(r Remote) {
|
||||
delete(rs.Remotes, r.Name)
|
||||
rs.Remotes.Delete(r.Name)
|
||||
}
|
||||
|
||||
func (rs *RepoState) AddBackup(r Remote) {
|
||||
|
||||
@@ -299,7 +299,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
|
||||
dialprovider := c.gRPCDialProvider(denv)
|
||||
var hooks []*commithook
|
||||
for _, r := range c.cfg.StandbyRemotes() {
|
||||
remote, ok := remotes[r.Name()]
|
||||
remote, ok := remotes.Get(r.Name())
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("sqle: cluster: standby replication: destination remote %s does not exist on database %s", r.Name(), name)
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads, orig
|
||||
|
||||
var er env.Remote
|
||||
var ok bool
|
||||
if er, ok = remotes[r.Name()]; ok {
|
||||
if er, ok = remotes.Get(r.Name()); ok {
|
||||
if er.Url != remoteUrl {
|
||||
return fmt.Errorf("invalid remote (%s) for cluster replication found in database %s: expect url %s but the existing remote had url %s", r.Name(), name, remoteUrl, er.Url)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
)
|
||||
|
||||
const DoltClusterDbName = "dolt_cluster"
|
||||
@@ -111,6 +112,7 @@ func (db database) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error
|
||||
Rsw: noopRepoStateWriter{},
|
||||
},
|
||||
ReadOnly: true,
|
||||
Remotes: concurrentmap.New[string, env.Remote](),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/store/hash"
|
||||
)
|
||||
|
||||
@@ -416,7 +417,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
|
||||
sess := dsess.DSessFromSess(ctx.Session)
|
||||
adapter := dsess.NewSessionStateAdapter(
|
||||
sess, db.RevisionQualifiedName(),
|
||||
map[string]env.Remote{},
|
||||
concurrentmap.New[string, env.Remote](),
|
||||
map[string]env.BranchConfig{},
|
||||
map[string]env.Remote{})
|
||||
ws, err := sess.WorkingSet(ctx, db.RevisionQualifiedName())
|
||||
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/lockutil"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
@@ -1443,6 +1444,7 @@ func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.In
|
||||
Rsw: srcDb.DbData().Rsw,
|
||||
Rsr: srcDb.DbData().Rsr,
|
||||
},
|
||||
Remotes: concurrentmap.New[string, env.Remote](),
|
||||
// todo: should we initialize
|
||||
// - Remotes
|
||||
// - Branches
|
||||
@@ -1493,6 +1495,7 @@ func initialStateForCommit(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.I
|
||||
Rsw: srcDb.DbData().Rsw,
|
||||
Rsr: srcDb.DbData().Rsr,
|
||||
},
|
||||
Remotes: concurrentmap.New[string, env.Remote](),
|
||||
// todo: should we initialize
|
||||
// - Remotes
|
||||
// - Branches
|
||||
|
||||
@@ -118,7 +118,7 @@ func removeRemote(ctx *sql.Context, dbd env.DbData, apr *argparser.ArgParseResul
|
||||
return err
|
||||
}
|
||||
|
||||
remote, ok := remotes[old]
|
||||
remote, ok := remotes.Get(old)
|
||||
if !ok {
|
||||
return fmt.Errorf("error: unknown remote: '%s'", old)
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
)
|
||||
|
||||
// InitialDbState is the initial state of a database, as returned by SessionDatabase.InitialDBState. It is used to
|
||||
@@ -40,7 +41,7 @@ type InitialDbState struct {
|
||||
HeadRoot *doltdb.RootValue
|
||||
ReadOnly bool
|
||||
DbData env.DbData
|
||||
Remotes map[string]env.Remote
|
||||
Remotes *concurrentmap.Map[string, env.Remote]
|
||||
Branches map[string]env.BranchConfig
|
||||
Backups map[string]env.Remote
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
)
|
||||
|
||||
// SessionStateAdapter is an adapter for env.RepoStateReader in SQL contexts, getting information about the repo state
|
||||
@@ -31,7 +32,7 @@ import (
|
||||
type SessionStateAdapter struct {
|
||||
session *DoltSession
|
||||
dbName string
|
||||
remotes map[string]env.Remote
|
||||
remotes *concurrentmap.Map[string, env.Remote]
|
||||
backups map[string]env.Remote
|
||||
branches map[string]env.BranchConfig
|
||||
}
|
||||
@@ -44,7 +45,7 @@ var _ env.RepoStateReader = SessionStateAdapter{}
|
||||
var _ env.RepoStateWriter = SessionStateAdapter{}
|
||||
var _ env.RootsProvider = SessionStateAdapter{}
|
||||
|
||||
func NewSessionStateAdapter(session *DoltSession, dbName string, remotes map[string]env.Remote, branches map[string]env.BranchConfig, backups map[string]env.Remote) SessionStateAdapter {
|
||||
func NewSessionStateAdapter(session *DoltSession, dbName string, remotes *concurrentmap.Map[string, env.Remote], branches map[string]env.BranchConfig, backups map[string]env.Remote) SessionStateAdapter {
|
||||
if branches == nil {
|
||||
branches = make(map[string]env.BranchConfig)
|
||||
}
|
||||
@@ -87,7 +88,7 @@ func (s SessionStateAdapter) CWBHeadSpec() (*doltdb.CommitSpec, error) {
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) GetRemotes() (map[string]env.Remote, error) {
|
||||
func (s SessionStateAdapter) GetRemotes() (*concurrentmap.Map[string, env.Remote], error) {
|
||||
return s.remotes, nil
|
||||
}
|
||||
|
||||
@@ -117,7 +118,7 @@ func (s SessionStateAdapter) UpdateBranch(name string, new env.BranchConfig) err
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
|
||||
if _, ok := s.remotes[remote.Name]; ok {
|
||||
if _, ok := s.remotes.Get(remote.Name); ok {
|
||||
return env.ErrRemoteAlreadyExists
|
||||
}
|
||||
|
||||
@@ -140,7 +141,7 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
|
||||
return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, rem.Name, rem.Url)
|
||||
}
|
||||
|
||||
s.remotes[remote.Name] = remote
|
||||
s.remotes.Set(remote.Name, remote)
|
||||
repoState.AddRemote(remote)
|
||||
return repoState.Save(fs)
|
||||
}
|
||||
@@ -175,11 +176,11 @@ func (s SessionStateAdapter) AddBackup(backup env.Remote) error {
|
||||
}
|
||||
|
||||
func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error {
|
||||
remote, ok := s.remotes[name]
|
||||
remote, ok := s.remotes.Get(name)
|
||||
if !ok {
|
||||
return env.ErrRemoteNotFound
|
||||
}
|
||||
delete(s.remotes, remote.Name)
|
||||
s.remotes.Delete(remote.Name)
|
||||
|
||||
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
|
||||
if err != nil {
|
||||
@@ -191,12 +192,12 @@ func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error
|
||||
return err
|
||||
}
|
||||
|
||||
remote, ok = repoState.Remotes[name]
|
||||
remote, ok = repoState.Remotes.Get(name)
|
||||
if !ok {
|
||||
// sanity check
|
||||
return env.ErrRemoteNotFound
|
||||
}
|
||||
delete(repoState.Remotes, name)
|
||||
repoState.Remotes.Delete(name)
|
||||
return repoState.Save(fs)
|
||||
}
|
||||
|
||||
|
||||
@@ -121,12 +121,12 @@ func NewRemoteItr(ctx *sql.Context, ddb *doltdb.DoltDB) (*RemoteItr, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
remotes := make([]env.Remote, len(remoteMap))
|
||||
i := 0
|
||||
for _, r := range remoteMap {
|
||||
remotes[i] = r
|
||||
i++
|
||||
}
|
||||
remotes := []env.Remote{}
|
||||
|
||||
remoteMap.Iter(func(key string, val env.Remote) bool {
|
||||
remotes = append(remotes, val)
|
||||
return true
|
||||
})
|
||||
|
||||
return &RemoteItr{remotes, 0}, nil
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ func NewReadReplicaDatabase(ctx context.Context, db Database, remoteName string,
|
||||
return EmptyReadReplica, err
|
||||
}
|
||||
|
||||
remote, ok := remotes[remoteName]
|
||||
remote, ok := remotes.Get(remoteName)
|
||||
if !ok {
|
||||
return EmptyReadReplica, fmt.Errorf("%w: '%s'", env.ErrRemoteNotFound, remoteName)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ func getPushOnWriteHook(ctx context.Context, bThreads *sql.BackgroundThreads, dE
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rem, ok := remotes[remoteName]
|
||||
rem, ok := remotes.Get(remoteName)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: '%s'", env.ErrRemoteNotFound, remoteName)
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/env"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
|
||||
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
|
||||
)
|
||||
|
||||
// UserSpaceDatabase in an implementation of sql.Database for root values. Does not expose any of the internal dolt tables.
|
||||
@@ -84,6 +85,7 @@ func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbSt
|
||||
DbData: env.DbData{
|
||||
Rsw: noopRepoStateWriter{},
|
||||
},
|
||||
Remotes: concurrentmap.New[string, env.Remote](),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
// Copyright 2021 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package concurrentmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
)
|
||||
|
||||
func New[K comparable, V any]() *Map[K, V] {
|
||||
return &Map[K, V]{m: make(map[K]V)}
|
||||
}
|
||||
|
||||
type Map[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
m map[K]V
|
||||
}
|
||||
|
||||
// Get returns the value for the given key. If the key does not exist, the zero value for the value type will be returned.
|
||||
func (cm *Map[K, V]) Get(key K) (V, bool) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
if value, found := cm.m[key]; found {
|
||||
return value, true
|
||||
}
|
||||
var zero V
|
||||
return zero, false
|
||||
}
|
||||
|
||||
// Set sets the value for the given key. If the key already exists, it will be overwritten.
|
||||
func (cm *Map[K, V]) Set(key K, value V) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
cm.m[key] = value
|
||||
}
|
||||
|
||||
// Delete removes the key from the map if it exists. If the key does not exist, this is a no-op.
|
||||
func (cm *Map[K, V]) Delete(key K) {
|
||||
cm.mu.Lock()
|
||||
defer cm.mu.Unlock()
|
||||
delete(cm.m, key)
|
||||
}
|
||||
|
||||
// Len returns the number of items in the map at the time of the call.
|
||||
func (cm *Map[K, V]) Len() int {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
return len(cm.m)
|
||||
}
|
||||
|
||||
// DeepCopy returns a deep copy of the concurrent map.
|
||||
func (cm *Map[K, V]) DeepCopy() *Map[K, V] {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
newMap := make(map[K]V, len(cm.m))
|
||||
for k, v := range cm.m {
|
||||
newMap[k] = v
|
||||
}
|
||||
return &Map[K, V]{m: newMap}
|
||||
}
|
||||
|
||||
// Iter iterates over the map, calling the provided function for each key/value pair. If the function returns false, the iteration stops.
|
||||
func (cm *Map[K, V]) Iter(f func(key K, value V) bool) {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
for k, v := range cm.m {
|
||||
if !f(k, v) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Snapshot returns a copy of the internal map at the time of the call. Returns a native map, not a concurrent one.
|
||||
func (cm *Map[K, V]) Snapshot() map[K]V {
|
||||
cm.mu.RLock()
|
||||
defer cm.mu.RUnlock()
|
||||
newMap := make(map[K]V, len(cm.m))
|
||||
for k, v := range cm.m {
|
||||
newMap[k] = v
|
||||
}
|
||||
return newMap
|
||||
}
|
||||
|
||||
func (cm *Map[K, V]) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(cm.Snapshot())
|
||||
}
|
||||
|
||||
func (cm *Map[K, V]) UnmarshalJSON(data []byte) error {
|
||||
return json.Unmarshal(data, &cm.m)
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
// Copyright 2021 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package concurrentmap
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConcurrentMapConstructor(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
if m == nil {
|
||||
t.Fatal("New concurrent map is nil")
|
||||
}
|
||||
if m.m == nil {
|
||||
t.Error("New concurrent map's underlying map is nil")
|
||||
}
|
||||
if len(m.m) != 0 {
|
||||
t.Error("New concurrent map's underlying map is not empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapSetAndGet(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(1, "a")
|
||||
|
||||
// Test that the value is set
|
||||
if val, found := m.Get(1); !found || val != "a" {
|
||||
t.Errorf("Got %s, want %s", val, "a")
|
||||
}
|
||||
// Test that the value is not set for a different key
|
||||
if val, found := m.Get(2); found || val != "" {
|
||||
t.Errorf("Got %s, want an empty value and a not found", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapDelete(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(1, "a")
|
||||
m.Delete(1)
|
||||
|
||||
// Test that the value is deleted
|
||||
if _, found := m.Get(1); found {
|
||||
t.Errorf("Expected key 1 to be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapLen(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(1, "a")
|
||||
m.Set(2, "b")
|
||||
m.Set(3, "b")
|
||||
|
||||
// Test that the length is correct
|
||||
if m.Len() != 3 {
|
||||
t.Errorf("Expected length 3, got %d", m.Len())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapDeepCopy(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(1, "a")
|
||||
copy := m.DeepCopy()
|
||||
m.Set(1, "b")
|
||||
|
||||
// Test that the copy is not affected by the original
|
||||
if val, _ := copy.Get(1); val != "a" {
|
||||
t.Errorf("DeepCopy failed, expected 'a', got '%s'", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapIter(t *testing.T) {
|
||||
m := New[int, string]()
|
||||
m.Set(1, "a")
|
||||
m.Set(2, "b")
|
||||
m.Set(3, "c")
|
||||
|
||||
counter := 0
|
||||
elements := make(map[int]string)
|
||||
m.Iter(func(key int, value string) bool {
|
||||
counter++
|
||||
elements[key] = value
|
||||
return true
|
||||
})
|
||||
|
||||
// Test that the iterator iterates over all elements
|
||||
if counter != 3 {
|
||||
t.Errorf("Iter failed, expected to iterate 3 times, iterated %d times", counter)
|
||||
}
|
||||
|
||||
// Test that iteration yeilds all elements
|
||||
if len(elements) != 3 {
|
||||
t.Errorf("Iter failed, there should be 3 elements in the map, got %d", len(elements))
|
||||
}
|
||||
if elements[1] != "a" || elements[2] != "b" || elements[3] != "c" {
|
||||
t.Errorf("Iter failed, expected to have 3 elements in the map, with correct values: %v", elements)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentMapSetAndGetWithConcurrency(t *testing.T) {
|
||||
m := New[int, int]()
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Set 100 elements concurrently
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
m.Set(i, i)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for al goroutines to finish
|
||||
wg.Wait()
|
||||
|
||||
// Test that all elements are set
|
||||
for i := 0; i < 100; i++ {
|
||||
if val, found := m.Get(i); !found || val != i {
|
||||
t.Errorf("Got %d, want %d", val, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user