mirror of
https://github.com/dolthub/dolt.git
synced 2026-04-22 11:29:06 -05:00
dbfactory: aws: For AWS remotes using an explicit file credentials, periodically refresh the credentials the AWS client uses from the file contents.
Some use cases put attenuated, expiring credentials into files. It's nice to pick up the new credentials without needing to recreate the client.
This commit is contained in:
@@ -20,6 +20,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
@@ -27,6 +28,7 @@ import (
|
||||
"github.com/aws/aws-sdk-go/service/dynamodb"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
|
||||
"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
|
||||
"github.com/dolthub/dolt/go/store/chunks"
|
||||
"github.com/dolthub/dolt/go/store/datas"
|
||||
"github.com/dolthub/dolt/go/store/nbs"
|
||||
@@ -49,6 +51,8 @@ const (
|
||||
AWSCredsProfile = "aws-creds-profile"
|
||||
)
|
||||
|
||||
var AWSFileCredsRefreshDuration = time.Minute
|
||||
|
||||
var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()}
|
||||
|
||||
// AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid)
|
||||
@@ -210,7 +214,11 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error)
|
||||
if filePath, ok := params[AWSCredsFileParam]; !ok {
|
||||
return opts, os.ErrNotExist
|
||||
} else {
|
||||
creds := credentials.NewSharedCredentials(filePath.(string), profile)
|
||||
provider := &credentials.SharedCredentialsProvider{
|
||||
Filename: filePath.(string),
|
||||
Profile: profile,
|
||||
}
|
||||
creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration))
|
||||
awsConfig = awsConfig.WithCredentials(creds)
|
||||
}
|
||||
case AutoCS:
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright 2023 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Refreshing credentials will periodically refresh credentials from the
|
||||
// underlying credential provider. This can be used in places where temporary
|
||||
// credentials are placed into files, for example, and we need profile
|
||||
// credentials periodically refreshed, for example.
|
||||
|
||||
package awsrefreshcreds
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
)
|
||||
|
||||
var now func() time.Time = time.Now
|
||||
|
||||
type RefreshingCredentialsProvider struct {
|
||||
provider credentials.Provider
|
||||
|
||||
refreshedAt time.Time
|
||||
refreshInterval time.Duration
|
||||
}
|
||||
|
||||
func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider {
|
||||
return &RefreshingCredentialsProvider{
|
||||
provider: provider,
|
||||
refreshInterval: interval,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) {
|
||||
v, err := p.provider.Retrieve()
|
||||
if err == nil {
|
||||
p.refreshedAt = now()
|
||||
}
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (p *RefreshingCredentialsProvider) IsExpired() bool {
|
||||
if now().Sub(p.refreshedAt) > p.refreshInterval {
|
||||
return true
|
||||
}
|
||||
return p.provider.IsExpired()
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
// Copyright 2023 Dolthub, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package awsrefreshcreds
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type staticProvider struct {
|
||||
v credentials.Value
|
||||
}
|
||||
|
||||
func (p *staticProvider) Retrieve() (credentials.Value, error) {
|
||||
return p.v, nil
|
||||
}
|
||||
|
||||
func (p *staticProvider) IsExpired() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func TestRefreshingCredentialsProvider(t *testing.T) {
|
||||
var sp staticProvider
|
||||
sp.v.AccessKeyID = "ExampleOne"
|
||||
rp := NewRefreshingCredentialsProvider(&sp, time.Minute)
|
||||
|
||||
n := time.Now()
|
||||
origNow := now
|
||||
t.Cleanup(func() {
|
||||
now = origNow
|
||||
})
|
||||
now = func() time.Time { return n }
|
||||
|
||||
v, err := rp.Retrieve()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ExampleOne", v.AccessKeyID)
|
||||
assert.False(t, rp.IsExpired())
|
||||
|
||||
sp.v.AccessKeyID = "ExampleTwo"
|
||||
|
||||
now = func() time.Time { return n.Add(30 * time.Second) }
|
||||
|
||||
v, err = rp.Retrieve()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
|
||||
assert.False(t, rp.IsExpired())
|
||||
|
||||
now = func() time.Time { return n.Add(91 * time.Second) }
|
||||
assert.True(t, rp.IsExpired())
|
||||
v, err = rp.Retrieve()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
|
||||
assert.False(t, rp.IsExpired())
|
||||
}
|
||||
|
||||
func TestRefreshingCredentialsProviderShared(t *testing.T) {
|
||||
d := t.TempDir()
|
||||
|
||||
onecontents := `
|
||||
[backup]
|
||||
aws_access_key_id = AKIAAAAAAAAAAAAAAAAA
|
||||
aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
|
||||
`
|
||||
|
||||
twocontents := `
|
||||
[backup]
|
||||
aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ
|
||||
aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
|
||||
`
|
||||
|
||||
configpath := filepath.Join(d, "config")
|
||||
|
||||
require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700))
|
||||
|
||||
n := time.Now()
|
||||
origNow := now
|
||||
t.Cleanup(func() {
|
||||
now = origNow
|
||||
})
|
||||
now = func() time.Time { return n }
|
||||
|
||||
creds := credentials.NewCredentials(
|
||||
NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{
|
||||
Filename: configpath,
|
||||
Profile: "backup",
|
||||
}, time.Minute),
|
||||
)
|
||||
|
||||
v, err := creds.Get()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID)
|
||||
|
||||
require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700))
|
||||
|
||||
now = func() time.Time { return n.Add(61 * time.Second) }
|
||||
v, err = creds.Get()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID)
|
||||
}
|
||||
Reference in New Issue
Block a user