jwt auth integration tests

This commit is contained in:
Brian Hendriks
2025-12-10 16:20:08 -08:00
parent 07f8e3767a
commit 6832617c61
6 changed files with 389 additions and 0 deletions

View File

@@ -621,6 +621,9 @@ func ConfigureServices(
metricsHandler := promhttp.Handler()
jwksConfig := cfg.ServerConfig.MetricsJwksConfig()
enableMetricsAuth := jwksConfig != nil
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s", enableMetricsAuth, addr)
if enableMetricsAuth {
mux.Handle("/metrics", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")

View File

@@ -8,6 +8,7 @@ require (
github.com/google/uuid v1.6.0
github.com/stretchr/testify v1.11.1
golang.org/x/sync v0.16.0
gopkg.in/go-jose/go-jose.v2 v2.6.3
gopkg.in/square/go-jose.v2 v2.5.1
gopkg.in/yaml.v3 v3.0.1
)

View File

@@ -22,6 +22,8 @@ golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-jose/go-jose.v2 v2.6.3 h1:nt80fvSDlhKWQgSWyHyy5CfmlQr+asih51R8PTWNKKs=
gopkg.in/go-jose/go-jose.v2 v2.6.3/go.mod h1:zzZDPkNNw/c9IE7Z9jr11mBZQhKQTMzoEEIoEdZlFBI=
gopkg.in/square/go-jose.v2 v2.5.1 h1:7odma5RETjNHWJnR32wx8t+Io4djHE1PqxCFx3iiZ2w=
gopkg.in/square/go-jose.v2 v2.5.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -0,0 +1,353 @@
package main
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
"github.com/stretchr/testify/require"
"gopkg.in/go-jose/go-jose.v2"
)
type getConfigFunc func(serverPort, metricsPort int) string
// runJWKSServer starts a local HTTP server that serves the JWKS file at /.well-known/jwks.json.
// The server is started in a goroutine and will be shut down via t.Cleanup.
func runJWKSServer(t *testing.T, jwksFilePath string, port int) {
data, err := os.ReadFile(jwksFilePath)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/jwks.json", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(data)
})
addr := fmt.Sprintf("127.0.0.1:%d", port)
server := &http.Server{
Addr: addr,
Handler: mux,
}
ln, err := net.Listen("tcp", addr)
require.NoError(t, err)
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
// Use t.Log instead of t.Fatal because Serve runs in goroutine
t.Logf("runJWKSServer: server error: %v", err)
}
}()
t.Logf("Started test JWKS server on %s serving %s", addr, jwksFilePath)
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
t.Logf("runJWKSServer: shutdown error: %v", err)
}
})
}
func createJWT(t *testing.T, issuer, audience, subject string) string {
const kid = "749df841-6e38-48f1-a178-20ecdd0b09f7"
// load jwks from testdata
data, err := os.ReadFile("testdata/test_jwks_private.json")
require.NoError(t, err)
var jwks jose.JSONWebKeySet
err = json.Unmarshal(data, &jwks)
require.NoError(t, err)
require.NotEmpty(t, jwks.Keys)
// choose key by kid or default to first
var jwk *jose.JSONWebKey
if kid == "" {
jwk = &jwks.Keys[0]
} else {
for i := range jwks.Keys {
if jwks.Keys[i].KeyID == kid {
jwk = &jwks.Keys[i]
break
}
}
require.NotNil(t, jwk)
}
// ensure we have a private key
require.False(t, jwk.IsPublic())
// create signer with kid header (if present)
opts := (&jose.SignerOptions{}).WithType("JWT")
if jwk.KeyID != "" {
opts = opts.WithHeader("kid", jwk.KeyID)
}
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: jwk.Key}, opts)
require.NoError(t, err)
// build claims
now := time.Now().UTC()
claims := map[string]interface{}{
"iss": issuer,
"aud": audience,
"sub": subject,
"iat": now.Unix(),
"exp": now.Add(time.Hour).Unix(),
}
payload, err := json.Marshal(claims)
require.NoError(t, err)
jws, err := signer.Sign(payload)
require.NoError(t, err)
compact, err := jws.CompactSerialize()
require.NoError(t, err)
return compact
}
func makeMetricsCall(t *testing.T, metricsPort int, bearerToken string) *http.Response {
url := fmt.Sprintf("http://127.0.0.1:%d/metrics", metricsPort)
req, err := http.NewRequest("GET", url, nil)
require.NoError(t, err)
if bearerToken != "" {
req.Header.Add("Authorization", "Bearer "+bearerToken)
}
client := &http.Client{}
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
return resp
}
var jwksPort int
func TestMetricsAuth(t *testing.T) {
jwksPort = GlobalPorts.GetPort(t)
absPath, err := filepath.Abs("./testdata/test_jwks.json")
require.NoError(t, err)
runJWKSServer(t, absPath, jwksPort)
t.Parallel()
t.Run("No Metrics Auth", testNoMetricsAuth)
t.Run("Missing Metrics Auth", testMissingMetricsAuth)
t.Run("Valid Metrics Auth", testValidMetricsAuth)
t.Run("Bad Audience Claim", testBadAudienceClaim)
t.Run("Bad Issuer Claim", testBadIssuerClaim)
t.Run("Bad Subject Claim", testBadSubjectClaim)
}
func startServerWithMetrics(t *testing.T, getConfig getConfigFunc) int {
var ports DynamicResources
ports.global = &GlobalPorts
ports.t = t
serverPort := ports.GetOrAllocatePort("server")
metricsPort := ports.GetOrAllocatePort("metrics")
config := getConfig(serverPort, metricsPort)
u, err := driver.NewDoltUser()
require.NoError(t, err)
t.Cleanup(func() {
u.Cleanup()
})
rs, err := u.MakeRepoStore()
require.NoError(t, err)
repo, err := rs.MakeRepo("max_conns_test")
require.NoError(t, err)
f, err := os.CreateTemp("", "config-*.yaml")
require.NoError(t, err)
t.Cleanup(func() {
os.Remove(f.Name())
})
_, err = f.WriteString(config)
require.NoError(t, err)
args := []string{"--config", f.Name()}
srvSettings := &driver.Server{
Args: args,
DynamicPort: "server",
}
t.Log("Starting server with config:\n" + config)
MakeServer(t, repo, srvSettings, &ports)
// hack to wait for server to start before making metrics call
time.Sleep(1 * time.Second)
return metricsPort
}
func testNoMetricsAuth(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`
listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
`, serverPort, metricsPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
resp := makeMetricsCall(t, metricsPort, "")
require.Equal(t, http.StatusOK, resp.StatusCode)
}
func testMissingMetricsAuth(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
jwks:
name: jwksname
location_url: http://127.0.0.1:%d/jwks.json
claims:
alg: RS256
iss: dolthub.com
sub: test_sub
aud: test_aud
`, serverPort, metricsPort, jwksPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
resp := makeMetricsCall(t, metricsPort, "")
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
func testValidMetricsAuth(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
jwks:
name: jwksname
location_url: http://127.0.0.1:%d/jwks.json
claims:
iss: dolthub.com
sub: test_sub
aud: test_aud
`, serverPort, metricsPort, jwksPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
jwt := createJWT(t, "dolthub.com", "test_aud", "test_sub")
resp := makeMetricsCall(t, metricsPort, jwt)
require.Equal(t, http.StatusOK, resp.StatusCode)
}
func testBadIssuerClaim(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
jwks:
name: jwksname
location_url: http://127.0.0.1:%d/jwks.json
claims:
iss: dolthub.com
sub: test_sub
aud: test_aud
`, serverPort, metricsPort, jwksPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
jwt := createJWT(t, "badissuer.com", "test_aud", "test_sub")
resp := makeMetricsCall(t, metricsPort, jwt)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
func testBadAudienceClaim(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
jwks:
name: jwksname
location_url: http://127.0.0.1:%d/jwks.json
claims:
iss: dolthub.com
sub: test_sub
aud: test_aud
`, serverPort, metricsPort, jwksPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
jwt := createJWT(t, "dolthub.com", "bad_aud", "test_sub")
resp := makeMetricsCall(t, metricsPort, jwt)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}
func testBadSubjectClaim(t *testing.T) {
t.Parallel()
getConfig := func(serverPort, metricsPort int) string {
return fmt.Sprintf(`listener:
host: localhost
port: %d
metrics:
host: localhost
port: %d
jwks:
name: jwksname
location_url: http://127.0.0.1:%d/jwks.json
claims:
iss: dolthub.com
sub: test_sub
aud: test_aud
`, serverPort, metricsPort, jwksPort)
}
metricsPort := startServerWithMetrics(t, getConfig)
jwt := createJWT(t, "dolthub.com", "test_aud", "bad_sub")
resp := makeMetricsCall(t, metricsPort, jwt)
require.Equal(t, http.StatusUnauthorized, resp.StatusCode)
}

View File

@@ -0,0 +1,12 @@
{
"keys": [
{
"kty": "RSA",
"e": "AQAB",
"use": "sig",
"kid": "749df841-6e38-48f1-a178-20ecdd0b09f7",
"alg": "RS256",
"n": "uPU3mTbLgki6xF05WLp1xFL3t9ZfrYygcgBq5-qJhz-n5dTicxiCt1X0PvDzomqGJFap7q_yB4z97w9GcoeRXzPrGyeVN324cOfCqgyrhR7pMXfrkAY885eqp3068Wo2V7OgS0yLlzD4XRjj4A9g8ssuZKhwgDcUqdiWk3ar33e3KMnz-GYqRrfNHnHkZ-WPqSwseM5ng9YgmKy8j3JeKLG336yGIRXQGAtZdekWBCBSlT_dW9bxiLdUMzS4ENzJzmYTrfcfQY6mm9C-DTprv3Lwq1h95Yhh7taIhRcBpin6Il5sBD7OLUUGHktDR_frEqkN8a2DeL5c81HOqT9-JQ"
}
]
}

View File

@@ -0,0 +1,18 @@
{
"keys": [
{
"p": "-dh54E1ei0vAT6CR-JjXnpQjBQLpILN1d9CY-hFl0S56oMlt1tU46nHGM5COVsQJv8OKCFWGQJY9BixlGAk9UGKRAxSKjhePsKDFr7vOdwmTg6mrF1JFa02Zuz3f7zRAsFd1rAtDcHTMmxISaIEZuBMeOougMJ1XdbLxqjkZLLM",
"kty": "RSA",
"q": "vYOPShvpX2wHMdLU7WTXPHre4P1sKvdZzEtFbOXnsfRNLeslAxFeTd7FpHc8wRA3cB5PTMzZU7jEyj3pHmZAq-2R_gu63gLjl8erI36cAmIgzB_seFQM8licuS6TbZvlYFM2M1NNjqTASLO5feGAtSoUSJmYvndn9OhvSa-Yxcc",
"d": "FQFsfg8WD8bYx0JbJ_ONOm29yngjR5-H_UqE2a_uTJjzJYwG59Fpzw6I_bj5woFcmLXq-LusviTKFiNi-dDhtrE7y0q0jKfPkasQlaV4uVaoX0DiVOoQdA3OiNUVI6PPZih1VPftho8-NbyE7MZyWUCwFSh4FmerBhseBsNcg7VwBr4CZimmSEKMJXHLMR7jZfrVJAuFGSHJWrdIS04sPcCyBX4V7FLCRBr8BRiM7LznwP4tDHG8-bferqtXH4-XwPY3HjaHRTK7gRQspCKPv-kPdnQl2rH0zJa5kY7MQq4RDtINXSefM_sek3RrhXUXvG2JF-1JiQ3EUi59eY7ZxQ",
"e": "AQAB",
"use": "sig",
"kid": "749df841-6e38-48f1-a178-20ecdd0b09f7",
"qi": "isaLZv3bKV0Osmfl5ay6LABeefMkun1tw2AJD6Ouqo9Z7mDLaTUDP5XxwypKry3ZVd-N1ZMkfgGi6Q-O-xLq2AU5mnI5bHbDAiN3P9Jp-Tex6c6rzT24fIbGQMV-Rul8z10_FsYbFv5CWr9V_jFBNrR8gX4XGtHn1qFVBo_K5Rk",
"dp": "tRlEvmFWdoGiFBW_uQKQyFF4UNmbQijSrNZ3DEwwEUAvgvx-sYo8hzORBy9w_VN7_ZQvKXtUpNxBv4fOf22zE-FeW204QWaysMTYhlkLfx1h373MVks8JltJY3-mIi0t9qRulxZS--Ctrnma_kUV72dsMeOjaZmjG51prolUxiE",
"alg": "RS256",
"dq": "skqLC9WmgLdJLX6EA7LTK3sNI-5HTUTXnnNSJVlF2Q1VbtXCRFiat_fVSR1Ecv2mqjxZro8qBrHVsc78-jSIszcWGkM-0o81Px4By6rZawSWhnOiLLImW_kxuKYw3PXFnhGq9C5y0Lf-jmdHIz57r_SekI6wPMBpdOcXi-M_fxE",
"n": "uPU3mTbLgki6xF05WLp1xFL3t9ZfrYygcgBq5-qJhz-n5dTicxiCt1X0PvDzomqGJFap7q_yB4z97w9GcoeRXzPrGyeVN324cOfCqgyrhR7pMXfrkAY885eqp3068Wo2V7OgS0yLlzD4XRjj4A9g8ssuZKhwgDcUqdiWk3ar33e3KMnz-GYqRrfNHnHkZ-WPqSwseM5ng9YgmKy8j3JeKLG336yGIRXQGAtZdekWBCBSlT_dW9bxiLdUMzS4ENzJzmYTrfcfQY6mm9C-DTprv3Lwq1h95Yhh7taIhRcBpin6Il5sBD7OLUUGHktDR_frEqkN8a2DeL5c81HOqT9-JQ"
}
]
}