diff --git a/go/libraries/doltcore/dbfactory/aws.go b/go/libraries/doltcore/dbfactory/aws.go index d6a802f425..eb6a580824 100644 --- a/go/libraries/doltcore/dbfactory/aws.go +++ b/go/libraries/doltcore/dbfactory/aws.go @@ -20,6 +20,7 @@ import ( "net/url" "os" "strings" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" @@ -27,6 +28,7 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" + "github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds" "github.com/dolthub/dolt/go/store/chunks" "github.com/dolthub/dolt/go/store/datas" "github.com/dolthub/dolt/go/store/nbs" @@ -49,6 +51,8 @@ const ( AWSCredsProfile = "aws-creds-profile" ) +var AWSFileCredsRefreshDuration = time.Minute + var AWSCredTypes = []string{RoleCS.String(), EnvCS.String(), FileCS.String()} // AWSCredentialSource is an enum type representing the different credential sources (auto, role, env, file, or invalid) @@ -210,7 +214,11 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error) if filePath, ok := params[AWSCredsFileParam]; !ok { return opts, os.ErrNotExist } else { - creds := credentials.NewSharedCredentials(filePath.(string), profile) + provider := &credentials.SharedCredentialsProvider{ + Filename: filePath.(string), + Profile: profile, + } + creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration)) awsConfig = awsConfig.WithCredentials(creds) } case AutoCS: diff --git a/go/libraries/utils/awsrefreshcreds/creds.go b/go/libraries/utils/awsrefreshcreds/creds.go new file mode 100644 index 0000000000..26265285bc --- /dev/null +++ b/go/libraries/utils/awsrefreshcreds/creds.go @@ -0,0 +1,57 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Refreshing credentials will periodically refresh credentials from the +// underlying credential provider. This can be used in places where temporary +// credentials are placed into files, for example, and we need profile +// credentials periodically refreshed, for example. + +package awsrefreshcreds + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" +) + +var now func() time.Time = time.Now + +type RefreshingCredentialsProvider struct { + provider credentials.Provider + + refreshedAt time.Time + refreshInterval time.Duration +} + +func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider { + return &RefreshingCredentialsProvider{ + provider: provider, + refreshInterval: interval, + } +} + +func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) { + v, err := p.provider.Retrieve() + if err == nil { + p.refreshedAt = now() + } + return v, err +} + +func (p *RefreshingCredentialsProvider) IsExpired() bool { + if now().Sub(p.refreshedAt) > p.refreshInterval { + return true + } + return p.provider.IsExpired() +} diff --git a/go/libraries/utils/awsrefreshcreds/creds_test.go b/go/libraries/utils/awsrefreshcreds/creds_test.go new file mode 100644 index 0000000000..2346fe9d17 --- /dev/null +++ b/go/libraries/utils/awsrefreshcreds/creds_test.go @@ -0,0 +1,117 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package awsrefreshcreds + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type staticProvider struct { + v credentials.Value +} + +func (p *staticProvider) Retrieve() (credentials.Value, error) { + return p.v, nil +} + +func (p *staticProvider) IsExpired() bool { + return false +} + +func TestRefreshingCredentialsProvider(t *testing.T) { + var sp staticProvider + sp.v.AccessKeyID = "ExampleOne" + rp := NewRefreshingCredentialsProvider(&sp, time.Minute) + + n := time.Now() + origNow := now + t.Cleanup(func() { + now = origNow + }) + now = func() time.Time { return n } + + v, err := rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleOne", v.AccessKeyID) + assert.False(t, rp.IsExpired()) + + sp.v.AccessKeyID = "ExampleTwo" + + now = func() time.Time { return n.Add(30 * time.Second) } + + v, err = rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleTwo", v.AccessKeyID) + assert.False(t, rp.IsExpired()) + + now = func() time.Time { return n.Add(91 * time.Second) } + assert.True(t, rp.IsExpired()) + v, err = rp.Retrieve() + assert.NoError(t, err) + assert.Equal(t, "ExampleTwo", v.AccessKeyID) + assert.False(t, rp.IsExpired()) +} + +func TestRefreshingCredentialsProviderShared(t *testing.T) { + d := t.TempDir() + + onecontents := ` +[backup] +aws_access_key_id = AKIAAAAAAAAAAAAAAAAA +aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA +` + + twocontents := ` +[backup] +aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ +aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ +` + + configpath := filepath.Join(d, "config") + + require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700)) + + n := time.Now() + origNow := now + t.Cleanup(func() { + now = origNow + }) + now = func() time.Time { return n } + + creds := credentials.NewCredentials( + NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{ + Filename: configpath, + Profile: "backup", + }, time.Minute), + ) + + v, err := creds.Get() + assert.NoError(t, err) + assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID) + + require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700)) + + now = func() time.Time { return n.Add(61 * time.Second) } + v, err = creds.Get() + assert.NoError(t, err) + assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID) +}