dbfactory: aws: For AWS remotes using an explicit file credentials, periodically refresh the credentials the AWS client uses from the file contents.

Some use cases put attenuated, expiring credentials into files. It's nice to
pick up the new credentials without needing to recreate the client.
This commit is contained in:
Aaron Son
2023-10-23 16:53:58 -07:00
parent 8084c49ffe
commit b083c24d69
3 changed files with 183 additions and 1 deletions
+9 -1
View File
@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
@@ -27,6 +28,7 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/nbs"
@@ -49,6 +51,8 @@ const (
AWSCredsProfile = "aws-creds-profile"
)
var AWSFileCredsRefreshDuration = time.Minute
var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()}
// AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid)
@@ -210,7 +214,11 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error)
if filePath, ok := params[AWSCredsFileParam]; !ok {
return opts, os.ErrNotExist
} else {
creds := credentials.NewSharedCredentials(filePath.(string), profile)
provider := &credentials.SharedCredentialsProvider{
Filename: filePath.(string),
Profile: profile,
}
creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration))
awsConfig = awsConfig.WithCredentials(creds)
}
case AutoCS:
@@ -0,0 +1,57 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Refreshing credentials will periodically refresh credentials from the
// underlying credential provider. This can be used in places where temporary
// credentials are placed into files, for example, and we need profile
// credentials periodically refreshed, for example.
package awsrefreshcreds
import (
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var now func() time.Time = time.Now
type RefreshingCredentialsProvider struct {
provider credentials.Provider
refreshedAt time.Time
refreshInterval time.Duration
}
func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider {
return &RefreshingCredentialsProvider{
provider: provider,
refreshInterval: interval,
}
}
func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) {
v, err := p.provider.Retrieve()
if err == nil {
p.refreshedAt = now()
}
return v, err
}
func (p *RefreshingCredentialsProvider) IsExpired() bool {
if now().Sub(p.refreshedAt) > p.refreshInterval {
return true
}
return p.provider.IsExpired()
}
@@ -0,0 +1,117 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package awsrefreshcreds
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type staticProvider struct {
v credentials.Value
}
func (p *staticProvider) Retrieve() (credentials.Value, error) {
return p.v, nil
}
func (p *staticProvider) IsExpired() bool {
return false
}
func TestRefreshingCredentialsProvider(t *testing.T) {
var sp staticProvider
sp.v.AccessKeyID = "ExampleOne"
rp := NewRefreshingCredentialsProvider(&sp, time.Minute)
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
v, err := rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleOne", v.AccessKeyID)
assert.False(t, rp.IsExpired())
sp.v.AccessKeyID = "ExampleTwo"
now = func() time.Time { return n.Add(30 * time.Second) }
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
now = func() time.Time { return n.Add(91 * time.Second) }
assert.True(t, rp.IsExpired())
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
}
func TestRefreshingCredentialsProviderShared(t *testing.T) {
d := t.TempDir()
onecontents := `
[backup]
aws_access_key_id = AKIAAAAAAAAAAAAAAAAA
aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
`
twocontents := `
[backup]
aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ
aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
`
configpath := filepath.Join(d, "config")
require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700))
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
creds := credentials.NewCredentials(
NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{
Filename: configpath,
Profile: "backup",
}, time.Minute),
)
v, err := creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID)
require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700))
now = func() time.Time { return n.Add(61 * time.Second) }
v, err = creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID)
}