Merge pull request #6864 from dolthub/aaron/dbfactory-preiodically-refresh-aws-file-credentials

dbfactory: aws: For AWS remotes using an explicit file credentials, periodically refresh the credentials the AWS client uses from the file contents.
This commit is contained in:
Aaron Son
2023-10-23 18:30:20 -07:00
committed by GitHub
3 changed files with 183 additions and 1 deletions

View File

@@ -20,6 +20,7 @@ import (
"net/url"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
@@ -27,6 +28,7 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/datas"
"github.com/dolthub/dolt/go/store/nbs"
@@ -49,6 +51,8 @@ const (
AWSCredsProfile = "aws-creds-profile"
)
var AWSFileCredsRefreshDuration = time.Minute
var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()}
// AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid)
@@ -210,7 +214,11 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error)
if filePath, ok := params[AWSCredsFileParam]; !ok {
return opts, os.ErrNotExist
} else {
creds := credentials.NewSharedCredentials(filePath.(string), profile)
provider := &credentials.SharedCredentialsProvider{
Filename: filePath.(string),
Profile: profile,
}
creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration))
awsConfig = awsConfig.WithCredentials(creds)
}
case AutoCS:

View File

@@ -0,0 +1,57 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Refreshing credentials will periodically refresh credentials from the
// underlying credential provider. This can be used in places where temporary
// credentials are placed into files, for example, and we need profile
// credentials periodically refreshed, for example.
package awsrefreshcreds
import (
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
)
var now func() time.Time = time.Now
type RefreshingCredentialsProvider struct {
provider credentials.Provider
refreshedAt time.Time
refreshInterval time.Duration
}
func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider {
return &RefreshingCredentialsProvider{
provider: provider,
refreshInterval: interval,
}
}
func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) {
v, err := p.provider.Retrieve()
if err == nil {
p.refreshedAt = now()
}
return v, err
}
func (p *RefreshingCredentialsProvider) IsExpired() bool {
if now().Sub(p.refreshedAt) > p.refreshInterval {
return true
}
return p.provider.IsExpired()
}

View File

@@ -0,0 +1,117 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package awsrefreshcreds
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type staticProvider struct {
v credentials.Value
}
func (p *staticProvider) Retrieve() (credentials.Value, error) {
return p.v, nil
}
func (p *staticProvider) IsExpired() bool {
return false
}
func TestRefreshingCredentialsProvider(t *testing.T) {
var sp staticProvider
sp.v.AccessKeyID = "ExampleOne"
rp := NewRefreshingCredentialsProvider(&sp, time.Minute)
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
v, err := rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleOne", v.AccessKeyID)
assert.False(t, rp.IsExpired())
sp.v.AccessKeyID = "ExampleTwo"
now = func() time.Time { return n.Add(30 * time.Second) }
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
now = func() time.Time { return n.Add(91 * time.Second) }
assert.True(t, rp.IsExpired())
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
}
func TestRefreshingCredentialsProviderShared(t *testing.T) {
d := t.TempDir()
onecontents := `
[backup]
aws_access_key_id = AKIAAAAAAAAAAAAAAAAA
aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
`
twocontents := `
[backup]
aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ
aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
`
configpath := filepath.Join(d, "config")
require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700))
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
creds := credentials.NewCredentials(
NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{
Filename: configpath,
Profile: "backup",
}, time.Minute),
)
v, err := creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID)
require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700))
now = func() time.Time { return n.Add(61 * time.Second) }
v, err = creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID)
}