Merge pull request #6994 from nustiueudinastea/concurrent-remotes-map

Concurrent remotes map
This commit is contained in:
Zach Musgrave
2023-11-15 11:10:59 -08:00
committed by GitHub
22 changed files with 340 additions and 70 deletions
@@ -1,4 +1,4 @@
name: Enginetest Race
name: Race tests
on:
push:
@@ -8,8 +8,8 @@ on:
workflow_dispatch:
jobs:
enginerace:
name: Go tests - race enginetests
racetests:
name: Go race tests
defaults:
run:
shell: bash
@@ -26,9 +26,13 @@ jobs:
go-version: ^1.21
id: go
- uses: actions/checkout@v3
- name: Test All
- name: Test engine
working-directory: ./go
run: |
DOLT_SKIP_PREPARED_ENGINETESTS=1 go test -vet=off -v -race -timeout 30m github.com/dolthub/dolt/go/libraries/doltcore/sqle/enginetest
env:
DOLT_DEFAULT_BIN_FORMAT: ${{ matrix.dolt_fmt }}
- name: Test concurrentmap
working-directory: ./go
run: |
go test -vet=off -v -race -timeout 1m github.com/dolthub/dolt/go/libraries/utils/concurrentmap
+1 -1
View File
@@ -214,7 +214,7 @@ func printRemotes(dEnv *env.DoltEnv, apr *argparser.ArgParseResults) errhand.Ver
return errhand.BuildDError("Unable to get remotes from the local directory").AddCause(err).Build()
}
for _, r := range remotes {
for _, r := range remotes.Snapshot() {
if apr.Contains(cli.VerboseFlag) {
paramStr := make([]byte, 0)
if len(r.Params) > 0 {
+1 -1
View File
@@ -229,7 +229,7 @@ func validateBranchMergedIntoUpstream(ctx context.Context, dbdata env.DbData, br
if err != nil {
return err
}
remote, ok := remotes[remoteName]
remote, ok := remotes.Get(remoteName)
if !ok {
// TODO: skip error?
return fmt.Errorf("remote %s not found", remoteName)
+31 -19
View File
@@ -35,6 +35,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/datas"
@@ -113,11 +114,7 @@ func createRepoState(fs filesys.Filesys) (*RepoState, error) {
// deep copy remotes and backups ¯\_(ツ)_/¯ (see commit c59cbead)
if repoState != nil {
remotes := make(map[string]Remote, len(repoState.Remotes))
for n, r := range repoState.Remotes {
remotes[n] = r
}
repoState.Remotes = remotes
repoState.Remotes = repoState.Remotes.DeepCopy()
backups := make(map[string]Remote, len(repoState.Backups))
for n, r := range repoState.Backups {
@@ -856,7 +853,7 @@ func (dEnv *DoltEnv) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GR
return NewGRPCDialProviderFromDoltEnv(dEnv).GetGRPCDialParams(config)
}
func (dEnv *DoltEnv) GetRemotes() (map[string]Remote, error) {
func (dEnv *DoltEnv) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
if dEnv.RSLoadErr != nil {
return nil, dEnv.RSLoadErr
}
@@ -866,12 +863,21 @@ func (dEnv *DoltEnv) GetRemotes() (map[string]Remote, error) {
// CheckRemoteAddressConflict checks whether any backups or remotes share the given URL. Returns the first remote if multiple match.
// Returns NoRemote and false if none match.
func CheckRemoteAddressConflict(absUrl string, remotes, backups map[string]Remote) (Remote, bool) {
for _, r := range remotes {
if r.Url == absUrl {
return r, true
func CheckRemoteAddressConflict(absUrl string, remotes *concurrentmap.Map[string, Remote], backups map[string]Remote) (Remote, bool) {
if remotes != nil {
var rm *Remote
remotes.Iter(func(key string, value Remote) bool {
if value.Url == absUrl {
rm = &value
return false
}
return true
})
if rm != nil {
return *rm, true
}
}
for _, r := range backups {
if r.Url == absUrl {
return r, true
@@ -881,7 +887,7 @@ func CheckRemoteAddressConflict(absUrl string, remotes, backups map[string]Remot
}
func (dEnv *DoltEnv) AddRemote(r Remote) error {
if _, ok := dEnv.RepoState.Remotes[r.Name]; ok {
if _, ok := dEnv.RepoState.Remotes.Get(r.Name); ok {
return ErrRemoteAlreadyExists
}
@@ -937,7 +943,7 @@ func (dEnv *DoltEnv) AddBackup(r Remote) error {
}
func (dEnv *DoltEnv) RemoveRemote(ctx context.Context, name string) error {
remote, ok := dEnv.RepoState.Remotes[name]
remote, ok := dEnv.RepoState.Remotes.Get(name)
if !ok {
return ErrRemoteNotFound
}
@@ -1048,7 +1054,7 @@ func (dEnv *DoltEnv) FindRef(ctx context.Context, refStr string) (ref.DoltRef, e
slashIdx := strings.IndexRune(refStr, '/')
if slashIdx > 0 {
remoteName := refStr[:slashIdx]
if _, ok := dEnv.RepoState.Remotes[remoteName]; ok {
if _, ok := dEnv.RepoState.Remotes.Get(remoteName); ok {
remoteRef, err := ref.NewRemoteRefFromPathStr(refStr)
if err != nil {
@@ -1079,7 +1085,7 @@ func GetRefSpecs(rsr RepoStateReader, remoteName string) ([]ref.RemoteRefSpec, e
}
if remoteName == "" {
remote, err = GetDefaultRemote(rsr)
} else if r, ok := remotes[remoteName]; ok {
} else if r, ok := remotes.Get(remoteName); ok {
remote = r
} else {
err = ErrInvalidRepository.New(remoteName)
@@ -1122,15 +1128,21 @@ func GetDefaultRemote(rsr RepoStateReader) (Remote, error) {
return NoRemote, err
}
if len(remotes) == 0 {
remotesLen := remotes.Len()
if remotesLen == 0 {
return NoRemote, ErrNoRemote
} else if len(remotes) == 1 {
for _, v := range remotes {
return v, nil
} else if remotesLen == 1 {
var remote *Remote
remotes.Iter(func(key string, value Remote) bool {
remote = &value
return false
})
if remote != nil {
return *remote, nil
}
}
if remote, ok := remotes["origin"]; ok {
if remote, ok := remotes.Get("origin"); ok {
return remote, nil
}
+2 -1
View File
@@ -27,6 +27,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
@@ -52,7 +53,7 @@ func createTestEnv(isInitialized bool, hasLocalConfig bool) (*DoltEnv, *filesys.
initialDirs = append(initialDirs, doltDataDir)
mainRef := ref.NewBranchRef(DefaultInitBranch)
repoState := &RepoState{Head: ref.MarshalableRef{Ref: mainRef}}
repoState := &RepoState{Head: ref.MarshalableRef{Ref: mainRef}, Remotes: concurrentmap.New[string, Remote]()}
repoStateData, err := json.Marshal(repoState)
if err != nil {
+3 -2
View File
@@ -22,6 +22,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/datas"
@@ -207,8 +208,8 @@ func (m MemoryRepoState) SetCWBHeadRef(_ context.Context, r ref.MarshalableRef)
return
}
func (m MemoryRepoState) GetRemotes() (map[string]Remote, error) {
return make(map[string]Remote), nil
func (m MemoryRepoState) GetRemotes() (*concurrentmap.Map[string, Remote], error) {
return &concurrentmap.Map[string, Remote]{}, nil
}
func (m MemoryRepoState) AddRemote(r Remote) error {
+7 -4
View File
@@ -222,7 +222,7 @@ func getRemote(rsr RepoStateReader, name string) (Remote, error) {
return NoRemote, err
}
remote, ok := remotes[name]
remote, ok := remotes.Get(name)
if !ok {
return NoRemote, ErrInvalidRepository.New(name)
}
@@ -383,7 +383,7 @@ func RemoteForFetchArgs(args []string, rsr RepoStateReader) (Remote, []string, e
return NoRemote, nil, err
}
if len(remotes) == 0 {
if remotes.Len() == 0 {
return NoRemote, nil, ErrNoRemote
}
@@ -395,7 +395,7 @@ func RemoteForFetchArgs(args []string, rsr RepoStateReader) (Remote, []string, e
args = args[1:]
}
remote, ok := remotes[remName]
remote, ok := remotes.Get(remName)
if !ok {
msg := "does not appear to be a dolt database. could not read from the remote database. please make sure you have the correct access rights and the database exists"
return NoRemote, nil, fmt.Errorf("%w; '%s' %s", ErrUnknownRemote, remName, msg)
@@ -584,7 +584,10 @@ func NewPullSpec(
if err != nil {
return nil, err
}
remote := remotes[refSpecs[0].GetRemote()]
remote, found := remotes.Get(refSpecs[0].GetRemote())
if !found {
return nil, ErrPullWithNoRemoteAndNoUpstream
}
var remoteRef ref.DoltRef
if remoteRefName == "" {
+19 -16
View File
@@ -21,6 +21,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/hash"
"github.com/dolthub/dolt/go/store/types"
@@ -30,7 +31,7 @@ import (
type RepoStateReader interface {
CWBHeadRef() (ref.DoltRef, error)
CWBHeadSpec() (*doltdb.CommitSpec, error)
GetRemotes() (map[string]Remote, error)
GetRemotes() (*concurrentmap.Map[string, Remote], error)
GetBackups() (map[string]Remote, error)
GetBranches() (map[string]BranchConfig, error)
}
@@ -68,10 +69,10 @@ type BranchConfig struct {
}
type RepoState struct {
Head ref.MarshalableRef `json:"head"`
Remotes map[string]Remote `json:"remotes"`
Backups map[string]Remote `json:"backups"`
Branches map[string]BranchConfig `json:"branches"`
Head ref.MarshalableRef `json:"head"`
Remotes *concurrentmap.Map[string, Remote] `json:"remotes"`
Backups map[string]Remote `json:"backups"`
Branches map[string]BranchConfig `json:"branches"`
// |staged|, |working|, and |merge| are legacy fields left over from when Dolt repos stored this info in the repo
// state file, not in the DB directly. They're still here so that we can migrate existing repositories forward to the
// new storage format, but they should be used only for this purpose and are no longer written.
@@ -83,13 +84,13 @@ type RepoState struct {
// repoStateLegacy only exists to unmarshall legacy repo state files, since the JSON marshaller can't work with
// unexported fields
type repoStateLegacy struct {
Head ref.MarshalableRef `json:"head"`
Remotes map[string]Remote `json:"remotes"`
Backups map[string]Remote `json:"backups"`
Branches map[string]BranchConfig `json:"branches"`
Staged string `json:"staged,omitempty"`
Working string `json:"working,omitempty"`
Merge *mergeState `json:"merge,omitempty"`
Head ref.MarshalableRef `json:"head"`
Remotes *concurrentmap.Map[string, Remote] `json:"remotes"`
Backups map[string]Remote `json:"backups"`
Branches map[string]BranchConfig `json:"branches"`
Staged string `json:"staged,omitempty"`
Working string `json:"working,omitempty"`
Merge *mergeState `json:"merge,omitempty"`
}
// repoStateLegacyFromRepoState creates a new repoStateLegacy from a RepoState file. Only for testing.
@@ -153,11 +154,13 @@ func LoadRepoState(fs filesys.ReadWriteFS) (*RepoState, error) {
func CloneRepoState(fs filesys.ReadWriteFS, r Remote) (*RepoState, error) {
init := ref.NewBranchRef(DefaultInitBranch) // best effort
hashStr := hash.Hash{}.String()
remotes := concurrentmap.New[string, Remote]()
remotes.Set(r.Name, r)
rs := &RepoState{
Head: ref.MarshalableRef{Ref: init},
staged: hashStr,
working: hashStr,
Remotes: map[string]Remote{r.Name: r},
Remotes: remotes,
Branches: make(map[string]BranchConfig),
Backups: make(map[string]Remote),
}
@@ -179,7 +182,7 @@ func CreateRepoState(fs filesys.ReadWriteFS, br string) (*RepoState, error) {
rs := &RepoState{
Head: ref.MarshalableRef{Ref: headRef},
Remotes: make(map[string]Remote),
Remotes: concurrentmap.New[string, Remote](),
Branches: make(map[string]BranchConfig),
Backups: make(map[string]Remote),
}
@@ -213,11 +216,11 @@ func (rs *RepoState) CWBHeadSpec() *doltdb.CommitSpec {
}
func (rs *RepoState) AddRemote(r Remote) {
rs.Remotes[r.Name] = r
rs.Remotes.Set(r.Name, r)
}
func (rs *RepoState) RemoveRemote(r Remote) {
delete(rs.Remotes, r.Name)
rs.Remotes.Delete(r.Name)
}
func (rs *RepoState) AddBackup(r Remote) {
@@ -299,7 +299,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
dialprovider := c.gRPCDialProvider(denv)
var hooks []*commithook
for _, r := range c.cfg.StandbyRemotes() {
remote, ok := remotes[r.Name()]
remote, ok := remotes.Get(r.Name())
if !ok {
return nil, fmt.Errorf("sqle: cluster: standby replication: destination remote %s does not exist on database %s", r.Name(), name)
}
@@ -59,7 +59,7 @@ func NewInitDatabaseHook(controller *Controller, bt *sql.BackgroundThreads, orig
var er env.Remote
var ok bool
if er, ok = remotes[r.Name()]; ok {
if er, ok = remotes.Get(r.Name()); ok {
if er.Url != remoteUrl {
return fmt.Errorf("invalid remote (%s) for cluster replication found in database %s: expect url %s but the existing remote had url %s", r.Name(), name, remoteUrl, er.Url)
}
@@ -26,6 +26,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
)
const DoltClusterDbName = "dolt_cluster"
@@ -111,6 +112,7 @@ func (db database) InitialDBState(ctx *sql.Context) (dsess.InitialDbState, error
Rsw: noopRepoStateWriter{},
},
ReadOnly: true,
Remotes: concurrentmap.New[string, env.Remote](),
}, nil
}
+2 -1
View File
@@ -43,6 +43,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/store/hash"
)
@@ -416,7 +417,7 @@ func (db Database) getTableInsensitive(ctx *sql.Context, head *doltdb.Commit, ds
sess := dsess.DSessFromSess(ctx.Session)
adapter := dsess.NewSessionStateAdapter(
sess, db.RevisionQualifiedName(),
map[string]env.Remote{},
concurrentmap.New[string, env.Remote](),
map[string]env.BranchConfig{},
map[string]env.Remote{})
ws, err := sess.WorkingSet(ctx, db.RevisionQualifiedName())
@@ -37,6 +37,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqlserver"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/libraries/utils/lockutil"
"github.com/dolthub/dolt/go/store/datas"
@@ -1443,6 +1444,7 @@ func initialStateForTagDb(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.In
Rsw: srcDb.DbData().Rsw,
Rsr: srcDb.DbData().Rsr,
},
Remotes: concurrentmap.New[string, env.Remote](),
// todo: should we initialize
// - Remotes
// - Branches
@@ -1493,6 +1495,7 @@ func initialStateForCommit(ctx context.Context, srcDb ReadOnlyDatabase) (dsess.I
Rsw: srcDb.DbData().Rsw,
Rsr: srcDb.DbData().Rsr,
},
Remotes: concurrentmap.New[string, env.Remote](),
// todo: should we initialize
// - Remotes
// - Branches
@@ -118,7 +118,7 @@ func removeRemote(ctx *sql.Context, dbd env.DbData, apr *argparser.ArgParseResul
return err
}
remote, ok := remotes[old]
remote, ok := remotes.Get(old)
if !ok {
return fmt.Errorf("error: unknown remote: '%s'", old)
}
@@ -24,6 +24,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/globalstate"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/writer"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
)
// InitialDbState is the initial state of a database, as returned by SessionDatabase.InitialDBState. It is used to
@@ -40,7 +41,7 @@ type InitialDbState struct {
HeadRoot *doltdb.RootValue
ReadOnly bool
DbData env.DbData
Remotes map[string]env.Remote
Remotes *concurrentmap.Map[string, env.Remote]
Branches map[string]env.BranchConfig
Backups map[string]env.Remote
@@ -24,6 +24,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/ref"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
)
// SessionStateAdapter is an adapter for env.RepoStateReader in SQL contexts, getting information about the repo state
@@ -31,7 +32,7 @@ import (
type SessionStateAdapter struct {
session *DoltSession
dbName string
remotes map[string]env.Remote
remotes *concurrentmap.Map[string, env.Remote]
backups map[string]env.Remote
branches map[string]env.BranchConfig
}
@@ -44,7 +45,7 @@ var _ env.RepoStateReader = SessionStateAdapter{}
var _ env.RepoStateWriter = SessionStateAdapter{}
var _ env.RootsProvider = SessionStateAdapter{}
func NewSessionStateAdapter(session *DoltSession, dbName string, remotes map[string]env.Remote, branches map[string]env.BranchConfig, backups map[string]env.Remote) SessionStateAdapter {
func NewSessionStateAdapter(session *DoltSession, dbName string, remotes *concurrentmap.Map[string, env.Remote], branches map[string]env.BranchConfig, backups map[string]env.Remote) SessionStateAdapter {
if branches == nil {
branches = make(map[string]env.BranchConfig)
}
@@ -87,7 +88,7 @@ func (s SessionStateAdapter) CWBHeadSpec() (*doltdb.CommitSpec, error) {
return spec, nil
}
func (s SessionStateAdapter) GetRemotes() (map[string]env.Remote, error) {
func (s SessionStateAdapter) GetRemotes() (*concurrentmap.Map[string, env.Remote], error) {
return s.remotes, nil
}
@@ -117,7 +118,7 @@ func (s SessionStateAdapter) UpdateBranch(name string, new env.BranchConfig) err
}
func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
if _, ok := s.remotes[remote.Name]; ok {
if _, ok := s.remotes.Get(remote.Name); ok {
return env.ErrRemoteAlreadyExists
}
@@ -140,7 +141,7 @@ func (s SessionStateAdapter) AddRemote(remote env.Remote) error {
return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, rem.Name, rem.Url)
}
s.remotes[remote.Name] = remote
s.remotes.Set(remote.Name, remote)
repoState.AddRemote(remote)
return repoState.Save(fs)
}
@@ -175,11 +176,11 @@ func (s SessionStateAdapter) AddBackup(backup env.Remote) error {
}
func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error {
remote, ok := s.remotes[name]
remote, ok := s.remotes.Get(name)
if !ok {
return env.ErrRemoteNotFound
}
delete(s.remotes, remote.Name)
s.remotes.Delete(remote.Name)
fs, err := s.session.Provider().FileSystemForDatabase(s.dbName)
if err != nil {
@@ -191,12 +192,12 @@ func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error
return err
}
remote, ok = repoState.Remotes[name]
remote, ok = repoState.Remotes.Get(name)
if !ok {
// sanity check
return env.ErrRemoteNotFound
}
delete(repoState.Remotes, name)
repoState.Remotes.Delete(name)
return repoState.Save(fs)
}
@@ -121,12 +121,12 @@ func NewRemoteItr(ctx *sql.Context, ddb *doltdb.DoltDB) (*RemoteItr, error) {
if err != nil {
return nil, err
}
remotes := make([]env.Remote, len(remoteMap))
i := 0
for _, r := range remoteMap {
remotes[i] = r
i++
}
remotes := []env.Remote{}
remoteMap.Iter(func(key string, val env.Remote) bool {
remotes = append(remotes, val)
return true
})
return &RemoteItr{remotes, 0}, nil
}
@@ -61,7 +61,7 @@ func NewReadReplicaDatabase(ctx context.Context, db Database, remoteName string,
return EmptyReadReplica, err
}
remote, ok := remotes[remoteName]
remote, ok := remotes.Get(remoteName)
if !ok {
return EmptyReadReplica, fmt.Errorf("%w: '%s'", env.ErrRemoteNotFound, remoteName)
}
+1 -1
View File
@@ -48,7 +48,7 @@ func getPushOnWriteHook(ctx context.Context, bThreads *sql.BackgroundThreads, dE
return nil, err
}
rem, ok := remotes[remoteName]
rem, ok := remotes.Get(remoteName)
if !ok {
return nil, fmt.Errorf("%w: '%s'", env.ErrRemoteNotFound, remoteName)
}
@@ -21,6 +21,7 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/env"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/table/editor"
"github.com/dolthub/dolt/go/libraries/utils/concurrentmap"
)
// UserSpaceDatabase in an implementation of sql.Database for root values. Does not expose any of the internal dolt tables.
@@ -84,6 +85,7 @@ func (db *UserSpaceDatabase) InitialDBState(ctx *sql.Context) (dsess.InitialDbSt
DbData: env.DbData{
Rsw: noopRepoStateWriter{},
},
Remotes: concurrentmap.New[string, env.Remote](),
}, nil
}
@@ -0,0 +1,102 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package concurrentmap
import (
"encoding/json"
"sync"
)
func New[K comparable, V any]() *Map[K, V] {
return &Map[K, V]{m: make(map[K]V)}
}
type Map[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
// Get returns the value for the given key. If the key does not exist, the zero value for the value type will be returned.
func (cm *Map[K, V]) Get(key K) (V, bool) {
cm.mu.RLock()
defer cm.mu.RUnlock()
if value, found := cm.m[key]; found {
return value, true
}
var zero V
return zero, false
}
// Set sets the value for the given key. If the key already exists, it will be overwritten.
func (cm *Map[K, V]) Set(key K, value V) {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.m[key] = value
}
// Delete removes the key from the map if it exists. If the key does not exist, this is a no-op.
func (cm *Map[K, V]) Delete(key K) {
cm.mu.Lock()
defer cm.mu.Unlock()
delete(cm.m, key)
}
// Len returns the number of items in the map at the time of the call.
func (cm *Map[K, V]) Len() int {
cm.mu.RLock()
defer cm.mu.RUnlock()
return len(cm.m)
}
// DeepCopy returns a deep copy of the concurrent map.
func (cm *Map[K, V]) DeepCopy() *Map[K, V] {
cm.mu.RLock()
defer cm.mu.RUnlock()
newMap := make(map[K]V, len(cm.m))
for k, v := range cm.m {
newMap[k] = v
}
return &Map[K, V]{m: newMap}
}
// Iter iterates over the map, calling the provided function for each key/value pair. If the function returns false, the iteration stops.
func (cm *Map[K, V]) Iter(f func(key K, value V) bool) {
cm.mu.RLock()
defer cm.mu.RUnlock()
for k, v := range cm.m {
if !f(k, v) {
break
}
}
}
// Snapshot returns a copy of the internal map at the time of the call. Returns a native map, not a concurrent one.
func (cm *Map[K, V]) Snapshot() map[K]V {
cm.mu.RLock()
defer cm.mu.RUnlock()
newMap := make(map[K]V, len(cm.m))
for k, v := range cm.m {
newMap[k] = v
}
return newMap
}
func (cm *Map[K, V]) MarshalJSON() ([]byte, error) {
return json.Marshal(cm.Snapshot())
}
func (cm *Map[K, V]) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &cm.m)
}
@@ -0,0 +1,134 @@
// Copyright 2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package concurrentmap
import (
"sync"
"testing"
)
func TestConcurrentMapConstructor(t *testing.T) {
m := New[int, string]()
if m == nil {
t.Fatal("New concurrent map is nil")
}
if m.m == nil {
t.Error("New concurrent map's underlying map is nil")
}
if len(m.m) != 0 {
t.Error("New concurrent map's underlying map is not empty")
}
}
func TestConcurrentMapSetAndGet(t *testing.T) {
m := New[int, string]()
m.Set(1, "a")
// Test that the value is set
if val, found := m.Get(1); !found || val != "a" {
t.Errorf("Got %s, want %s", val, "a")
}
// Test that the value is not set for a different key
if val, found := m.Get(2); found || val != "" {
t.Errorf("Got %s, want an empty value and a not found", val)
}
}
func TestConcurrentMapDelete(t *testing.T) {
m := New[int, string]()
m.Set(1, "a")
m.Delete(1)
// Test that the value is deleted
if _, found := m.Get(1); found {
t.Errorf("Expected key 1 to be deleted")
}
}
func TestConcurrentMapLen(t *testing.T) {
m := New[int, string]()
m.Set(1, "a")
m.Set(2, "b")
m.Set(3, "b")
// Test that the length is correct
if m.Len() != 3 {
t.Errorf("Expected length 3, got %d", m.Len())
}
}
func TestConcurrentMapDeepCopy(t *testing.T) {
m := New[int, string]()
m.Set(1, "a")
copy := m.DeepCopy()
m.Set(1, "b")
// Test that the copy is not affected by the original
if val, _ := copy.Get(1); val != "a" {
t.Errorf("DeepCopy failed, expected 'a', got '%s'", val)
}
}
func TestConcurrentMapIter(t *testing.T) {
m := New[int, string]()
m.Set(1, "a")
m.Set(2, "b")
m.Set(3, "c")
counter := 0
elements := make(map[int]string)
m.Iter(func(key int, value string) bool {
counter++
elements[key] = value
return true
})
// Test that the iterator iterates over all elements
if counter != 3 {
t.Errorf("Iter failed, expected to iterate 3 times, iterated %d times", counter)
}
// Test that iteration yeilds all elements
if len(elements) != 3 {
t.Errorf("Iter failed, there should be 3 elements in the map, got %d", len(elements))
}
if elements[1] != "a" || elements[2] != "b" || elements[3] != "c" {
t.Errorf("Iter failed, expected to have 3 elements in the map, with correct values: %v", elements)
}
}
func TestConcurrentMapSetAndGetWithConcurrency(t *testing.T) {
m := New[int, int]()
var wg sync.WaitGroup
// Set 100 elements concurrently
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
m.Set(i, i)
}(i)
}
// Wait for al goroutines to finish
wg.Wait()
// Test that all elements are set
for i := 0; i < 100; i++ {
if val, found := m.Get(i); !found || val != i {
t.Errorf("Got %d, want %d", val, i)
}
}
}