go/libraries/doltcore/sqle/cluster: Add JWT authentication for peer communication to cluster replicas.

A remotesapi server running on a cluster replica publishes a JWKS.

Every outbound GRPC call the cluster replica makes includes a JWT signed with a
private key.

remotesapi servers running on cluster replicas require and validate incoming
JWTs for cluster traffic. The set of valid signing keys is taken from the
JWKSes which are published at /.well-known/jwks.json on the standby replica
hosts.

It is possible to configure tls_ca on cluster remotesapi to configure the set
of trusted roots for outbound TLS connections. Because the JWKSes are served
over the same connection, and becuase signed JWTs are not replay resistent, TLS
is recommended for all deployment topologies.
This commit is contained in:
Aaron Son
2022-11-10 09:26:36 -08:00
parent 3c484143c6
commit b22fbf11f2
8 changed files with 394 additions and 189 deletions
+6 -4
View File
@@ -285,12 +285,11 @@ func Serve(
lgr.Errorf("error starting remotesapi server listeners for cluster config on port %d: %v", clusterController.RemoteSrvPort(), err)
startError = err
return
} else {
go func() {
clusterRemoteSrv.Serve(listeners)
}()
}
go clusterRemoteSrv.Serve(listeners)
go clusterController.Run()
clusterController.ManageQueryConnections(
mySQLServer.SessionManager().Iter,
sqlEngine.GetUnderlyingEngine().ProcessList.Kill,
@@ -323,6 +322,9 @@ func Serve(
if clusterRemoteSrv != nil {
clusterRemoteSrv.GracefulStop()
}
if clusterController != nil {
clusterController.GracefulStop()
}
return mySQLServer.Close()
})
+182 -18
View File
@@ -18,9 +18,14 @@ import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -36,7 +41,9 @@ import (
"github.com/dolthub/dolt/go/libraries/doltcore/remotesrv"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/clusterdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
"github.com/dolthub/dolt/go/store/types"
)
@@ -59,12 +66,17 @@ type Controller struct {
sinterceptor serverinterceptor
cinterceptor clientinterceptor
lgr *logrus.Logger
grpcCreds credentials.PerRPCCredentials
provider dbProvider
iterSessions IterSessions
killQuery func(uint32)
killConnection func(uint32) error
jwks *jwtauth.MultiJWKS
tlsCfg *tls.Config
grpcCreds credentials.PerRPCCredentials
pub ed25519.PublicKey
priv ed25519.PrivateKey
}
type sqlvars interface {
@@ -112,9 +124,49 @@ func NewController(lgr *logrus.Logger, cfg Config, pCfg config.ReadWriteConfig)
ret.cinterceptor.lgr = lgr.WithFields(logrus.Fields{})
ret.cinterceptor.setRole(role, epoch)
ret.cinterceptor.roleSetter = roleSetter
ret.tlsCfg, err = ret.outboundTlsConfig()
if err != nil {
return nil, err
}
ret.pub, ret.priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
keyID := creds.PubKeyToKID(ret.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
ret.grpcCreds = &creds.RPCCreds{
PrivKey: ret.priv,
Audience: creds.RemotesAPIAudience,
Issuer: creds.ClientIssuer,
KeyID: keyIDStr,
RequireTLS: false,
}
ret.jwks = ret.standbyRemotesJWKS()
ret.sinterceptor.keyProvider = ret.jwks
ret.sinterceptor.jwtExpected = JWTExpectations()
return ret, nil
}
func (c *Controller) Run() {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
c.jwks.Run()
}()
wg.Wait()
}
func (c *Controller) GracefulStop() error {
c.jwks.GracefulStop()
return nil
}
func (c *Controller) ManageSystemVariables(variables sqlvars) {
if c == nil {
return
@@ -198,7 +250,7 @@ func (c *Controller) applyCommitHooks(ctx context.Context, name string, bt *sql.
}
func (c *Controller) gRPCDialProvider(denv *env.DoltEnv) dbfactory.GRPCDialProvider {
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.cfg, c.grpcCreds}
return grpcDialProvider{env.NewGRPCDialProviderFromDoltEnv(denv), &c.cinterceptor, c.tlsCfg, c.grpcCreds}
}
func (c *Controller) RegisterStoredProcedures(store procedurestore) {
@@ -412,23 +464,9 @@ func (c *Controller) RemoteSrvServerArgs(ctx *sql.Context, args remotesrv.Server
args = sqle.RemoteSrvServerArgs(ctx, args)
args.DBCache = remotesrvStoreCache{args.DBCache, c}
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
keyID := creds.PubKeyToKID(pub)
keyID := creds.PubKeyToKID(c.pub)
keyIDStr := creds.B32CredsEncoding.EncodeToString(keyID)
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, pub)
c.grpcCreds = &creds.RPCCreds{
PrivKey: priv,
Audience: creds.RemotesAPIAudience,
Issuer: creds.ClientIssuer,
KeyID: keyIDStr,
RequireTLS: false,
}
args.HttpInterceptor = JWKSHandlerInterceptor(keyIDStr, c.pub)
return args
}
@@ -565,3 +603,129 @@ func (c *Controller) waitForHooksToReplicate() error {
return errors.New("cluster/controller: failed to transition from primary to standby gracefully; could not replicate databases to standby in a timely manner.")
}
}
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
// following semantics:
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
// all of which are trusted roots for the outbound connections the
// remotestorage client establishes.
// * The certificate chain presented by the server must validate to a root
// which was present in tls_ca. In particular, every certificate in the chain
// must be within its validity window, the signatures must be valid, key usage
// and isCa must be correctly set for the roots and the intermediates, and the
// leaf must have extended key usage server auth.
// * On the other hand, no verification is done against the SAN or the Subject
// of the certificate.
//
// We use these TLS semantics for both connections to the gRPC endpoint which
// is the actual remotesapi, and for connections to any HTTPS endpoints to
// which the gRPC service returns URLs. For now, this works perfectly for our
// use case, but it's tightly coupled to `cluster:` deployment topologies and
// the likes.
//
// If tls_ca is not set then default TLS handling is performed. In particular,
// if the remotesapi endpoints is HTTPS, then the system roots are used and
// ServerName is verified against the presented URL SANs of the certificates.
//
// This tls Config is used for fetching JWKS, for outbound GRPC connections and
// for outbound https connections on the URLs that the GRPC services return.
func (c *Controller) outboundTlsConfig() (*tls.Config, error) {
tlsCA := c.cfg.RemotesAPIConfig().TLSCA()
if tlsCA == "" {
return nil, nil
}
urlmatches := c.cfg.RemotesAPIConfig().ServerNameURLMatches()
dnsmatches := c.cfg.RemotesAPIConfig().ServerNameDNSMatches()
pem, err := os.ReadFile(tlsCA)
if err != nil {
return nil, err
}
roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM(pem); !ok {
return nil, errors.New("error loading ca roots from " + tlsCA)
}
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(rawCerts))
var err error
for i, asn1Data := range rawCerts {
certs[i], err = x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
return err
}
if len(urlmatches) > 0 {
found := false
for _, n := range urlmatches {
for _, cn := range certs[0].URIs {
if n == cn.String() {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_urls, but it did not")
}
}
if len(dnsmatches) > 0 {
found := false
for _, n := range dnsmatches {
for _, cn := range certs[0].DNSNames {
if n == cn {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_dns, but it did not")
}
}
return nil
}
return &tls.Config{
// We have to InsecureSkipVerify because ServerName is always
// set by the grpc dial provider and golang tls.Config does not
// have good support for performing certificate validation
// without server name validation.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyFunc,
NextProtos: []string{"h2"},
}, nil
}
func (c *Controller) standbyRemotesJWKS() *jwtauth.MultiJWKS {
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: c.tlsCfg,
ForceAttemptHTTP2: true,
},
}
urls := make([]string, len(c.cfg.StandbyRemotes()))
for i, r := range c.cfg.StandbyRemotes() {
urls[i] = strings.Replace(r.RemoteURLTemplate(), dsess.URLTemplateDatabasePlaceholder, ".well-known/jwks.json", -1)
}
return jwtauth.NewMultiJWKS(c.lgr.WithFields(logrus.Fields{"component": "jwks-key-provider"}), urls, client)
}
@@ -16,15 +16,13 @@ package cluster
import (
"crypto/tls"
"crypto/x509"
"errors"
"io/ioutil"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
"github.com/dolthub/dolt/go/libraries/doltcore/dbfactory"
"github.com/dolthub/dolt/go/libraries/doltcore/grpcendpoint"
)
@@ -35,24 +33,26 @@ import (
// - client interceptors for transmitting our replication role.
// - do not use environment credentials. (for now).
type grpcDialProvider struct {
orig dbfactory.GRPCDialProvider
ci *clientinterceptor
cfg Config
creds credentials.PerRPCCredentials
orig dbfactory.GRPCDialProvider
ci *clientinterceptor
tlsCfg *tls.Config
creds credentials.PerRPCCredentials
}
func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfactory.GRPCRemoteConfig, error) {
tlsConfig, err := p.tlsConfig()
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
config.TLSConfig = tlsConfig
config.TLSConfig = p.tlsCfg
config.Creds = p.creds
if config.Creds != nil && config.TLSConfig != nil {
if c, ok := config.Creds.(*creds.RPCCreds); ok {
c.RequireTLS = true
}
}
config.WithEnvCreds = false
cfg, err := p.orig.GetGRPCDialParams(config)
if err != nil {
return dbfactory.GRPCRemoteConfig{}, err
}
cfg.DialOptions = append(cfg.DialOptions, p.ci.Options()...)
cfg.DialOptions = append(cfg.DialOptions, grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
@@ -63,114 +63,6 @@ func (p grpcDialProvider) GetGRPCDialParams(config grpcendpoint.Config) (dbfacto
},
MinConnectTimeout: 250 * time.Millisecond,
}))
return cfg, nil
}
// Within a cluster, if remotesapi is configured with a tls_ca, we take the
// following semantics:
// * The configured tls_ca file holds a set of PEM encoded x509 certificates,
// all of which are trusted roots for the outbound connections the
// remotestorage client establishes.
// * The certificate chain presented by the server must validate to a root
// which was present in tls_ca. In particular, every certificate in the chain
// must be within its validity window, the signatures must be valid, key usage
// and isCa must be correctly set for the roots and the intermediates, and the
// leaf must have extended key usage server auth.
// * On the other hand, no verification is done against the SAN or the Subject
// of the certificate.
//
// We use these TLS semantics for both connections to the gRPC endpoint which
// is the actual remotesapi, and for connections to any HTTPS endpoints to
// which the gRPC service returns URLs. For now, this works perfectly for our
// use case, but it's tightly coupled to `cluster:` deployment topologies and
// the likes.
//
// If tls_ca is not set then default TLS handling is performed. In particular,
// if the remotesapi endpoints is HTTPS, then the system roots are used and
// ServerName is verified against the presented URL SANs of the certificates.
func (p grpcDialProvider) tlsConfig() (*tls.Config, error) {
tlsCA := p.cfg.RemotesAPIConfig().TLSCA()
if tlsCA == "" {
return nil, nil
}
urlmatches := p.cfg.RemotesAPIConfig().ServerNameURLMatches()
dnsmatches := p.cfg.RemotesAPIConfig().ServerNameDNSMatches()
pem, err := ioutil.ReadFile(tlsCA)
if err != nil {
return nil, err
}
roots := x509.NewCertPool()
if ok := roots.AppendCertsFromPEM(pem); !ok {
return nil, errors.New("error loading ca roots from " + tlsCA)
}
verifyFunc := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
certs := make([]*x509.Certificate, len(rawCerts))
var err error
for i, asn1Data := range rawCerts {
certs[i], err = x509.ParseCertificate(asn1Data)
if err != nil {
return err
}
}
keyUsages := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
opts := x509.VerifyOptions{
Roots: roots,
CurrentTime: time.Now(),
Intermediates: x509.NewCertPool(),
KeyUsages: keyUsages,
}
for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}
_, err = certs[0].Verify(opts)
if err != nil {
return err
}
if len(urlmatches) > 0 {
found := false
for _, n := range urlmatches {
for _, cn := range certs[0].URIs {
if n == cn.String() {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_urls, but it did not")
}
}
if len(dnsmatches) > 0 {
found := false
for _, n := range dnsmatches {
for _, cn := range certs[0].DNSNames {
if n == cn {
found = true
}
break
}
if found {
break
}
}
if !found {
return errors.New("expected certificate to match something in server_name_dns, but it did not")
}
}
return nil
}
return &tls.Config{
// We have to InsecureSkipVerify because ServerName is always
// set by the grpc dial provider and golang tls.Config does not
// have good support for performing certificate validation
// without server name validation.
InsecureSkipVerify: true,
VerifyPeerCertificate: verifyFunc,
NextProtos: []string{"h2"},
}, nil
}
@@ -17,14 +17,18 @@ package cluster
import (
"context"
"strconv"
"strings"
"sync"
"time"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
)
const clusterRoleHeader = "x-dolt-cluster-role"
@@ -158,12 +162,20 @@ func (ci *clientinterceptor) Options() []grpc.DialOption {
// * for incoming requests which are not standby, it will currently fail the
// requests with codes.Unauthenticated. Eventually, it will allow read-only
// traffic through which is authenticated and authorized.
//
// The serverinterceptor is responsible for authenticating incoming requests
// from standby replicas. It is instantiated with a jwtauth.KeyProvider and
// some jwt.Expected. Incoming requests must have a valid, unexpired, signed
// JWT, signed by a key accessible in the KeyProvider.
type serverinterceptor struct {
lgr *logrus.Entry
role Role
epoch int
mu sync.Mutex
roleSetter func(role string, epoch int)
keyProvider jwtauth.KeyProvider
jwtExpected jwt.Expected
}
func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
@@ -174,6 +186,9 @@ func (si *serverinterceptor) Stream() grpc.StreamServerInterceptor {
fromStandby = si.handleRequestHeaders(md, role, epoch)
}
if fromStandby {
if err := si.authenticate(ss.Context()); err != nil {
return err
}
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
role, epoch := si.getRole()
if err := grpc.SetHeader(ss.Context(), metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
@@ -204,6 +219,9 @@ func (si *serverinterceptor) Unary() grpc.UnaryServerInterceptor {
fromStandby = si.handleRequestHeaders(md, role, epoch)
}
if fromStandby {
if err := si.authenticate(ctx); err != nil {
return nil, err
}
// After handleRequestHeaders, our role may have changed, so we fetch it again here.
role, epoch := si.getRole()
if err := grpc.SetHeader(ctx, metadata.Pairs(clusterRoleHeader, string(role), clusterRoleEpochHeader, strconv.Itoa(epoch))); err != nil {
@@ -272,3 +290,26 @@ func (si *serverinterceptor) getRole() (Role, int) {
defer si.mu.Unlock()
return si.role, si.epoch
}
func (si *serverinterceptor) authenticate(ctx context.Context) error {
if md, ok := metadata.FromIncomingContext(ctx); ok {
auths := md.Get("authorization")
if len(auths) != 1 {
si.lgr.Info("incoming standby request had no authorization")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
auth := auths[0]
if !strings.HasPrefix(auth, "Bearer ") {
si.lgr.Info("incoming standby request had malformed authentication header")
return status.Error(codes.Unauthenticated, "unauthenticated")
}
auth = strings.TrimPrefix(auth, "Bearer ")
_, err := jwtauth.ValidateJWT(auth, time.Now(), si.keyProvider, si.jwtExpected)
if err != nil {
si.lgr.Infof("incoming standby request authorization header failed to verify: %v", err)
return status.Error(codes.Unauthenticated, "unauthenticated")
}
return nil
}
return status.Error(codes.Unauthenticated, "unauthenticated")
}
@@ -16,13 +16,15 @@ package cluster
import (
"context"
"crypto/ed25519"
"crypto/rand"
"net"
"strconv"
"sync"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
@@ -30,6 +32,10 @@ import (
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/utils/jwtauth"
)
type server struct {
@@ -51,6 +57,53 @@ func noopSetRole(string, int) {
var lgr = logrus.StandardLogger().WithFields(logrus.Fields{})
var kp jwtauth.KeyProvider
var pub ed25519.PublicKey
var priv ed25519.PrivateKey
func init() {
var err error
pub, priv, err = ed25519.GenerateKey(rand.Reader)
if err != nil {
panic(err)
}
kp = keyProvider{pub}
}
type keyProvider struct {
ed25519.PublicKey
}
func (p keyProvider) GetKey(string) ([]jose.JSONWebKey, error) {
return []jose.JSONWebKey{{
Key: p.PublicKey,
KeyID: "1",
}}, nil
}
func newJWT() string {
key := jose.SigningKey{Algorithm: jose.EdDSA, Key: priv}
opts := &jose.SignerOptions{ExtraHeaders: map[jose.HeaderKey]interface{}{
"kid": "1",
}}
signer, err := jose.NewSigner(key, opts)
if err != nil {
panic(err)
}
jwtBuilder := jwt.Signed(signer)
jwtBuilder = jwtBuilder.Claims(jwt.Claims{
Audience: []string{"some_audience"},
Issuer: "some_issuer",
Subject: "some_subject",
Expiry: jwt.NewNumericDate(time.Now().Add(30 * time.Second)),
})
res, err := jwtBuilder.CompactSerialize()
if err != nil {
panic(err)
}
return res
}
func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient), serveropts []grpc.ServerOption, dialopts []grpc.DialOption) *server {
addr, err := net.ResolveUnixAddr("unix", "test_grpc.socket")
require.NoError(t, err)
@@ -93,12 +146,14 @@ func withClient(t *testing.T, cb func(*testing.T, grpc_health_v1.HealthClient),
func outboundCtx(vals ...interface{}) context.Context {
ctx := context.Background()
if len(vals) == 0 {
return ctx
return metadata.AppendToOutgoingContext(ctx,
"authorization", "Bearer "+newJWT())
}
if len(vals) == 2 {
return metadata.AppendToOutgoingContext(ctx,
clusterRoleHeader, string(vals[0].(Role)),
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)))
clusterRoleEpochHeader, strconv.Itoa(vals[1].(int)),
"authorization", "Bearer "+newJWT())
}
panic("bad test --- outboundCtx must take 0 or 2 values")
}
@@ -108,6 +163,7 @@ func TestServerInterceptorUnauthenticatedWithoutClientHeaders(t *testing.T) {
si.roleSetter = noopSetRole
si.lgr = lgr
si.setRole(RoleStandby, 10)
si.keyProvider = kp
t.Run("Standby", func(t *testing.T) {
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
_, err := client.Check(outboundCtx(), &grpc_health_v1.HealthCheckRequest{})
@@ -136,6 +192,7 @@ func TestServerInterceptorAddsUnaryResponseHeaders(t *testing.T) {
si.setRole(RoleStandby, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
var md metadata.MD
_, err := client.Check(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
@@ -154,6 +211,7 @@ func TestServerInterceptorAddsStreamResponseHeaders(t *testing.T) {
si.setRole(RoleStandby, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
var md metadata.MD
srv, err := client.Watch(outboundCtx(RolePrimary, 10), &grpc_health_v1.HealthCheckRequest{}, grpc.Header(&md))
@@ -174,6 +232,7 @@ func TestServerInterceptorAsPrimaryDoesNotSendRequest(t *testing.T) {
si.setRole(RolePrimary, 10)
si.roleSetter = noopSetRole
si.lgr = lgr
si.keyProvider = kp
srv := withClient(t, func(t *testing.T, client grpc_health_v1.HealthClient) {
ctx := metadata.AppendToOutgoingContext(outboundCtx(RoleStandby, 10), "test-header", "test-header-value")
_, err := client.Check(ctx, &grpc_health_v1.HealthCheckRequest{})
@@ -20,6 +20,9 @@ import (
"net/http"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"github.com/dolthub/dolt/go/libraries/doltcore/creds"
)
type JWKSHandler struct {
@@ -55,3 +58,7 @@ func JWKSHandlerInterceptor(keyID string, pub ed25519.PublicKey) func(http.Handl
})
}
}
func JWTExpectations() jwt.Expected {
return jwt.Expected{Issuer: creds.ClientIssuer, Audience: jwt.Audience{creds.RemotesAPIAudience}}
}
+81 -42
View File
@@ -16,12 +16,14 @@ package jwtauth
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"sync"
"time"
"github.com/sirupsen/logrus"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/json"
)
@@ -111,7 +113,7 @@ func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
return jwks.Key(kid), nil
}
// The multiJWKS will source JWKS from multiple URLs and will make them all
// The MultiJWKS will source JWKS from multiple URLs and will make them all
// available through GetKey(). It's GetKey() cannot error, but it can return no
// results.
//
@@ -122,33 +124,36 @@ func (f *fetchedJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
// key set will generally hint that the URLs should be more aggressively
// refreshed, but there is no blocking on refreshing the URLs.
//
// gracefulStop() will shutdown any ongoing fetching work and will return when
// GracefulStop() will shutdown any ongoing fetching work and will return when
// everything is cleanly shutdown.
type multiJWKS struct {
type MultiJWKS struct {
client *http.Client
wg sync.WaitGroup
stop chan struct{}
refresh []chan struct{}
refresh []chan *sync.WaitGroup
urls []string
sets []jose.JSONWebKeySet
agg jose.JSONWebKeySet
mu sync.RWMutex
lgr *logrus.Entry
stopped bool
}
func newMultiJWKS(urls []string, client *http.Client) *multiJWKS {
res := new(multiJWKS)
func NewMultiJWKS(lgr *logrus.Entry, urls []string, client *http.Client) *MultiJWKS {
res := new(MultiJWKS)
res.lgr = lgr
res.client = client
res.urls = urls
res.stop = make(chan struct{})
res.refresh = make([]chan struct{}, len(urls))
res.refresh = make([]chan *sync.WaitGroup, len(urls))
for i := range res.refresh {
res.refresh[i] = make(chan struct{})
res.refresh[i] = make(chan *sync.WaitGroup, 3)
}
res.sets = make([]jose.JSONWebKeySet, len(urls))
return res
}
func (t *multiJWKS) run() {
func (t *MultiJWKS) Run() {
t.wg.Add(len(t.urls))
for i := 0; i < len(t.urls); i++ {
go t.thread(i)
@@ -156,21 +161,32 @@ func (t *multiJWKS) run() {
t.wg.Wait()
}
func (t *multiJWKS) gracefulStop() {
func (t *MultiJWKS) GracefulStop() {
t.mu.Lock()
t.stopped = true
t.mu.Unlock()
close(t.stop)
t.wg.Wait()
// TODO: Potentially clear t.refresh channels, ensure nothing else can call GetKey()...
}
func (t * multiJWKS) needsRefresh() {
func (t *MultiJWKS) needsRefresh() *sync.WaitGroup {
wg := new(sync.WaitGroup)
if t.stopped {
return wg
}
wg.Add(len(t.refresh))
for _, c := range t.refresh {
select {
case c <- struct{}{}:
case c <- wg:
default:
wg.Done()
}
}
return wg
}
func (t * multiJWKS) store(i int, jwks jose.JSONWebKeySet) {
func (t *MultiJWKS) store(i int, jwks jose.JSONWebKeySet) {
t.mu.Lock()
defer t.mu.Unlock()
t.sets[i] = jwks
@@ -184,56 +200,79 @@ func (t * multiJWKS) store(i int, jwks jose.JSONWebKeySet) {
}
}
func (t *multiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
func (t *MultiJWKS) GetKey(kid string) ([]jose.JSONWebKey, error) {
t.mu.RLock()
defer t.mu.RUnlock()
res := t.agg.Key(kid)
if len(res) == 0 {
t.needsRefresh()
t.lgr.Infof("fetched key %s, found no key, signaling refresh", kid)
refresh := t.needsRefresh()
t.mu.RUnlock()
refresh.Wait()
t.mu.RLock()
res = t.agg.Key(kid)
t.lgr.Infof("refresh for key %s done, found %d keys", kid, len(res))
}
return res, nil
}
func (t * multiJWKS) thread(i int) {
func (t *MultiJWKS) fetch(i int) error {
request, err := http.NewRequest("GET", t.urls[i], nil)
if err != nil {
return err
}
response, err := t.client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
if response.StatusCode/100 != 2 {
return fmt.Errorf("http request failed: StatusCode: %d", response.StatusCode)
}
contents, err := ioutil.ReadAll(response.Body)
if err != nil {
return err
}
var jwks jose.JSONWebKeySet
err = json.Unmarshal(contents, &jwks)
if err != nil {
return err
}
t.store(i, jwks)
return nil
}
func (t *MultiJWKS) thread(i int) {
defer t.wg.Done()
timer := time.NewTimer(30 * time.Second)
var refresh *sync.WaitGroup
for {
nextRefresh := 30 * time.Second
request, err := http.NewRequest("GET", t.urls[i], nil)
if err == nil {
response, err := t.client.Do(request)
if err == nil && response.StatusCode/100 == 2 {
contents, err := ioutil.ReadAll(response.Body)
if err == nil {
var jwks jose.JSONWebKeySet
err = json.Unmarshal(contents, &jwks)
if err == nil {
t.store(i, jwks)
} else {
// Something bad...
nextRefresh = 1 * time.Second
}
} else {
// Something bad...
nextRefresh = 1 * time.Second
}
response.Body.Close()
} else {
// Something bad...
nextRefresh = 1 * time.Second
}
} else {
err := t.fetch(i)
if err != nil {
// Something bad...
t.lgr.Warnf("error fetching %s: %v", t.urls[i], err)
nextRefresh = 1 * time.Second
}
timer.Reset(nextRefresh)
if refresh != nil {
refresh.Done()
}
refresh = nil
select {
case <-t.stop:
if !timer.Stop() {
<-timer.C
}
return
case <-t.refresh[i]:
for {
select {
case refresh = <-t.refresh[i]:
refresh.Done()
default:
return
}
}
case refresh = <-t.refresh[i]:
if !timer.Stop() {
<-timer.C
}
@@ -562,6 +562,7 @@ tests:
- exec: "create table more_vals (i int primary key)"
error_match: "repo1 is read-only"
- on: server2
retry_attempts: 100
queries:
- query: "SELECT @@GLOBAL.dolt_cluster_role,@@GLOBAL.dolt_cluster_role_epoch"
result: