[no-release-notes] go: Migrate to aws-sdk-go-v2.

This commit is contained in:
Aaron Son
2025-03-04 13:34:20 -08:00
parent a751dc70f6
commit 01d9782ffc
19 changed files with 4867 additions and 530 deletions

4384
go/Godeps/LICENSES generated

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@ require (
github.com/abiosoft/readline v0.0.0-20180607040430-155bce2042db
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883
github.com/attic-labs/kingpin v2.2.7-0.20180312050558-442efcfac769+incompatible
github.com/aws/aws-sdk-go v1.55.6
github.com/bcicen/jstream v1.0.0
github.com/boltdb/bolt v1.3.1
github.com/denisbrodbeck/machineid v1.0.1
@@ -52,9 +51,15 @@ require (
require (
github.com/Shopify/toxiproxy/v2 v2.5.0
github.com/aliyun/aliyun-oss-go-sdk v2.2.5+incompatible
github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.8
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.64
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.41.0
github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0
github.com/cenkalti/backoff/v4 v4.1.3
github.com/cespare/xxhash/v2 v2.2.0
github.com/creasty/defaults v1.6.0
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.19.1-0.20250305230031-14a57e076a0a
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63
@@ -107,6 +112,22 @@ require (
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
github.com/apache/thrift v0.13.1-0.20201008052519-daf620915714 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.61 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.6.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.25.0 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.33.16 // indirect
github.com/aws/smithy-go v1.22.2 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dolthub/go-icu-regex v0.0.0-20250303123116-549b8d7cad00 // indirect
@@ -130,7 +151,6 @@ require (
github.com/googleapis/gax-go/v2 v2.11.0 // indirect
github.com/gorilla/mux v1.8.0 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/klauspost/compress v1.10.5 // indirect
github.com/klauspost/cpuid/v2 v2.0.12 // indirect
github.com/lestrrat-go/strftime v1.0.4 // indirect

View File

@@ -106,18 +106,60 @@ github.com/aws/aws-sdk-go v1.55.6 h1:cSg4pvZ3m8dgYcgqB97MrcdjUmZ1BeMYKUxMMB89IPk
github.com/aws/aws-sdk-go v1.55.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g=
github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10/go.mod h1:qqvMj6gHLR/EXWZw4ZbqlPbQUyenf4h82UQUlKc+l14=
github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA=
github.com/aws/aws-sdk-go-v2/config v1.29.8 h1:RpwAfYcV2lr/yRc4lWhUM9JRPQqKgKWmou3LV7UfWP4=
github.com/aws/aws-sdk-go-v2/config v1.29.8/go.mod h1:t+G7Fq1OcO8cXTPPXzxQSnj/5Xzdc9jAAD3Xrn9/Mgo=
github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc=
github.com/aws/aws-sdk-go-v2/credentials v1.17.61 h1:Hd/uX6Wo2iUW1JWII+rmyCD7MMhOe7ALwQXN6sKDd1o=
github.com/aws/aws-sdk-go-v2/credentials v1.17.61/go.mod h1:L7vaLkwHY1qgW0gG1zG0z/X0sQ5tpIY5iI13+j3qI80=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 h1:x793wxmUWVDhshP8WW2mlnXuFrO4cOd3HLBroh1paFw=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30/go.mod h1:Jpne2tDnYiFascUEs2AWHJL9Yp7A5ZVy3TNyxaAjD6M=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.64 h1:RTko0AQ0i1vWXDM97DkuW6zskgOxFxm4RqC0kmBJFkE=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.64/go.mod h1:ty968MpOa5CoQ/ALWNB8Gmfoehof2nRHDR/DZDPfimE=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34 h1:ZK5jHhnrioRkUNOc+hOgQKlUL5JeC3S6JgLxtQ+Rm0Q=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.34/go.mod h1:p4VfIceZokChbA9FzMbRGz5OV+lekcVtHlPKEO0gSZY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34 h1:SZwFm17ZUNNg5Np0ioo/gq8Mn6u9w19Mri8DnJ15Jf0=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.34/go.mod h1:dFZsC0BLo346mvKQLWmoJxT+Sjp+qcVR1tRVHQGOH9Q=
github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3 h1:bIqFDwgGXXN1Kpp99pDOdKMTTb5d2KyU5X/BZxjOkRo=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.3/go.mod h1:H5O/EsxDWyU+LP/V8i5sm8cxoZgc2fdNR9bxlOFrQTo=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34 h1:ZNTqv4nIdE/DiBfUUfXcLZ/Spcuz+RjeziUtNJackkM=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.34/go.mod h1:zf7Vcd1ViW7cPqYWEHLHJkS50X0JS2IKz9Cgaj6ugrs=
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.41.0 h1:kSMAk72LZ5eIdY/W+tVV6VdokciajcDdVClEBVNWNP0=
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.41.0/go.mod h1:yYaWRnVSPyAmexW5t7G3TcuYoalYfT+xQwzWsvtUQ7M=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3 h1:eAh2A4b5IzM/lum78bZ590jy36+d/aFLgKF/4Vd1xPE=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.3/go.mod h1:0yKJC/kb8sAnmlYa6Zs3QVYqaC8ug2AbnNChv5Ox3uA=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.6.2 h1:t/gZFyrijKuSU0elA5kRngP/oU3mc0I+Dvp8HwRE4c0=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.6.2/go.mod h1:iu6FSzgt+M2/x3Dk8zhycdIcHjEFb36IS8HVUVFoMg0=
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 h1:M1R1rud7HzDrfCdlBQ7NjnRsDNEhXO/vGhuD189Ggmk=
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15/go.mod h1:uvFKBSq9yMPV4LGAi7N4awn4tLY+hKE35f8THes2mzQ=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 h1:dM9/92u2F1JbDaGooxTq18wmmFzbJRfXfVfy96/1CXM=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15/go.mod h1:SwFBy2vjtA0vZbjjaFtfN045boopadnoVPhu4Fv66vY=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91LiqT1nbvzDukyqAlCv89ZmwaHw/ZFlFZg=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA=
github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI=
github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0 h1:EBm8lXevBWe+kK9VOU/IBeOI189WPRwPUc3LvJK9GOs=
github.com/aws/aws-sdk-go-v2/service/s3 v1.78.0/go.mod h1:4qzsZSzB/KiX2EzDjs9D7A8rI/WGJxZceVJIHqtJjIU=
github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.0 h1:2U9sF8nKy7UgyEeLiZTRg6ShBS22z8UnYpV6aRFL0is=
github.com/aws/aws-sdk-go-v2/service/sso v1.25.0/go.mod h1:qs4a9T5EMLl/Cajiw2TcbNt2UNo/Hqlyp+GiuG4CFDI=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.0 h1:wjAdc85cXdQR5uLx5FwWvGIHm4OPJhTyzUHU8craXtE=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.0/go.mod h1:MlYRNmYu/fGPoxBQVvBYr9nyr948aY/WLUvwBMBJubs=
github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.16 h1:BHEK2Q/7CMRMCb3nySi/w8UbIcPhKvYP5s1xf8/izn0=
github.com/aws/aws-sdk-go-v2/service/sts v1.33.16/go.mod h1:cQnB8CUnxbMU82JvlqjKR2HBOm3fe9pWorWBza6MBJ4=
github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ=
github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/bcicen/jstream v1.0.0 h1:gOi+Sn9mHrpePlENynPKA6Dra/PjLaIpqrTevhfvLAA=
github.com/bcicen/jstream v1.0.0/go.mod h1:9ielPxqFry7Y4Tg3j4BfjPocfJ3TbsRtXOAYXYmRuAQ=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
@@ -173,6 +215,8 @@ github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbj
github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8=
github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:IdqX7J8vi/Kn3T3Ee0VzqnLqwFmgA2hr8WZETPcQjfM=
github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2/go.mod h1:mIEZOHnFx4ZMQeawhw9rhsj+0zwQj7adVsnBX7t+eKY=
github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
@@ -414,9 +458,7 @@ github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d/go.mod
github.com/jcmturner/gofork v0.0.0-20180107083740-2aebee971930/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k=
github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jmoiron/sqlx v1.3.4 h1:wv+0IJZfL5z0uZoUjlpKgHkgaFSYD+r9CfrXjEXsO7w=
github.com/jmoiron/sqlx v1.3.4/go.mod h1:2BljVx/86SuTyjE+aPYlHCTNvZrnJXghYGpNiXLBMCQ=

View File

@@ -513,6 +513,8 @@ github.com/influxdata/influxdb1-client v0.0.0-20191209144304-8bf82d3c094d h1:/WZ
github.com/jcmturner/gofork v0.0.0-20180107083740-2aebee971930 h1:v4CYlQ+HeysPHsr2QFiEO60gKqnvn1xwvuKhhAhuEkk=
github.com/jedib0t/go-pretty v4.3.0+incompatible/go.mod h1:XemHduiw8R651AF9Pt4FwCTKeG3oo7hrHJAoznj9nag=
github.com/jedib0t/go-pretty/v6 v6.4.4/go.mod h1:MgmISkTWDSFu0xOqiZ0mKNntMQ2mDgOcwOkwBEkMDJI=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg=
github.com/jonboulle/clockwork v0.1.0 h1:VKV+ZcuP6l3yW9doeqz6ziZGgcynBVQO+obU0+0hcPo=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=

View File

@@ -17,16 +17,16 @@ package dbfactory
import (
"context"
"errors"
"fmt"
"net/url"
"os"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
"github.com/dolthub/dolt/go/store/chunks"
@@ -138,7 +138,7 @@ func (fact AWSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm
return nil, errors.New("aws url has an invalid format")
}
opts, err := awsConfigFromParams(params)
cfg, err := awsConfigFromParams(ctx, params)
if err != nil {
return nil, err
@@ -150,14 +150,14 @@ func (fact AWSFactory) newChunkStore(ctx context.Context, nbf *types.NomsBinForm
return nil, err
}
sess := session.Must(session.NewSessionWithOptions(opts))
_, err = sess.Config.Credentials.Get()
// Sanity check that we have credentials...
_, err = cfg.Credentials.Retrieve(ctx)
if err != nil {
return nil, err
}
q := nbs.NewUnlimitedMemQuotaProvider()
return nbs.NewAWSStore(ctx, nbf.VersionString(), parts[0], dbName, parts[1], s3.New(sess), dynamodb.New(sess), defaultMemTableSize, q)
return nbs.NewAWSStore(ctx, nbf.VersionString(), parts[0], dbName, parts[1], s3.NewFromConfig(cfg), dynamodb.NewFromConfig(cfg), defaultMemTableSize, q)
}
func validatePath(path string) (string, error) {
@@ -178,28 +178,26 @@ func validatePath(path string) (string, error) {
return path, nil
}
func awsConfigFromParams(params map[string]interface{}) (session.Options, error) {
awsConfig := aws.NewConfig()
func awsConfigFromParams(ctx context.Context, params map[string]interface{}) (aws.Config, error) {
var opts []func(*config.LoadOptions) error
// aws-region always sets the region. Otherwise it comes from AWS_REGION or AWS_DEFAULT_REGION.
if val, ok := params[AWSRegionParam]; ok {
awsConfig = awsConfig.WithRegion(val.(string))
opts = append(opts, config.WithRegion(val.(string)))
}
awsCredsSource := RoleCS
if val, ok := params[AWSCredsTypeParam]; ok {
awsCredsSource = AWSCredentialSourceFromStr(val.(string))
if awsCredsSource == InvalidCS {
return session.Options{}, errors.New("invalid value for aws-creds-source")
return aws.Config{}, errors.New("invalid value for aws-creds-source")
}
}
opts := session.Options{
SharedConfigState: session.SharedConfigEnable,
}
profile := ""
if val, ok := params[AWSCredsProfile]; ok {
profile = val.(string)
opts.Profile = val.(string)
opts = append(opts, config.WithSharedConfigProfile(val.(string)))
}
filePath, ok := params[AWSCredsFileParam]
@@ -209,39 +207,80 @@ func awsConfigFromParams(params map[string]interface{}) (session.Options, error)
switch awsCredsSource {
case EnvCS:
awsConfig = awsConfig.WithCredentials(credentials.NewEnvCredentials())
// Credentials can only come directly from AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY...
creds := awsrefreshcreds.LoadEnvCredentials()
opts = append(opts, config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) {
if !creds.HasKeys() {
return aws.Credentials{}, errors.New("error loading env creds; did not find AWS_ACCESS_KEY_ID or AWS_SECRET_ACCESS_KEY environment variable.")
} else {
return creds, nil
}
})))
case FileCS:
if filePath, ok := params[AWSCredsFileParam]; !ok {
return opts, os.ErrNotExist
return aws.Config{}, os.ErrNotExist
} else {
provider := &credentials.SharedCredentialsProvider{
Filename: filePath.(string),
Profile: profile,
}
creds := credentials.NewCredentials(awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration))
awsConfig = awsConfig.WithCredentials(creds)
provider := awsrefreshcreds.LoadINICredentialsProvider(filePath.(string), profile)
provider = awsrefreshcreds.NewRefreshingCredentialsProvider(provider, AWSFileCredsRefreshDuration)
opts = append(opts, config.WithCredentialsProvider(provider))
}
case AutoCS:
// start by trying to get the credentials from the environment
envCreds := credentials.NewEnvCredentials()
if _, err := envCreds.Get(); err == nil {
awsConfig = awsConfig.WithCredentials(envCreds)
if envCreds := awsrefreshcreds.LoadEnvCredentials(); envCreds.HasKeys() {
opts = append(opts, config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) {
return envCreds, nil
})))
} else {
// if env credentials don't exist try looking for a credentials file
if filePath, ok := params[AWSCredsFileParam]; ok {
if _, err := os.Stat(filePath.(string)); err == nil {
creds := credentials.NewSharedCredentials(filePath.(string), profile)
awsConfig = awsConfig.WithCredentials(creds)
provider := awsrefreshcreds.LoadINICredentialsProvider(filePath.(string), profile)
opts = append(opts, config.WithCredentialsProvider(provider))
}
}
// if file and env do not return valid credentials use the default credentials of the box (same as role)
}
// if file and env do not return valid credentials use the default credentials of the box (same as role)
case RoleCS:
default:
}
opts.Config.MergeIn(awsConfig)
return opts, nil
cfg, err := config.LoadDefaultConfig(ctx, opts...)
var profileErr config.SharedConfigProfileNotExistError
if errors.As(err, &profileErr) {
// XXX: Dolt was originaly using aws-sdk-go, which was
// happy to load the specified shared profile from
// places like AWS_CONFIG_FILE or $HOME/.aws/config,
// but did not complain if it could not find it.
//
// We preserve that behavior here with this gross
// hack. We write a shared config file with an empty
// profile, and we point to that config file when
// loading the config.
if profile == "" {
profile = os.Getenv("AWS_PROFILE")
}
if profile == "" {
profile = os.Getenv("AWS_DEFAULT_PROFILE")
}
path, ferr := makeTempEmptyProfileConfig(profile)
if path != "" {
defer os.Remove(path)
}
if ferr == nil {
opts = append(opts, config.WithSharedConfigFiles([]string{path}))
cfg, err = config.LoadDefaultConfig(ctx, opts...)
}
}
return cfg, err
}
func makeTempEmptyProfileConfig(profile string) (string, error) {
f, err := os.CreateTemp("", "dolt_aws_empty_profile-*")
if err != nil {
return "", err
}
_, err = fmt.Fprintf(f, "[profile %s]\n", profile)
if err != nil {
return f.Name(), errors.Join(err, f.Close())
}
return f.Name(), f.Close()
}

View File

@@ -15,12 +15,13 @@
package dbfactory
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -121,12 +122,11 @@ func TestAWSConfigFromParams(t *testing.T) {
os.Setenv(k, v)
}
}
getSession := func(t *testing.T, params map[string]interface{}) *session.Session {
opts, err := awsConfigFromParams(params)
getConfig := func(t *testing.T, params map[string]interface{}) aws.Config {
cfg, err := awsConfigFromParams(context.Background(), params)
t.Helper()
require.NoError(t, err)
sess, err := session.NewSessionWithOptions(opts)
require.NoError(t, err)
return sess
return cfg
}
// Do not pick up config from any files in the running user's
@@ -166,15 +166,13 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": expectedSecretAccessKey,
"AWS_REGION": expectedRegion,
})
sess := getSession(t, getRttParams(rtt))
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, getRttParams(rtt))
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, expectedRegion, *sess.Config.Region)
}
assert.Equal(t, expectedRegion, cfg.Region)
})
t.Run("CredsInLegacyEnv", func(t *testing.T) {
expectedAccessKeyID := uuid.New().String()
@@ -185,15 +183,13 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SECRET_KEY": expectedSecretAccessKey,
"AWS_DEFAULT_REGION": expectedRegion,
})
sess := getSession(t, getRttParams(rtt))
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, getRttParams(rtt))
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, expectedRegion, *sess.Config.Region)
}
assert.Equal(t, expectedRegion, cfg.Region)
})
t.Run("FilesInEnv", func(t *testing.T) {
t.Run("ProfileInEnv", func(t *testing.T) {
@@ -205,15 +201,13 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_CONFIG_FILE": configFile,
"AWS_SHARED_CREDENTIALS_FILE": credsFile,
})
sess := getSession(t, getRttParams(rtt))
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, getRttParams(rtt))
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
t.Run("ProfileInLegacyEnv", func(t *testing.T) {
loadedProfile := "load_from_file"
@@ -224,15 +218,13 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_CONFIG_FILE": configFile,
"AWS_SHARED_CREDENTIALS_FILE": credsFile,
})
sess := getSession(t, getRttParams(rtt))
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, getRttParams(rtt))
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
t.Run("ProfileInParam", func(t *testing.T) {
loadedProfile := "load_from_file"
@@ -244,15 +236,13 @@ func TestAWSConfigFromParams(t *testing.T) {
})
params := getRttParams(rtt)
params[AWSCredsProfile] = loadedProfile
sess := getSession(t, params)
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, params)
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
t.Run("FileParamOverridesCredsTypeRole", func(t *testing.T) {
// If an aws-creds-file parameter is passed,
@@ -274,15 +264,13 @@ func TestAWSConfigFromParams(t *testing.T) {
params := getRttParams(rtt)
params[AWSCredsProfile] = loadedProfile
params[AWSCredsFileParam] = credsFile
sess := getSession(t, params)
creds, err := sess.Config.Credentials.Get()
cfg := getConfig(t, params)
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
// XXX: Currently there are no tests
// here of web identity token
@@ -319,17 +307,15 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": expectedSecretAccessKey,
"AWS_REGION": envRegion,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSRegionParam: expectedRegion,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, expectedRegion, *sess.Config.Region)
}
assert.Equal(t, expectedRegion, cfg.Region)
})
t.Run("CredsTypeEnv", func(t *testing.T) {
t.Run("PopulatedCreds", func(t *testing.T) {
@@ -341,17 +327,15 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": expectedSecretAccessKey,
"AWS_REGION": expectedRegion,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "env",
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, expectedRegion, *sess.Config.Region)
}
assert.Equal(t, expectedRegion, cfg.Region)
})
t.Run("MissingAccessKeyID", func(t *testing.T) {
expectedSecretAccessKey := uuid.New().String()
@@ -360,10 +344,10 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SECRET_ACCESS_KEY": expectedSecretAccessKey,
"AWS_REGION": expectedRegion,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "env",
})
_, err := sess.Config.Credentials.Get()
_, err := cfg.Credentials.Retrieve(context.Background())
require.Error(t, err)
})
t.Run("MissingSecretAccessKey", func(t *testing.T) {
@@ -373,16 +357,16 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_ACCESS_KEY_ID": expectedAccessKeyID,
"AWS_REGION": expectedRegion,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "env",
})
_, err := sess.Config.Credentials.Get()
_, err := cfg.Credentials.Retrieve(context.Background())
require.Error(t, err)
})
})
t.Run("CredsTypeFile", func(t *testing.T) {
t.Run("FileParamDoesNotExist", func(t *testing.T) {
_, err := awsConfigFromParams(map[string]interface{}{
_, err := awsConfigFromParams(context.Background(), map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsProfile: "some_profile",
})
@@ -395,12 +379,12 @@ func TestAWSConfigFromParams(t *testing.T) {
setEnv(t, map[string]string{
"AWS_CONFIG_FILE": configFile,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsProfile: loadedProfile,
AWSCredsFileParam: credsFile,
})
_, err = sess.Config.Credentials.Get()
_, err = cfg.Credentials.Retrieve(context.Background())
require.Error(t, err)
})
t.Run("ProfileFromParamDoesNotExist", func(t *testing.T) {
@@ -410,12 +394,12 @@ func TestAWSConfigFromParams(t *testing.T) {
setEnv(t, map[string]string{
"AWS_CONFIG_FILE": configFile,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsProfile: loadedProfile,
AWSCredsFileParam: credsFile,
})
_, err = sess.Config.Credentials.Get()
_, err = cfg.Credentials.Retrieve(context.Background())
require.Error(t, err)
})
t.Run("ProfileFromEnvDoesNotExist", func(t *testing.T) {
@@ -426,11 +410,11 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_CONFIG_FILE": configFile,
"AWS_PROFILE": loadedProfile,
})
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
})
_, err = sess.Config.Credentials.Get()
_, err = cfg.Credentials.Retrieve(context.Background())
require.Error(t, err)
})
type profileOnlyHasCredsTest struct {
@@ -459,11 +443,11 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_CONFIG_FILE": tt.fileEnv,
})
}
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, onlyCredsProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, onlyCredsProfileSecretAccessKey, creds.SecretAccessKey)
@@ -481,12 +465,12 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_CONFIG_FILE": tt.fileEnv,
})
}
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
AWSCredsProfile: loadedProfile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, onlyCredsProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, onlyCredsProfileSecretAccessKey, creds.SecretAccessKey)
@@ -507,19 +491,17 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SHARED_CREDENTIALS_FILE": altCredsFile,
})
credsFile := filepath.Join(cwd, "testdata", "basic_creds_file")
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsProfile: loadedProfile,
AWSCredsFileParam: credsFile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
t.Run("NoProfileUsesDefault", func(t *testing.T) {
// If no aws-profile parameter is supplied,
@@ -536,18 +518,16 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SHARED_CREDENTIALS_FILE": altCredsFile,
})
credsFile := filepath.Join(cwd, "testdata", "basic_creds_file")
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, defaultProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, defaultProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, defaultProfileRegion, *sess.Config.Region)
}
assert.Equal(t, defaultProfileRegion, cfg.Region)
})
t.Run("ProfileInEnv", func(t *testing.T) {
// If no aws-profile parameter is supplied,
@@ -566,18 +546,16 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SHARED_CREDENTIALS_FILE": altCredsFile,
})
credsFile := filepath.Join(cwd, "testdata", "basic_creds_file")
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, loadFromFileProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, loadFromFileProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
t.Run("SplitBrainProfileInLegacyEnv", func(t *testing.T) {
// If no aws-profile parameter is supplied,
@@ -603,18 +581,16 @@ func TestAWSConfigFromParams(t *testing.T) {
"AWS_SHARED_CREDENTIALS_FILE": altCredsFile,
})
credsFile := filepath.Join(cwd, "testdata", "basic_creds_file")
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFile,
})
creds, err := sess.Config.Credentials.Get()
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, defaultProfileAccessKeyID, creds.AccessKeyID)
assert.Equal(t, defaultProfileSecretAccessKey, creds.SecretAccessKey)
}
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, loadFromFileProfileRegion, *sess.Config.Region)
}
assert.Equal(t, loadFromFileProfileRegion, cfg.Region)
})
})
t.Run("CredentialsFileRefresh", func(t *testing.T) {
@@ -642,16 +618,14 @@ aws_access_key_id = new_access_key_id
aws_secret_access_key = new_secret_access_key
`)
require.NoError(t, os.WriteFile(credsFilePath, credsFileContents, 0660))
sess := getSession(t, map[string]interface{}{
cfg := getConfig(t, map[string]interface{}{
AWSCredsTypeParam: "file",
AWSCredsFileParam: credsFilePath,
AWSRegionParam: "us-west-2",
AWSCredsProfile: "some_profile",
})
if assert.NotNil(t, sess.Config.Region) {
assert.Equal(t, "us-west-2", *sess.Config.Region)
}
creds, err := sess.Config.Credentials.Get()
assert.Equal(t, "us-west-2", cfg.Region)
creds, err := cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, "original_access_key_id", creds.AccessKeyID)
assert.Equal(t, "original_secret_access_key", creds.SecretAccessKey)
@@ -659,7 +633,7 @@ aws_secret_access_key = new_secret_access_key
require.NoError(t, os.WriteFile(filepath.Join(dir, "new_creds_file"), newCredsFileContents, 0660))
require.NoError(t, os.Rename(filepath.Join(dir, "new_creds_file"), credsFilePath))
time.Sleep(10 * time.Millisecond)
creds, err = sess.Config.Credentials.Get()
creds, err = cfg.Credentials.Retrieve(context.Background())
if assert.NoError(t, err) {
assert.Equal(t, "new_access_key_id", creds.AccessKeyID)
assert.Equal(t, "new_secret_access_key", creds.SecretAccessKey)

View File

@@ -20,38 +20,94 @@
package awsrefreshcreds
import (
"context"
"fmt"
"os"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go-v2/aws"
ini "github.com/dolthub/aws-sdk-go-ini-parser"
)
var now func() time.Time = time.Now
var _ aws.CredentialsProvider = (*RefreshingCredentialsProvider)(nil)
type RefreshingCredentialsProvider struct {
provider credentials.Provider
refreshedAt time.Time
refreshInterval time.Duration
provider aws.CredentialsProvider
interval time.Duration
}
func NewRefreshingCredentialsProvider(provider credentials.Provider, interval time.Duration) *RefreshingCredentialsProvider {
func NewRefreshingCredentialsProvider(provider aws.CredentialsProvider, interval time.Duration) *RefreshingCredentialsProvider {
return &RefreshingCredentialsProvider{
provider: provider,
refreshInterval: interval,
provider: provider,
interval: interval,
}
}
func (p *RefreshingCredentialsProvider) Retrieve() (credentials.Value, error) {
v, err := p.provider.Retrieve()
if err == nil {
p.refreshedAt = now()
func (p *RefreshingCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
res, err := p.provider.Retrieve(ctx)
if err == nil && res.CanExpire == false {
res.CanExpire = true
res.Expires = now().Add(p.interval)
}
return v, err
return res, err
}
func (p *RefreshingCredentialsProvider) IsExpired() bool {
if now().Sub(p.refreshedAt) > p.refreshInterval {
return true
// Based on the behavior of EnvConfig in aws-sdk-go-v2.
func LoadEnvCredentials() aws.Credentials {
var ret aws.Credentials
ret.AccessKeyID = os.Getenv("AWS_ACCESS_KEY_ID")
if ret.AccessKeyID == "" {
ret.AccessKeyID = os.Getenv("AWS_ACCESS_KEY")
}
return p.provider.IsExpired()
ret.SecretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY")
if ret.SecretAccessKey == "" {
ret.SecretAccessKey = os.Getenv("AWS_SECRET_KEY")
}
if ret.HasKeys() {
ret.SessionToken = os.Getenv("AWS_SESSION_TOKEN")
ret.Source = "EnvironmentVariables"
return ret
}
return aws.Credentials{}
}
func LoadINICredentialsProvider(filename, profile string) aws.CredentialsProvider {
if profile == "" {
profile = os.Getenv("AWS_PROFILE")
}
if profile == "" {
profile = "default"
}
return aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) {
sections, err := ini.OpenFile(filename)
if err != nil {
return aws.Credentials{}, err
}
section, ok := sections.GetSection(profile)
if !ok {
return aws.Credentials{}, fmt.Errorf("error loading credentials for profile %s from file %s; profile not found", profile, filename)
}
id := section.String("aws_access_key_id")
if len(id) == 0 {
return aws.Credentials{}, fmt.Errorf("error loading credentials for profile %s from file %s; no aws_access_key_id", profile, filename)
}
secret := section.String("aws_secret_access_key")
if len(secret) == 0 {
return aws.Credentials{}, fmt.Errorf("error loading credentials for profile %s from file %s; no aws_secret_access_key", profile, filename)
}
// Default to empty string if not found
token := section.String("aws_session_token")
return aws.Credentials{
AccessKeyID: id,
SecretAccessKey: secret,
SessionToken: token,
Source: "SharedCredentialsFile",
}, nil
})
}

View File

@@ -1,117 +0,0 @@
// Copyright 2023 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package awsrefreshcreds
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type staticProvider struct {
v credentials.Value
}
func (p *staticProvider) Retrieve() (credentials.Value, error) {
return p.v, nil
}
func (p *staticProvider) IsExpired() bool {
return false
}
func TestRefreshingCredentialsProvider(t *testing.T) {
var sp staticProvider
sp.v.AccessKeyID = "ExampleOne"
rp := NewRefreshingCredentialsProvider(&sp, time.Minute)
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
v, err := rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleOne", v.AccessKeyID)
assert.False(t, rp.IsExpired())
sp.v.AccessKeyID = "ExampleTwo"
now = func() time.Time { return n.Add(30 * time.Second) }
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
now = func() time.Time { return n.Add(91 * time.Second) }
assert.True(t, rp.IsExpired())
v, err = rp.Retrieve()
assert.NoError(t, err)
assert.Equal(t, "ExampleTwo", v.AccessKeyID)
assert.False(t, rp.IsExpired())
}
func TestRefreshingCredentialsProviderShared(t *testing.T) {
d := t.TempDir()
onecontents := `
[backup]
aws_access_key_id = AKIAAAAAAAAAAAAAAAAA
aws_secret_access_key = oF8x/JQEGchAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
`
twocontents := `
[backup]
aws_access_key_id = AKIZZZZZZZZZZZZZZZZZ
aws_secret_access_key = oF8x/JQEGchZZZZZZZZZZZZZZZZZZZZZZZZZZZZZ
`
configpath := filepath.Join(d, "config")
require.NoError(t, os.WriteFile(configpath, []byte(onecontents), 0700))
n := time.Now()
origNow := now
t.Cleanup(func() {
now = origNow
})
now = func() time.Time { return n }
creds := credentials.NewCredentials(
NewRefreshingCredentialsProvider(&credentials.SharedCredentialsProvider{
Filename: configpath,
Profile: "backup",
}, time.Minute),
)
v, err := creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIAAAAAAAAAAAAAAAAA", v.AccessKeyID)
require.NoError(t, os.WriteFile(configpath, []byte(twocontents), 0700))
now = func() time.Time { return n.Add(61 * time.Second) }
v, err = creds.Get()
assert.NoError(t, err)
assert.Equal(t, "AKIZZZZZZZZZZZZZZZZZ", v.AccessKeyID)
}

View File

@@ -32,11 +32,10 @@ import (
"sync"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
"github.com/aws/aws-sdk-go-v2/aws"
s3manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/dolthub/dolt/go/store/atomicerr"
"github.com/dolthub/dolt/go/store/chunks"
@@ -52,8 +51,20 @@ const (
defaultS3PartSize = minS3PartSize // smallest allowed by S3 allows for most throughput
)
type S3APIV2 interface {
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error)
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error)
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error)
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error)
UploadPartCopy(context.Context, *s3.UploadPartCopyInput, ...func(*s3.Options)) (*s3.UploadPartCopyOutput, error)
PutObject(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error)
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error)
}
var _ S3APIV2 = (*s3.Client)(nil)
type awsTablePersister struct {
s3 s3iface.S3API
s3 S3APIV2
bucket string
rl chan struct{}
limits awsLimits
@@ -84,13 +95,8 @@ func (s3p awsTablePersister) Open(ctx context.Context, name hash.Hash, chunkCoun
return cs, nil
}
var reqErr awserr.RequestFailure
if errors.As(err, &reqErr) {
if reqErr.Code() != "NoSuchKey" || reqErr.StatusCode() != 404 {
return emptyChunkSource{}, err
}
} else {
// Probably won't ever happen.
var nskErr *s3types.NoSuchKey
if !errors.As(err, &nskErr) {
return emptyChunkSource{}, err
}
@@ -129,7 +135,7 @@ func (s3p awsTablePersister) AccessMode() chunks.ExclusiveAccessMode {
}
type s3UploadedPart struct {
idx int64
idx int32
etag string
}
@@ -168,10 +174,10 @@ func (s3p awsTablePersister) Persist(ctx context.Context, mt *memTable, haver ch
}
func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, sz uint64, key string) error {
uploader := s3manager.NewUploaderWithClient(s3p.s3, func(u *s3manager.Uploader) {
uploader := s3manager.NewUploader(s3p.s3, func(u *s3manager.Uploader) {
u.PartSize = int64(s3p.limits.partTarget)
})
_, err := uploader.Upload(&s3manager.UploadInput{
_, err := uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
Body: r,
@@ -180,7 +186,7 @@ func (s3p awsTablePersister) multipartUpload(ctx context.Context, r io.Reader, s
}
func (s3p awsTablePersister) startMultipartUpload(ctx context.Context, key string) (string, error) {
result, err := s3p.s3.CreateMultipartUploadWithContext(ctx, &s3.CreateMultipartUploadInput{
result, err := s3p.s3.CreateMultipartUpload(ctx, &s3.CreateMultipartUploadInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
})
@@ -193,7 +199,7 @@ func (s3p awsTablePersister) startMultipartUpload(ctx context.Context, key strin
}
func (s3p awsTablePersister) abortMultipartUpload(ctx context.Context, key, uploadID string) error {
_, abrtErr := s3p.s3.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{
_, abrtErr := s3p.s3.AbortMultipartUpload(ctx, &s3.AbortMultipartUploadInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
UploadId: aws.String(uploadID),
@@ -202,8 +208,8 @@ func (s3p awsTablePersister) abortMultipartUpload(ctx context.Context, key, uplo
return abrtErr
}
func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, uploadID string, mpu *s3.CompletedMultipartUpload) error {
_, err := s3p.s3.CompleteMultipartUploadWithContext(ctx, &s3.CompleteMultipartUploadInput{
func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, uploadID string, mpu *s3types.CompletedMultipartUpload) error {
_, err := s3p.s3.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
MultipartUpload: mpu,
@@ -213,15 +219,15 @@ func (s3p awsTablePersister) completeMultipartUpload(ctx context.Context, key, u
return err
}
func getNumParts(dataLen, minPartSize uint64) uint64 {
func getNumParts(dataLen, minPartSize uint64) uint32 {
numParts := dataLen / minPartSize
if numParts == 0 {
numParts = 1
}
return numParts
return uint32(numParts)
}
type partsByPartNum []*s3.CompletedPart
type partsByPartNum []s3types.CompletedPart
func (s partsByPartNum) Len() int {
return len(s)
@@ -275,7 +281,7 @@ func (s3p awsTablePersister) executeCompactionPlan(ctx context.Context, plan com
return s3p.completeMultipartUpload(ctx, key, uploadID, multipartUpload)
}
func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionPlan, key, uploadID string) (*s3.CompletedMultipartUpload, error) {
func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionPlan, key, uploadID string) (*s3types.CompletedMultipartUpload, error) {
if len(plan.sources.sws) > maxS3Parts {
return nil, errors.New("exceeded maximum parts")
}
@@ -310,7 +316,7 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
sent, failed, done := make(chan s3UploadedPart), make(chan error), make(chan struct{})
var uploadWg sync.WaitGroup
type uploadFn func() (etag string, err error)
sendPart := func(partNum int64, doUpload uploadFn) {
sendPart := func(partNum int32, doUpload uploadFn) {
if s3p.rl != nil {
s3p.rl <- struct{}{}
defer func() { <-s3p.rl }()
@@ -331,7 +337,7 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
}
// Try to send along part info. In the case that the upload was aborted, reading from done allows this worker to exit correctly.
select {
case sent <- s3UploadedPart{int64(partNum), etag}:
case sent <- s3UploadedPart{partNum, etag}:
case <-done:
return
}
@@ -339,10 +345,10 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
// Concurrently begin sending all parts using sendPart().
// First, kick off sending all the copyable parts.
partNum := int64(1) // Part numbers are 1-indexed
partNum := int32(1) // Part numbers are 1-indexed
for _, cp := range copies {
uploadWg.Add(1)
go func(cp copyPart, partNum int64) {
go func(cp copyPart, partNum int32) {
sendPart(partNum, func() (etag string, err error) {
return s3p.uploadPartCopy(ctx, cp.name, cp.srcOffset, cp.srcLen, key, uploadID, partNum)
})
@@ -352,13 +358,13 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
// Then, split buff (data from |manuals| and index) into parts and upload those concurrently.
numManualParts := getNumParts(uint64(len(buff)), s3p.limits.partTarget) // TODO: What if this is too big?
for i := uint64(0); i < numManualParts; i++ {
start, end := i*s3p.limits.partTarget, (i+1)*s3p.limits.partTarget
for i := uint32(0); i < numManualParts; i++ {
start, end := uint64(i)*s3p.limits.partTarget, uint64(i+1)*s3p.limits.partTarget
if i+1 == numManualParts { // If this is the last part, make sure it includes any overflow
end = uint64(len(buff))
}
uploadWg.Add(1)
go func(data []byte, partNum int64) {
go func(data []byte, partNum int32) {
sendPart(partNum, func() (etag string, err error) {
return s3p.uploadPart(ctx, data, key, uploadID, partNum)
})
@@ -374,15 +380,15 @@ func (s3p awsTablePersister) assembleTable(ctx context.Context, plan compactionP
}()
// Watch |sent| and |failed| for the results of part uploads. If ever one fails, close |done| to stop all the in-progress or pending sendPart() calls and then bail.
multipartUpload := &s3.CompletedMultipartUpload{}
multipartUpload := &s3types.CompletedMultipartUpload{}
var firstFailure error
for cont := true; cont; {
select {
case sentPart, open := <-sent:
if open {
multipartUpload.Parts = append(multipartUpload.Parts, &s3.CompletedPart{
multipartUpload.Parts = append(multipartUpload.Parts, s3types.CompletedPart{
ETag: aws.String(sentPart.etag),
PartNumber: aws.Int64(sentPart.idx),
PartNumber: aws.Int32(sentPart.idx),
})
}
cont = open
@@ -485,13 +491,13 @@ func splitOnMaxSize(dataLen, maxPartSize uint64) []int64 {
return sizes
}
func (s3p awsTablePersister) uploadPartCopy(ctx context.Context, src string, srcStart, srcEnd int64, key, uploadID string, partNum int64) (etag string, err error) {
res, err := s3p.s3.UploadPartCopyWithContext(ctx, &s3.UploadPartCopyInput{
func (s3p awsTablePersister) uploadPartCopy(ctx context.Context, src string, srcStart, srcEnd int64, key, uploadID string, partNum int32) (etag string, err error) {
res, err := s3p.s3.UploadPartCopy(ctx, &s3.UploadPartCopyInput{
CopySource: aws.String(url.PathEscape(s3p.bucket + "/" + s3p.key(src))),
CopySourceRange: aws.String(httpRangeHeader(srcStart, srcEnd)),
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
PartNumber: aws.Int64(int64(partNum)),
PartNumber: aws.Int32(partNum),
UploadId: aws.String(uploadID),
})
if err == nil {
@@ -500,11 +506,11 @@ func (s3p awsTablePersister) uploadPartCopy(ctx context.Context, src string, src
return
}
func (s3p awsTablePersister) uploadPart(ctx context.Context, data []byte, key, uploadID string, partNum int64) (etag string, err error) {
res, err := s3p.s3.UploadPartWithContext(ctx, &s3.UploadPartInput{
func (s3p awsTablePersister) uploadPart(ctx context.Context, data []byte, key, uploadID string, partNum int32) (etag string, err error) {
res, err := s3p.s3.UploadPart(ctx, &s3.UploadPartInput{
Bucket: aws.String(s3p.bucket),
Key: aws.String(s3p.key(key)),
PartNumber: aws.Int64(int64(partNum)),
PartNumber: aws.Int32(partNum),
UploadId: aws.String(uploadID),
Body: bytes.NewReader(data),
})

View File

@@ -29,10 +29,7 @@ import (
"sync"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -191,12 +188,12 @@ type failingFakeS3 struct {
numSuccesses int
}
func (m *failingFakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) {
func (m *failingFakeS3) UploadPart(ctx context.Context, input *s3.UploadPartInput, opts ...func(*s3.Options)) (*s3.UploadPartOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.numSuccesses > 0 {
m.numSuccesses--
return m.fakeS3.UploadPartWithContext(ctx, input)
return m.fakeS3.UploadPart(ctx, input)
}
return nil, mockAWSError("MalformedXML")
}
@@ -277,7 +274,7 @@ func TestAWSTablePersisterConjoinAll(t *testing.T) {
rl := make(chan struct{}, 8)
defer close(rl)
newPersister := func(s3svc s3iface.S3API) awsTablePersister {
newPersister := func(s3svc S3APIV2) awsTablePersister {
return awsTablePersister{
s3svc,
"bucket",

View File

@@ -29,10 +29,11 @@ import (
"sort"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/dustin/go-humanize"
flag "github.com/juju/gnuflag"
"github.com/stretchr/testify/assert"
@@ -136,16 +137,19 @@ func main() {
}
} else if *toAWS != "" {
sess := session.Must(session.NewSession(aws.NewConfig().WithRegion("us-west-2")))
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-west-2"))
d.PanicIfError(err)
open = func() (chunks.ChunkStore, error) {
return nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), dynamoTable, *toAWS, s3Bucket, s3.New(sess), dynamodb.New(sess), bufSize, nbs.NewUnlimitedMemQuotaProvider())
return nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), dynamoTable, *toAWS, s3Bucket, s3.NewFromConfig(cfg), dynamodb.NewFromConfig(cfg), bufSize, nbs.NewUnlimitedMemQuotaProvider())
}
reset = func() {
ddb := dynamodb.New(sess)
_, err := ddb.DeleteItem(&dynamodb.DeleteItemInput{
ddb := dynamodb.NewFromConfig(cfg)
_, err := ddb.DeleteItem(context.Background(), &dynamodb.DeleteItemInput{
TableName: aws.String(dynamoTable),
Key: map[string]*dynamodb.AttributeValue{
"db": {S: toAWS},
Key: map[string]ddbtypes.AttributeValue{
"db": &ddbtypes.AttributeValueMemberS{
Value: *toAWS,
},
},
})
d.PanicIfError(err)
@@ -163,9 +167,10 @@ func main() {
return nbs.NewLocalStore(context.Background(), types.Format_Default.VersionString(), *useNBS, bufSize, nbs.NewUnlimitedMemQuotaProvider())
}
} else if *useAWS != "" {
sess := session.Must(session.NewSession(aws.NewConfig().WithRegion("us-west-2")))
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-west-2"))
d.PanicIfError(err)
open = func() (chunks.ChunkStore, error) {
return nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), dynamoTable, *useAWS, s3Bucket, s3.New(sess), dynamodb.New(sess), bufSize, nbs.NewUnlimitedMemQuotaProvider())
return nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), dynamoTable, *useAWS, s3Bucket, s3.NewFromConfig(cfg), dynamodb.NewFromConfig(cfg), bufSize, nbs.NewUnlimitedMemQuotaProvider())
}
}
writeDB = func() {}

View File

@@ -23,12 +23,12 @@ package nbs
import (
"bytes"
"context"
"sync/atomic"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/stretchr/testify/assert"
"github.com/dolthub/dolt/go/store/constants"
@@ -52,24 +52,24 @@ func makeFakeDDB(t *testing.T) *fakeDDB {
}
}
func (m *fakeDDB) GetItemWithContext(ctx aws.Context, input *dynamodb.GetItemInput, opts ...request.Option) (*dynamodb.GetItemOutput, error) {
key := input.Key[dbAttr].S
assert.NotNil(m.t, key, "key should have been a String: %+v", input.Key[dbAttr])
func (m *fakeDDB) GetItem(ctx context.Context, input *dynamodb.GetItemInput, opts ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) {
keyM := input.Key[dbAttr].(*ddbtypes.AttributeValueMemberS)
assert.NotNil(m.t, keyM, "key should have been a String: %+v", input.Key[dbAttr])
item := map[string]*dynamodb.AttributeValue{}
if e, present := m.data[*key]; present {
item[dbAttr] = &dynamodb.AttributeValue{S: key}
item := map[string]ddbtypes.AttributeValue{}
if e, present := m.data[keyM.Value]; present {
item[dbAttr] = &ddbtypes.AttributeValueMemberS{Value: keyM.Value}
switch e := e.(type) {
case record:
item[nbsVersAttr] = &dynamodb.AttributeValue{S: aws.String(AWSStorageVersion)}
item[versAttr] = &dynamodb.AttributeValue{S: aws.String(e.vers)}
item[rootAttr] = &dynamodb.AttributeValue{B: e.root}
item[lockAttr] = &dynamodb.AttributeValue{B: e.lock}
item[nbsVersAttr] = &ddbtypes.AttributeValueMemberS{Value: AWSStorageVersion}
item[versAttr] = &ddbtypes.AttributeValueMemberS{Value: e.vers}
item[rootAttr] = &ddbtypes.AttributeValueMemberB{Value: e.root}
item[lockAttr] = &ddbtypes.AttributeValueMemberB{Value: e.lock}
if e.specs != "" {
item[tableSpecsAttr] = &dynamodb.AttributeValue{S: aws.String(e.specs)}
item[tableSpecsAttr] = &ddbtypes.AttributeValueMemberS{Value: e.specs}
}
if e.appendix != "" {
item[appendixAttr] = &dynamodb.AttributeValue{S: aws.String(e.appendix)}
item[appendixAttr] = &ddbtypes.AttributeValueMemberS{Value: e.appendix}
}
}
}
@@ -85,46 +85,46 @@ func (m *fakeDDB) putData(k string, d []byte) {
m.data[k] = d
}
func (m *fakeDDB) PutItemWithContext(ctx aws.Context, input *dynamodb.PutItemInput, opts ...request.Option) (*dynamodb.PutItemOutput, error) {
func (m *fakeDDB) PutItem(ctx context.Context, input *dynamodb.PutItemInput, opts ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) {
assert.NotNil(m.t, input.Item[dbAttr], "%s should have been present", dbAttr)
assert.NotNil(m.t, input.Item[dbAttr].S, "key should have been a String: %+v", input.Item[dbAttr])
key := *input.Item[dbAttr].S
assert.NotNil(m.t, input.Item[dbAttr].(*ddbtypes.AttributeValueMemberS), "key should have been a String: %+v", input.Item[dbAttr])
key := input.Item[dbAttr].(*ddbtypes.AttributeValueMemberS).Value
assert.NotNil(m.t, input.Item[nbsVersAttr], "%s should have been present", nbsVersAttr)
assert.NotNil(m.t, input.Item[nbsVersAttr].S, "nbsVers should have been a String: %+v", input.Item[nbsVersAttr])
assert.Equal(m.t, AWSStorageVersion, *input.Item[nbsVersAttr].S)
assert.NotNil(m.t, input.Item[nbsVersAttr].(*ddbtypes.AttributeValueMemberS), "nbsVers should have been a String: %+v", input.Item[nbsVersAttr])
assert.Equal(m.t, AWSStorageVersion, input.Item[nbsVersAttr].(*ddbtypes.AttributeValueMemberS).Value)
assert.NotNil(m.t, input.Item[versAttr], "%s should have been present", versAttr)
assert.NotNil(m.t, input.Item[versAttr].S, "nbsVers should have been a String: %+v", input.Item[versAttr])
assert.Equal(m.t, constants.FormatLD1String, *input.Item[versAttr].S)
assert.NotNil(m.t, input.Item[versAttr].(*ddbtypes.AttributeValueMemberS), "nbsVers should have been a String: %+v", input.Item[versAttr])
assert.Equal(m.t, constants.FormatLD1String, input.Item[versAttr].(*ddbtypes.AttributeValueMemberS).Value)
assert.NotNil(m.t, input.Item[lockAttr], "%s should have been present", lockAttr)
assert.NotNil(m.t, input.Item[lockAttr].B, "lock should have been a blob: %+v", input.Item[lockAttr])
lock := input.Item[lockAttr].B
assert.NotNil(m.t, input.Item[lockAttr].(*ddbtypes.AttributeValueMemberB), "lock should have been a blob: %+v", input.Item[lockAttr])
lock := input.Item[lockAttr].(*ddbtypes.AttributeValueMemberB).Value
assert.NotNil(m.t, input.Item[rootAttr], "%s should have been present", rootAttr)
assert.NotNil(m.t, input.Item[rootAttr].B, "root should have been a blob: %+v", input.Item[rootAttr])
root := input.Item[rootAttr].B
assert.NotNil(m.t, input.Item[rootAttr].(*ddbtypes.AttributeValueMemberB), "root should have been a blob: %+v", input.Item[rootAttr])
root := input.Item[rootAttr].(*ddbtypes.AttributeValueMemberB).Value
specs := ""
if attr, present := input.Item[tableSpecsAttr]; present {
assert.NotNil(m.t, attr.S, "specs should have been a String: %+v", input.Item[tableSpecsAttr])
specs = *attr.S
assert.NotNil(m.t, attr.(*ddbtypes.AttributeValueMemberS), "specs should have been a String: %+v", input.Item[tableSpecsAttr])
specs = attr.(*ddbtypes.AttributeValueMemberS).Value
}
apps := ""
if attr, present := input.Item[appendixAttr]; present {
assert.NotNil(m.t, attr.S, "appendix specs should have been a String: %+v", input.Item[appendixAttr])
apps = *attr.S
assert.NotNil(m.t, attr.(*ddbtypes.AttributeValueMemberS), "appendix specs should have been a String: %+v", input.Item[appendixAttr])
apps = attr.(*ddbtypes.AttributeValueMemberS).Value
}
mustNotExist := *(input.ConditionExpression) == valueNotExistsOrEqualsExpression
current, present := m.data[key]
if mustNotExist && present {
return nil, mockAWSError("ConditionalCheckFailedException")
return nil, &ddbtypes.ConditionalCheckFailedException{}
} else if !mustNotExist && !checkCondition(current.(record), input.ExpressionAttributeValues) {
return nil, mockAWSError("ConditionalCheckFailedException")
return nil, &ddbtypes.ConditionalCheckFailedException{}
}
m.putRecord(key, lock, root, constants.FormatLD1String, specs, apps)
@@ -133,8 +133,8 @@ func (m *fakeDDB) PutItemWithContext(ctx aws.Context, input *dynamodb.PutItemInp
return &dynamodb.PutItemOutput{}, nil
}
func checkCondition(current record, expressionAttrVals map[string]*dynamodb.AttributeValue) bool {
return current.vers == *expressionAttrVals[versExpressionValuesKey].S && bytes.Equal(current.lock, expressionAttrVals[prevLockExpressionValuesKey].B)
func checkCondition(current record, expressionAttrVals map[string]ddbtypes.AttributeValue) bool {
return current.vers == expressionAttrVals[versExpressionValuesKey].(*ddbtypes.AttributeValueMemberS).Value && bytes.Equal(current.lock, expressionAttrVals[prevLockExpressionValuesKey].(*ddbtypes.AttributeValueMemberB).Value)
}

View File

@@ -28,10 +28,9 @@ import (
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
ddbtypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types"
"github.com/dolthub/dolt/go/store/d"
"github.com/dolthub/dolt/go/store/hash"
@@ -57,18 +56,18 @@ var (
valueNotExistsOrEqualsExpression = fmt.Sprintf("attribute_not_exists("+lockAttr+") or %s", valueEqualsExpression)
)
type ddbsvc interface {
GetItemWithContext(ctx aws.Context, input *dynamodb.GetItemInput, opts ...request.Option) (*dynamodb.GetItemOutput, error)
PutItemWithContext(ctx aws.Context, input *dynamodb.PutItemInput, opts ...request.Option) (*dynamodb.PutItemOutput, error)
type DynamoDBAPIV2 interface {
GetItem(context.Context, *dynamodb.GetItemInput, ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error)
PutItem(context.Context, *dynamodb.PutItemInput, ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error)
}
// dynamoManifest assumes the existence of a DynamoDB table whose primary partition key is in String format and named `db`.
type dynamoManifest struct {
table, db string
ddbsvc ddbsvc
ddbsvc DynamoDBAPIV2
}
func newDynamoManifest(table, namespace string, ddb ddbsvc) manifest {
func newDynamoManifest(table, namespace string, ddb DynamoDBAPIV2) manifest {
d.PanicIfTrue(table == "")
d.PanicIfTrue(namespace == "")
return dynamoManifest{table, namespace, ddb}
@@ -85,11 +84,13 @@ func (dm dynamoManifest) ParseIfExists(ctx context.Context, stats *Stats, readHo
var exists bool
var contents manifestContents
result, err := dm.ddbsvc.GetItemWithContext(ctx, &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(true), // This doubles the cost :-(
result, err := dm.ddbsvc.GetItem(ctx, &dynamodb.GetItemInput{
ConsistentRead: aws.Bool(true),
TableName: aws.String(dm.table),
Key: map[string]*dynamodb.AttributeValue{
dbAttr: {S: aws.String(dm.db)},
Key: map[string]ddbtypes.AttributeValue{
dbAttr: &ddbtypes.AttributeValueMemberS{
Value: dm.db,
},
},
})
@@ -105,18 +106,18 @@ func (dm dynamoManifest) ParseIfExists(ctx context.Context, stats *Stats, readHo
}
exists = true
contents.nbfVers = *result.Item[versAttr].S
contents.root = hash.New(result.Item[rootAttr].B)
copy(contents.lock[:], result.Item[lockAttr].B)
contents.nbfVers = result.Item[versAttr].(*ddbtypes.AttributeValueMemberS).Value
contents.root = hash.New(result.Item[rootAttr].(*ddbtypes.AttributeValueMemberB).Value)
copy(contents.lock[:], result.Item[lockAttr].(*ddbtypes.AttributeValueMemberB).Value)
if hasSpecs {
contents.specs, err = parseSpecs(strings.Split(*result.Item[tableSpecsAttr].S, ":"))
contents.specs, err = parseSpecs(strings.Split(result.Item[tableSpecsAttr].(*ddbtypes.AttributeValueMemberS).Value, ":"))
if err != nil {
return false, manifestContents{}, ErrCorruptManifest
}
}
if hasAppendix {
contents.appendix, err = parseSpecs(strings.Split(*result.Item[appendixAttr].S, ":"))
contents.appendix, err = parseSpecs(strings.Split(result.Item[appendixAttr].(*ddbtypes.AttributeValueMemberS).Value, ":"))
if err != nil {
return false, manifestContents{}, ErrCorruptManifest
}
@@ -126,24 +127,41 @@ func (dm dynamoManifest) ParseIfExists(ctx context.Context, stats *Stats, readHo
return exists, contents, nil
}
func validateManifest(item map[string]*dynamodb.AttributeValue) (valid, hasSpecs, hasAppendix bool) {
if item[nbsVersAttr] != nil && item[nbsVersAttr].S != nil &&
AWSStorageVersion == *item[nbsVersAttr].S &&
item[versAttr] != nil && item[versAttr].S != nil &&
item[lockAttr] != nil && item[lockAttr].B != nil &&
item[rootAttr] != nil && item[rootAttr].B != nil {
if len(item) == 6 || len(item) == 7 {
if item[tableSpecsAttr] != nil && item[tableSpecsAttr].S != nil {
hasSpecs = true
}
if item[appendixAttr] != nil && item[appendixAttr].S != nil {
hasAppendix = true
}
return true, hasSpecs, hasAppendix
}
return len(item) == 5, false, false
func validateManifest(item map[string]ddbtypes.AttributeValue) (valid, hasSpecs, hasAppendix bool) {
if nbsVersA := item[nbsVersAttr]; nbsVersA == nil {
return false, false, false
} else if nbsVers, ok := nbsVersA.(*ddbtypes.AttributeValueMemberS); !ok {
return false, false, false
} else if nbsVers.Value != AWSStorageVersion {
return false, false, false
}
return false, false, false
if versA := item[versAttr]; versA == nil {
return false, false, false
} else if _, ok := versA.(*ddbtypes.AttributeValueMemberS); !ok {
return false, false, false
}
if lockA := item[lockAttr]; lockA == nil {
return false, false, false
} else if _, ok := lockA.(*ddbtypes.AttributeValueMemberB); !ok {
return false, false, false
}
if rootA := item[rootAttr]; rootA == nil {
return false, false, false
} else if _, ok := rootA.(*ddbtypes.AttributeValueMemberB); !ok {
return false, false, false
}
if len(item) == 6 || len(item) == 7 {
if tableSpecsA := item[tableSpecsAttr]; tableSpecsA == nil {
} else if _, ok := tableSpecsA.(*ddbtypes.AttributeValueMemberS); ok {
hasSpecs = true
}
if appendixA := item[appendixAttr]; appendixA == nil {
} else if _, ok := appendixA.(*ddbtypes.AttributeValueMemberS); ok {
hasAppendix = true
}
return true, hasSpecs, hasAppendix
}
return len(item) == 5, false, false
}
func (dm dynamoManifest) Update(ctx context.Context, lastLock hash.Hash, newContents manifestContents, stats *Stats, writeHook func() error) (manifestContents, error) {
@@ -152,25 +170,25 @@ func (dm dynamoManifest) Update(ctx context.Context, lastLock hash.Hash, newCont
putArgs := dynamodb.PutItemInput{
TableName: aws.String(dm.table),
Item: map[string]*dynamodb.AttributeValue{
dbAttr: {S: aws.String(dm.db)},
nbsVersAttr: {S: aws.String(AWSStorageVersion)},
versAttr: {S: aws.String(newContents.nbfVers)},
rootAttr: {B: newContents.root[:]},
lockAttr: {B: newContents.lock[:]},
Item: map[string]ddbtypes.AttributeValue{
dbAttr: &ddbtypes.AttributeValueMemberS{Value: dm.db},
nbsVersAttr: &ddbtypes.AttributeValueMemberS{Value: AWSStorageVersion},
versAttr: &ddbtypes.AttributeValueMemberS{Value: newContents.nbfVers},
rootAttr: &ddbtypes.AttributeValueMemberB{Value: newContents.root[:]},
lockAttr: &ddbtypes.AttributeValueMemberB{Value: newContents.lock[:]},
},
}
if len(newContents.specs) > 0 {
tableInfo := make([]string, 2*len(newContents.specs))
formatSpecs(newContents.specs, tableInfo)
putArgs.Item[tableSpecsAttr] = &dynamodb.AttributeValue{S: aws.String(strings.Join(tableInfo, ":"))}
putArgs.Item[tableSpecsAttr] = &ddbtypes.AttributeValueMemberS{Value: strings.Join(tableInfo, ":")}
}
if len(newContents.appendix) > 0 {
tableInfo := make([]string, 2*len(newContents.appendix))
formatSpecs(newContents.appendix, tableInfo)
putArgs.Item[appendixAttr] = &dynamodb.AttributeValue{S: aws.String(strings.Join(tableInfo, ":"))}
putArgs.Item[appendixAttr] = &ddbtypes.AttributeValueMemberS{Value: strings.Join(tableInfo, ":")}
}
expr := valueEqualsExpression
@@ -179,12 +197,12 @@ func (dm dynamoManifest) Update(ctx context.Context, lastLock hash.Hash, newCont
}
putArgs.ConditionExpression = aws.String(expr)
putArgs.ExpressionAttributeValues = map[string]*dynamodb.AttributeValue{
prevLockExpressionValuesKey: {B: lastLock[:]},
versExpressionValuesKey: {S: aws.String(newContents.nbfVers)},
putArgs.ExpressionAttributeValues = map[string]ddbtypes.AttributeValue{
prevLockExpressionValuesKey: &ddbtypes.AttributeValueMemberB{Value: lastLock[:]},
versExpressionValuesKey: &ddbtypes.AttributeValueMemberS{Value: newContents.nbfVers},
}
_, ddberr := dm.ddbsvc.PutItemWithContext(ctx, &putArgs)
_, ddberr := dm.ddbsvc.PutItem(ctx, &putArgs)
if ddberr != nil {
if errIsConditionalCheckFailed(ddberr) {
exists, upstream, err := dm.ParseIfExists(ctx, stats, nil)
@@ -213,6 +231,6 @@ func (dm dynamoManifest) Update(ctx context.Context, lastLock hash.Hash, newCont
}
func errIsConditionalCheckFailed(err error) bool {
awsErr, ok := err.(awserr.Error)
return ok && awsErr.Code() == "ConditionalCheckFailedException"
var ccfe *ddbtypes.ConditionalCheckFailedException
return errors.As(err, &ccfe)
}

View File

@@ -28,10 +28,9 @@ import (
"os"
"sync"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/dustin/go-humanize"
flag "github.com/juju/gnuflag"
@@ -75,10 +74,9 @@ func main() {
*dbName = *dir
} else if *table != "" && *bucket != "" && *dbName != "" {
sess := session.Must(session.NewSession(aws.NewConfig().WithRegion("us-west-2")))
var err error
store, err = nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), *table, *dbName, *bucket, s3.New(sess), dynamodb.New(sess), memTableSize, nbs.NewUnlimitedMemQuotaProvider())
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-west-2"))
d.PanicIfError(err)
store, err = nbs.NewAWSStore(context.Background(), types.Format_Default.VersionString(), *table, *dbName, *bucket, s3.NewFromConfig(cfg), dynamodb.NewFromConfig(cfg), memTableSize, nbs.NewUnlimitedMemQuotaProvider())
d.PanicIfError(err)
} else {
log.Fatalf("Must set either --dir or ALL of --table, --bucket and --db\n")

View File

@@ -32,11 +32,9 @@ import (
"testing"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client/metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/stretchr/testify/assert"
"github.com/dolthub/dolt/go/store/d"
@@ -60,8 +58,6 @@ func makeFakeS3(t *testing.T) *fakeS3 {
}
type fakeS3 struct {
s3iface.S3API
assert *assert.Assertions
mu sync.Mutex
@@ -118,7 +114,7 @@ func (m *fakeS3) readerForTableWithNamespace(ctx context.Context, ns string, nam
return nil, nil
}
func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.AbortMultipartUploadInput, opts ...request.Option) (*s3.AbortMultipartUploadOutput, error) {
func (m *fakeS3) AbortMultipartUpload(ctx context.Context, input *s3.AbortMultipartUploadInput, opts ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
m.assert.NotNil(input.UploadId, "UploadId is a required field")
@@ -133,7 +129,7 @@ func (m *fakeS3) AbortMultipartUploadWithContext(ctx aws.Context, input *s3.Abor
return &s3.AbortMultipartUploadOutput{}, nil
}
func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.CreateMultipartUploadInput, opts ...request.Option) (*s3.CreateMultipartUploadOutput, error) {
func (m *fakeS3) CreateMultipartUpload(ctx context.Context, input *s3.CreateMultipartUploadInput, opts ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
@@ -151,7 +147,7 @@ func (m *fakeS3) CreateMultipartUploadWithContext(ctx aws.Context, input *s3.Cre
return out, nil
}
func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInput, opts ...request.Option) (*s3.UploadPartOutput, error) {
func (m *fakeS3) UploadPart(ctx context.Context, input *s3.UploadPartInput, opts ...func(*s3.Options)) (*s3.UploadPartOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
@@ -174,7 +170,7 @@ func (m *fakeS3) UploadPartWithContext(ctx aws.Context, input *s3.UploadPartInpu
return &s3.UploadPartOutput{ETag: aws.String(etag)}, nil
}
func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPartCopyInput, opts ...request.Option) (*s3.UploadPartCopyOutput, error) {
func (m *fakeS3) UploadPartCopy(ctx context.Context, input *s3.UploadPartCopyInput, opts ...func(*s3.Options)) (*s3.UploadPartCopyOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
m.assert.NotNil(input.PartNumber, "PartNumber is a required field")
@@ -205,10 +201,10 @@ func (m *fakeS3) UploadPartCopyWithContext(ctx aws.Context, input *s3.UploadPart
m.assert.Equal(inProgress.uploadID, *input.UploadId)
inProgress.etags = append(inProgress.etags, etag)
m.inProgress[*input.Key] = inProgress
return &s3.UploadPartCopyOutput{CopyPartResult: &s3.CopyPartResult{ETag: aws.String(etag)}}, nil
return &s3.UploadPartCopyOutput{CopyPartResult: &s3types.CopyPartResult{ETag: aws.String(etag)}}, nil
}
func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.CompleteMultipartUploadInput, opts ...request.Option) (*s3.CompleteMultipartUploadOutput, error) {
func (m *fakeS3) CompleteMultipartUpload(ctx context.Context, input *s3.CompleteMultipartUploadInput, opts ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
m.assert.NotNil(input.UploadId, "UploadId is a required field")
@@ -228,7 +224,7 @@ func (m *fakeS3) CompleteMultipartUploadWithContext(ctx aws.Context, input *s3.C
return &s3.CompleteMultipartUploadOutput{Bucket: input.Bucket, Key: input.Key}, nil
}
func (m *fakeS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
func (m *fakeS3) GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
@@ -270,7 +266,7 @@ func parseRange(hdr string, total int) (start, end int) {
return start, end + 1 // insanely, the HTTP range header specifies ranges inclusively.
}
func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput, opts ...request.Option) (*s3.PutObjectOutput, error) {
func (m *fakeS3) PutObject(ctx context.Context, input *s3.PutObjectInput, opts ...func(*s3.Options)) (*s3.PutObjectOutput, error) {
m.assert.NotNil(input.Bucket, "Bucket is a required field")
m.assert.NotNil(input.Key, "Key is a required field")
@@ -283,37 +279,3 @@ func (m *fakeS3) PutObjectWithContext(ctx aws.Context, input *s3.PutObjectInput,
return &s3.PutObjectOutput{}, nil
}
func (m *fakeS3) GetObjectRequest(input *s3.GetObjectInput) (*request.Request, *s3.GetObjectOutput) {
out := &s3.GetObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.GetObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.GetObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "GetObject",
HTTPMethod: "GET",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}
func (m *fakeS3) PutObjectRequest(input *s3.PutObjectInput) (*request.Request, *s3.PutObjectOutput) {
out := &s3.PutObjectOutput{}
var handlers request.Handlers
handlers.Send.PushBack(func(r *request.Request) {
res, err := m.PutObjectWithContext(r.Context(), input)
r.Error = err
if res != nil {
*(r.Data.(*s3.PutObjectOutput)) = *res
}
})
return request.New(aws.Config{}, metadata.ClientInfo{}, handlers, nil, &request.Operation{
Name: "PutObject",
HTTPMethod: "PUT",
HTTPPath: "/{Bucket}/{Key+}",
}, input, out), out
}

View File

@@ -35,9 +35,8 @@ import (
"syscall"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/jpillora/backoff"
"golang.org/x/sync/errgroup"
)
@@ -45,7 +44,7 @@ import (
// s3ObjectReader is a wrapper for S3 that gives us some nice to haves for reading objects from S3.
// TODO: Bring all the multipart upload and remote-conjoin stuff in.
type s3ObjectReader struct {
s3 s3iface.S3API
s3 S3APIV2
bucket string
readRl chan struct{}
ns string
@@ -91,7 +90,7 @@ func (s3or *s3ObjectReader) reader(ctx context.Context, name string) (io.ReadClo
Bucket: aws.String(s3or.bucket),
Key: aws.String(s3or.key(name)),
}
result, err := s3or.s3.GetObjectWithContext(ctx, input)
result, err := s3or.s3.GetObject(ctx, input)
if err != nil {
return nil, err
}
@@ -117,7 +116,7 @@ func (s3or *s3ObjectReader) readRange(ctx context.Context, name string, p []byte
Range: aws.String(rangeHeader),
}
result, err := s3or.s3.GetObjectWithContext(ctx, input)
result, err := s3or.s3.GetObject(ctx, input)
if err != nil {
return 0, 0, err
}

View File

@@ -29,10 +29,7 @@ import (
"syscall"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -89,16 +86,16 @@ func TestS3TableReaderAtNamespace(t *testing.T) {
}
type flakyS3 struct {
s3iface.S3API
S3APIV2
alreadyFailed map[string]struct{}
}
func makeFlakyS3(svc s3iface.S3API) *flakyS3 {
func makeFlakyS3(svc S3APIV2) *flakyS3 {
return &flakyS3{svc, map[string]struct{}{}}
}
func (fs3 *flakyS3) GetObjectWithContext(ctx aws.Context, input *s3.GetObjectInput, opts ...request.Option) (*s3.GetObjectOutput, error) {
output, err := fs3.S3API.GetObjectWithContext(ctx, input)
func (fs3 *flakyS3) GetObject(ctx context.Context, input *s3.GetObjectInput, opts ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
output, err := fs3.S3APIV2.GetObject(ctx, input)
if err != nil {
return nil, err

View File

@@ -34,7 +34,6 @@ import (
"time"
"cloud.google.com/go/storage"
"github.com/aws/aws-sdk-go/service/s3/s3iface"
"github.com/dustin/go-humanize"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/oracle/oci-go-sdk/v65/common"
@@ -504,22 +503,7 @@ func OverwriteStoreManifest(ctx context.Context, store *NomsBlockStore, root has
return nil
}
func NewAWSStoreWithMMapIndex(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3iface.S3API, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) {
cacheOnce.Do(makeGlobalCaches)
readRateLimiter := make(chan struct{}, 32)
p := &awsTablePersister{
s3,
bucket,
readRateLimiter,
awsLimits{defaultS3PartSize, minS3PartSize, maxS3PartSize},
ns,
q,
}
mm := makeManifestManager(newDynamoManifest(table, ns, ddb))
return newNomsBlockStore(ctx, nbfVerStr, mm, p, q, inlineConjoiner{defaultMaxTables}, memTableSize)
}
func NewAWSStore(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 s3iface.S3API, ddb ddbsvc, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) {
func NewAWSStore(ctx context.Context, nbfVerStr string, table, ns, bucket string, s3 S3APIV2, ddb DynamoDBAPIV2, memTableSize uint64, q MemoryQuotaProvider) (*NomsBlockStore, error) {
cacheOnce.Do(makeGlobalCaches)
readRateLimiter := make(chan struct{}, 32)
p := &awsTablePersister{

View File

@@ -34,14 +34,14 @@ import (
"strings"
"cloud.google.com/go/storage"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/dynamodb"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/oracle/oci-go-sdk/v65/common"
"github.com/oracle/oci-go-sdk/v65/objectstorage"
"github.com/dolthub/dolt/go/libraries/utils/awsrefreshcreds"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
"github.com/dolthub/dolt/go/store/chunks"
"github.com/dolthub/dolt/go/store/d"
@@ -341,34 +341,47 @@ func parseAWSSpec(ctx context.Context, awsURL string, options SpecOptions) chunk
parts := strings.SplitN(u.Hostname(), ":", 2) // [table] [, bucket]?
d.PanicIfFalse(len(parts) == 2)
awsConfig := aws.NewConfig().WithRegion(options.AwsRegionOrDefault())
var opts []func(*config.LoadOptions) error
opts = append(opts, config.WithRegion(options.AwsRegionOrDefault()))
switch options.AWSCredSource {
case RoleCS:
// All the default behavior of the SDK.
case EnvCS:
awsConfig = awsConfig.WithCredentials(credentials.NewEnvCredentials())
// Credentials can only come directly from AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY...
creds := awsrefreshcreds.LoadEnvCredentials()
opts = append(opts, config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) {
return creds, nil
})))
case FileCS:
filePath := options.AwsCredFileOrDefault()
creds := credentials.NewSharedCredentials(filePath, DefaultAWSCredsProfile)
awsConfig = awsConfig.WithCredentials(creds)
provider := awsrefreshcreds.LoadINICredentialsProvider(options.AwsCredFileOrDefault(), DefaultAWSCredsProfile)
opts = append(opts, config.WithCredentialsProvider(provider))
case AutoCS:
envCreds := credentials.NewEnvCredentials()
if _, err := envCreds.Get(); err == nil {
awsConfig = awsConfig.WithCredentials(envCreds)
var opt config.LoadOptionsFunc
if envCreds := awsrefreshcreds.LoadEnvCredentials(); envCreds.HasKeys() {
opt = config.WithCredentialsProvider(aws.CredentialsProviderFunc(func(context.Context) (aws.Credentials, error) {
return envCreds, nil
}))
}
filePath := options.AwsCredFileOrDefault()
if _, err := os.Stat(filePath); err == nil {
creds := credentials.NewSharedCredentials(filePath, DefaultAWSCredsProfile)
awsConfig = awsConfig.WithCredentials(creds)
provider := awsrefreshcreds.LoadINICredentialsProvider(options.AwsCredFileOrDefault(), DefaultAWSCredsProfile)
opt = config.WithCredentialsProvider(provider)
}
if opt != nil {
opts = append(opts, opt)
}
default:
panic("unsupported credential type")
}
sess := session.Must(session.NewSession(awsConfig))
cs, err := nbs.NewAWSStore(ctx, types.Format_Default.VersionString(), parts[0], u.Path, parts[1], s3.New(sess), dynamodb.New(sess), 1<<28, nbs.NewUnlimitedMemQuotaProvider())
cfg, err := config.LoadDefaultConfig(ctx, opts...)
d.PanicIfError(err)
cs, err := nbs.NewAWSStore(ctx, types.Format_Default.VersionString(), parts[0], u.Path, parts[1], s3.NewFromConfig(cfg), dynamodb.NewFromConfig(cfg), 1<<28, nbs.NewUnlimitedMemQuotaProvider())
d.PanicIfError(err)
return cs