add ability to skip metrics auth for the localhost

This commit is contained in:
Brian Hendriks
2025-12-11 11:50:11 -08:00
parent 6832617c61
commit e44048ed20
7 changed files with 250 additions and 63 deletions
@@ -384,6 +384,10 @@ func (cfg *commandLineServerConfig) MetricsJwksConfig() *servercfg.JwksConfig {
return nil
}
func (cfg *commandLineServerConfig) MetricsJWTRequiredForLocalhost() bool {
return false
}
func (cfg *commandLineServerConfig) RemotesapiPort() *int {
return cfg.remotesapiPort
}
+14 -1
View File
@@ -58,6 +58,7 @@ import (
"github.com/dolthub/dolt/go/libraries/events"
"github.com/dolthub/dolt/go/libraries/utils/config"
"github.com/dolthub/dolt/go/libraries/utils/filesys"
httputils "github.com/dolthub/dolt/go/libraries/utils/http"
"github.com/dolthub/dolt/go/libraries/utils/svcs"
"github.com/dolthub/dolt/go/store/chunks"
eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1"
@@ -621,11 +622,23 @@ func ConfigureServices(
metricsHandler := promhttp.Handler()
jwksConfig := cfg.ServerConfig.MetricsJwksConfig()
enableMetricsAuth := jwksConfig != nil
requireLocalhostAuth := cfg.ServerConfig.MetricsJWTRequiredForLocalhost()
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s", enableMetricsAuth, addr)
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s, require_localhost_auth = %t", enableMetricsAuth, addr, requireLocalhostAuth)
if enableMetricsAuth {
mux.Handle("/metrics", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !requireLocalhostAuth {
isLocal, err := httputils.IsLocalRequest(r)
logrus.Info("Metrics JWT not required for localhost isLocal:", isLocal, "err:", err)
if err != nil {
logrus.Warnf("error checking if request is local for /metrics (assuming remote) request: %v.", err)
} else if isLocal {
metricsHandler.ServeHTTP(w, r)
return
}
}
auth := r.Header.Get("Authorization")
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`)
+43 -41
View File
@@ -209,6 +209,7 @@ type ServerConfig interface {
MetricsTLSKey() string
MetricsTLSCA() string
MetricsJwksConfig() *JwksConfig
MetricsJWTRequiredForLocalhost() bool
// PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a
// JSON string.
@@ -333,47 +334,48 @@ func ValidateConfig(config ServerConfig) error {
}
const (
HostKey = "host"
PortKey = "port"
UserKey = "user"
PasswordKey = "password"
ReadTimeoutKey = "net_read_timeout"
WriteTimeoutKey = "net_write_timeout"
ReadOnlyKey = "read_only"
LogLevelKey = "log_level"
LogFormatKey = "log_format"
AutoCommitKey = "autocommit"
DoltTransactionCommitKey = "dolt_transaction_commit"
BranchActivityTrackingKey = "branch_activity_tracking"
DataDirKey = "data_dir"
CfgDirKey = "cfg_dir"
MaxConnectionsKey = "max_connections"
MaxWaitConnectionsKey = "back_log"
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
TLSKeyKey = "tls_key"
TLSCertKey = "tls_cert"
RequireSecureTransportKey = "require_secure_transport"
MaxLoggedQueryLenKey = "max_logged_query_len"
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
DisableClientMultiStatementsKey = "disable_client_multi_statements"
MetricsLabelsKey = "metrics_labels"
MetricsHostKey = "metrics_host"
MetricsPortKey = "metrics_port"
MetricsTLSCertKey = "metrics_tls_cert"
MetricsTLSKeyKey = "metrics_tls_key"
MetricsTLSCAKey = "metrics_tls_ca"
MetricsJwksConfigKey = "metrics_jwks_config"
PrivilegeFilePathKey = "privilege_file_path"
BranchControlFilePathKey = "branch_control_file_path"
UserVarsKey = "user_vars"
SystemVarsKey = "system_vars"
JwksConfigKey = "jwks_config"
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
SocketKey = "socket"
RemotesapiPortKey = "remotesapi_port"
RemotesapiReadOnlyKey = "remotesapi_read_only"
ClusterConfigKey = "cluster_config"
EventSchedulerKey = "event_scheduler"
HostKey = "host"
PortKey = "port"
UserKey = "user"
PasswordKey = "password"
ReadTimeoutKey = "net_read_timeout"
WriteTimeoutKey = "net_write_timeout"
ReadOnlyKey = "read_only"
LogLevelKey = "log_level"
LogFormatKey = "log_format"
AutoCommitKey = "autocommit"
DoltTransactionCommitKey = "dolt_transaction_commit"
BranchActivityTrackingKey = "branch_activity_tracking"
DataDirKey = "data_dir"
CfgDirKey = "cfg_dir"
MaxConnectionsKey = "max_connections"
MaxWaitConnectionsKey = "back_log"
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
TLSKeyKey = "tls_key"
TLSCertKey = "tls_cert"
RequireSecureTransportKey = "require_secure_transport"
MaxLoggedQueryLenKey = "max_logged_query_len"
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
DisableClientMultiStatementsKey = "disable_client_multi_statements"
MetricsLabelsKey = "metrics_labels"
MetricsHostKey = "metrics_host"
MetricsPortKey = "metrics_port"
MetricsTLSCertKey = "metrics_tls_cert"
MetricsTLSKeyKey = "metrics_tls_key"
MetricsTLSCAKey = "metrics_tls_ca"
MetricsJwksConfigKey = "metrics_jwks_config"
MetricsJWTRequiredForLocalhostKey = "metrics_jwt_required_for_localhost"
PrivilegeFilePathKey = "privilege_file_path"
BranchControlFilePathKey = "branch_control_file_path"
UserVarsKey = "user_vars"
SystemVarsKey = "system_vars"
JwksConfigKey = "jwks_config"
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
SocketKey = "socket"
RemotesapiPortKey = "remotesapi_port"
RemotesapiReadOnlyKey = "remotesapi_read_only"
ClusterConfigKey = "cluster_config"
EventSchedulerKey = "event_scheduler"
)
type SystemVariableTarget interface {
+32 -21
View File
@@ -111,13 +111,14 @@ type PerformanceYAMLConfig struct {
}
type MetricsYAMLConfig struct {
Labels map[string]string `yaml:"labels"`
Host *string `yaml:"host,omitempty"`
Port *int `yaml:"port,omitempty"`
TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"`
TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"`
TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"`
Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"`
Labels map[string]string `yaml:"labels"`
Host *string `yaml:"host,omitempty"`
Port *int `yaml:"port,omitempty"`
TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"`
TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"`
TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"`
Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"`
JWTRequiredForLocalhost *bool `yaml:"jwt_required_for_localhost,omitempty" minver:"TBD"`
}
type RemotesapiYAMLConfig struct {
@@ -234,13 +235,14 @@ func ServerConfigAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
DataDirStr: ptr(cfg.DataDir()),
CfgDirStr: ptr(cfg.CfgDir()),
MetricsConfig: MetricsYAMLConfig{
Labels: cfg.MetricsLabels(),
Host: nillableStrPtr(cfg.MetricsHost()),
Port: ptr(cfg.MetricsPort()),
TlsCert: ptr(cfg.MetricsTLSCert()),
TlsKey: ptr(cfg.MetricsTLSKey()),
TlsCa: ptr(cfg.MetricsTLSCA()),
Jwks: cfg.MetricsJwksConfig(),
Labels: cfg.MetricsLabels(),
Host: nillableStrPtr(cfg.MetricsHost()),
Port: ptr(cfg.MetricsPort()),
TlsCert: ptr(cfg.MetricsTLSCert()),
TlsKey: ptr(cfg.MetricsTLSKey()),
TlsCa: ptr(cfg.MetricsTLSCA()),
Jwks: cfg.MetricsJwksConfig(),
JWTRequiredForLocalhost: ptr(cfg.MetricsJWTRequiredForLocalhost()),
},
RemotesapiConfig: RemotesapiYAMLConfig{
Port_: cfg.RemotesapiPort(),
@@ -311,13 +313,14 @@ func ServerConfigSetValuesAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
DataDirStr: zeroIf(ptr(cfg.DataDir()), !cfg.ValueSet(DataDirKey)),
CfgDirStr: zeroIf(ptr(cfg.CfgDir()), !cfg.ValueSet(CfgDirKey)),
MetricsConfig: MetricsYAMLConfig{
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)),
TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)),
TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)),
Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)),
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)),
TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)),
TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)),
Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)),
JWTRequiredForLocalhost: zeroIf(ptr(cfg.MetricsJWTRequiredForLocalhost()), !cfg.ValueSet(MetricsJWTRequiredForLocalhostKey)),
},
RemotesapiConfig: RemotesapiYAMLConfig{
Port_: zeroIf(cfg.RemotesapiPort(), !cfg.ValueSet(RemotesapiPortKey)),
@@ -809,6 +812,14 @@ func (cfg YAMLConfig) MetricsJwksConfig() *JwksConfig {
return cfg.MetricsConfig.Jwks
}
func (cfg YAMLConfig) MetricsJWTRequiredForLocalhost() bool {
if cfg.MetricsConfig.JWTRequiredForLocalhost == nil {
return false
}
return *cfg.MetricsConfig.JWTRequiredForLocalhost
}
func (cfg YAMLConfig) RemotesapiPort() *int {
return cfg.RemotesapiConfig.Port_
}
+57
View File
@@ -0,0 +1,57 @@
package http
import (
"fmt"
"net"
"net/http"
"strings"
)
// IsLocalRequest determines if an HTTP request originated from a local source.
// It checks for common proxy headers to rule out forwarded requests and
// inspects the RemoteAddr to see if it corresponds to a loopback address
// or a unix socket path.
func IsLocalRequest(r *http.Request) (bool, error) {
// If any common proxy/forwarding headers are present, consider the request forwarded.
proxyHeaders := []string{
"X-Forwarded-For",
"X-Real-IP",
"Forwarded",
"Via",
"True-Client-IP",
"X-Cluster-Client-Ip",
}
for _, h := range proxyHeaders {
if v := r.Header.Get(h); v != "" {
return false, nil
}
}
remote := r.RemoteAddr
if remote == "" {
return false, fmt.Errorf("empty RemoteAddr")
}
// remote can be "host:port" or a raw address. Try SplitHostPort first.
host, _, err := net.SplitHostPort(remote)
if err != nil {
// If SplitHostPort fails, treat the whole value as the host (could be a unix socket path or raw IP).
host = remote
}
// Treat obvious unix-socket paths as local.
if strings.HasPrefix(host, "/") || strings.HasPrefix(host, "@") {
return true, nil
}
ip := net.ParseIP(host)
if ip == nil {
return false, fmt.Errorf("invalid remote IP: %s", host)
}
// Consider loopback addresses local.
if ip.IsLoopback() {
return true, nil
}
return false, nil
}
+94
View File
@@ -0,0 +1,94 @@
package http
import (
nethttp "net/http"
"testing"
)
func TestIsLocalRequest(t *testing.T) {
tests := []struct {
name string
remoteAddr string
headers map[string]string
wantLocal bool
wantErr bool
}{
{
name: "ipv4 loopback",
remoteAddr: "127.0.0.1:12345",
headers: nil,
wantLocal: true,
},
{
name: "ipv6 loopback",
remoteAddr: "[::1]:54321",
headers: nil,
wantLocal: true,
},
{
name: "unix socket path",
remoteAddr: "/var/run/socket",
headers: nil,
wantLocal: true,
},
{
name: "abstract unix socket",
remoteAddr: "@abstractsocket",
headers: nil,
wantLocal: true,
},
{
name: "non-local ip",
remoteAddr: "192.168.1.10:8080",
headers: nil,
wantLocal: false,
},
{
name: "forwarded header present",
remoteAddr: "127.0.0.1:1111",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1",
},
wantLocal: false,
},
{
name: "malformed remote addr (no IP)",
remoteAddr: "not-an-ip",
headers: nil,
wantErr: true,
},
{
name: "hostname remote addr",
remoteAddr: "localhost:9999",
headers: nil,
wantErr: true,
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req := &nethttp.Request{
Header: make(nethttp.Header),
RemoteAddr: tc.remoteAddr,
}
for k, v := range tc.headers {
req.Header.Set(k, v)
}
got, err := IsLocalRequest(req)
if tc.wantErr {
if err == nil {
t.Fatalf("expected error but got nil (got=%v)", got)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.wantLocal {
t.Fatalf("unexpected result: got=%v want=%v", got, tc.wantLocal)
}
})
}
}
@@ -209,6 +209,7 @@ listener:
metrics:
host: localhost
port: %d
jwt_required_for_localhost: true
`, serverPort, metricsPort)
}
@@ -236,6 +237,7 @@ metrics:
iss: dolthub.com
sub: test_sub
aud: test_aud
jwt_required_for_localhost: true
`, serverPort, metricsPort, jwksPort)
}
@@ -262,6 +264,7 @@ metrics:
iss: dolthub.com
sub: test_sub
aud: test_aud
jwt_required_for_localhost: true
`, serverPort, metricsPort, jwksPort)
}
@@ -289,6 +292,7 @@ metrics:
iss: dolthub.com
sub: test_sub
aud: test_aud
jwt_required_for_localhost: true
`, serverPort, metricsPort, jwksPort)
}
@@ -316,6 +320,7 @@ metrics:
iss: dolthub.com
sub: test_sub
aud: test_aud
jwt_required_for_localhost: true
`, serverPort, metricsPort, jwksPort)
}
@@ -343,6 +348,7 @@ metrics:
iss: dolthub.com
sub: test_sub
aud: test_aud
jwt_required_for_localhost: true
`, serverPort, metricsPort, jwksPort)
}