From b083c24d69a6218f00b653f7ccf7263873ecb755 Mon Sep 17 00:00:00 2001 From: Aaron Son Date: Mon, 23 Oct 2023 16:53:58 -0700 Subject: [PATCH] dbfactory: aws: For AWS remotes using an explicit file credentials, periodically refresh the credentials the AWS client uses from the file contents. Some use cases put attenuated, expiring credentials into files. It's nice to pick up the new credentials without needing to recreate the client. --- go/libraries/doltcore/dbfactory/aws.go | 10 +- go/libraries/utils/awsrefreshcreds/creds.go | 57 +++++++++ .../utils/awsrefreshcreds/creds_test.go | 117 ++++++++++++++++++ 3 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 go/libraries/utils/awsrefreshcreds/creds.go create mode 100644 go/libraries/utils/awsrefreshcreds/creds_test.go diff --git a/go/libraries/doltcore/dbfactory/aws.go b/go/libraries/doltcore/dbfactory/aws.go index d6a802f425..eb6a580824 100644 --- a/go/libraries/doltcore/dbfactory/aws.go +++ b/go/libraries/doltcore/dbfactory/aws.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -27,6 +28,7 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" + "github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds" "github.com/dolthub/dolt/go/store/chunks" "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/nbs" @@ -49,6 +51,8 @@ const ( AWSCredsProfile = "aws-creds-profile" ) +var AWSFileCredsRefreshDuration = time.Minute + var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()} // AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid) @@ -210,7 +214,11 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error) if filePath, ok := params[AWSCredsFileParam]; !ok { return opts, os.ErrNotExist } else { - creds := credentials.NewSharedCredentials(filePath.(string), profile) + provider := &credentials.SharedCredentialsProvider{ + Filename: filePath.(string), + Profile: profile, + } + creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration)) awsConfig = awsConfig.WithCredentials(creds) } case AutoCS: diff --git a/go/libraries/utils/awsrefreshcreds/creds.go b/go/libraries/utils/awsrefreshcreds/creds.go new file mode 100644 index 0000000000..26265285bc --- /dev/null +++ b/go/libraries/utils/awsrefreshcreds/creds.go @@ -0,0 +1,57 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Refreshing credentials will periodically refresh credentials from the +// underlying credential provider. This can be used in places where temporary +// credentials are placed into files, for example, and we need profile +// credentials periodically refreshed, for example. + +package awsrefreshcreds + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +var now func() time.Time = time.Now + +type RefreshingCredentialsProvider struct { + provider credentials.Provider + + refreshedAt time.Time + refreshInterval time.Duration +} + +func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider { + return &RefreshingCredentialsProvider{ + provider: provider, + refreshInterval: interval, + } +} + +func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) { + v, err := p.provider.Retrieve() + if err == nil { + p.refreshedAt = now() + } + return v, err +} + +func (p *RefreshingCredentialsProvider) IsExpired() bool { + if now().Sub(p.refreshedAt) > p.refreshInterval { + return true + } + return p.provider.IsExpired() +} diff --git a/go/libraries/utils/awsrefreshcreds/creds_test.go b/go/libraries/utils/awsrefreshcreds/creds_test.go new file mode 100644 index 0000000000..2346fe9d17 --- /dev/null +++ b/go/libraries/utils/awsrefreshcreds/creds_test.go @@ -0,0 +1,117 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package awsrefreshcreds + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type staticProvider struct { + v credentials.Value +} + +func (p *staticProvider) Retrieve() (credentials.Value, error) { + return p.v, nil +} + +func (p *staticProvider) IsExpired() bool { + return false +} + +func TestRefreshingCredentialsProvider(t *testing.T) { + var sp staticProvider + sp.v.AccessKeyID = "ExampleOne" + rp := NewRefreshingCredentialsProvider(&sp, time.Minute) + + n := time.Now() + origNow := now + t.Cleanup(func() { + now = origNow + }) + now = func() time.Time { return n } + + v, err := rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleOne", v.AccessKeyID) + assert.False(t, rp.IsExpired()) + + sp.v.AccessKeyID = "ExampleTwo" + + now = func() time.Time { return n.Add(30 * time.Second) } + + v, err = rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleTwo", v.AccessKeyID) + assert.False(t, rp.IsExpired()) + + now = func() time.Time { return n.Add(91 * time.Second) } + assert.True(t, rp.IsExpired()) + v, err = rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleTwo", v.AccessKeyID) + assert.False(t, rp.IsExpired()) +} + +func TestRefreshingCredentialsProviderShared(t *testing.T) { + d := t.TempDir() + + onecontents := ` +[backup] +aws_access_key_id = AKIAAAAAAAAAAAAAAAAA +aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +` + + twocontents := ` +[backup] +aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ +aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ +` + + configpath := filepath.Join(d, "config") + + require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700)) + + n := time.Now() + origNow := now + t.Cleanup(func() { + now = origNow + }) + now = func() time.Time { return n } + + creds := credentials.NewCredentials( + NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{ + Filename: configpath, + Profile: "backup", + }, time.Minute), + ) + + v, err := creds.Get() + assert.NoError(t, err) + assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID) + + require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700)) + + now = func() time.Time { return n.Add(61 * time.Second) } + v, err = creds.Get() + assert.NoError(t, err) + assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID) +}