From e44048ed205d4d1027aec8495513605b6fdaf422 Mon Sep 17 00:00:00 2001 From: Brian Hendriks Date: Thu, 11 Dec 2025 11:50:11 -0800 Subject: [PATCH] add ability to skip metrics auth for the localhost --- .../commands/sqlserver/command_line_config.go | 4 + go/cmd/dolt/commands/sqlserver/server.go | 15 ++- .../doltcore/servercfg/serverconfig.go | 84 +++++++++-------- .../doltcore/servercfg/yaml_config.go | 53 ++++++----- go/libraries/utils/http/requests.go | 57 +++++++++++ go/libraries/utils/http/requests_test.go | 94 +++++++++++++++++++ .../go-sql-server-driver/metrics_auth_test.go | 6 ++ 7 files changed, 250 insertions(+), 63 deletions(-) create mode 100644 go/libraries/utils/http/requests.go create mode 100644 go/libraries/utils/http/requests_test.go diff --git a/go/cmd/dolt/commands/sqlserver/command_line_config.go b/go/cmd/dolt/commands/sqlserver/command_line_config.go index 9321d22e1d..0968f22dcd 100755 --- a/go/cmd/dolt/commands/sqlserver/command_line_config.go +++ b/go/cmd/dolt/commands/sqlserver/command_line_config.go @@ -384,6 +384,10 @@ func (cfg *commandLineServerConfig) MetricsJwksConfig() *servercfg.JwksConfig { return nil } +func (cfg *commandLineServerConfig) MetricsJWTRequiredForLocalhost() bool { + return false +} + func (cfg *commandLineServerConfig) RemotesapiPort() *int { return cfg.remotesapiPort } diff --git a/go/cmd/dolt/commands/sqlserver/server.go b/go/cmd/dolt/commands/sqlserver/server.go index 53c3bd75d2..936db5ef04 100644 --- a/go/cmd/dolt/commands/sqlserver/server.go +++ b/go/cmd/dolt/commands/sqlserver/server.go @@ -58,6 +58,7 @@ import ( "github.com/dolthub/dolt/go/libraries/events" "github.com/dolthub/dolt/go/libraries/utils/config" "github.com/dolthub/dolt/go/libraries/utils/filesys" + httputils "github.com/dolthub/dolt/go/libraries/utils/http" "github.com/dolthub/dolt/go/libraries/utils/svcs" "github.com/dolthub/dolt/go/store/chunks" eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1" @@ -621,11 +622,23 @@ func ConfigureServices( metricsHandler := promhttp.Handler() jwksConfig := cfg.ServerConfig.MetricsJwksConfig() enableMetricsAuth := jwksConfig != nil + requireLocalhostAuth := cfg.ServerConfig.MetricsJWTRequiredForLocalhost() - logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s", enableMetricsAuth, addr) + logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s, require_localhost_auth = %t", enableMetricsAuth, addr, requireLocalhostAuth) if enableMetricsAuth { mux.Handle("/metrics", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !requireLocalhostAuth { + isLocal, err := httputils.IsLocalRequest(r) + logrus.Info("Metrics JWT not required for localhost isLocal:", isLocal, "err:", err) + if err != nil { + logrus.Warnf("error checking if request is local for /metrics (assuming remote) request: %v.", err) + } else if isLocal { + metricsHandler.ServeHTTP(w, r) + return + } + } + auth := r.Header.Get("Authorization") if auth == "" || !strings.HasPrefix(auth, "Bearer ") { w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`) diff --git a/go/libraries/doltcore/servercfg/serverconfig.go b/go/libraries/doltcore/servercfg/serverconfig.go index ac7b6e05af..ee0d908cd3 100644 --- a/go/libraries/doltcore/servercfg/serverconfig.go +++ b/go/libraries/doltcore/servercfg/serverconfig.go @@ -209,6 +209,7 @@ type ServerConfig interface { MetricsTLSKey() string MetricsTLSCA() string MetricsJwksConfig() *JwksConfig + MetricsJWTRequiredForLocalhost() bool // PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a // JSON string. @@ -333,47 +334,48 @@ func ValidateConfig(config ServerConfig) error { } const ( - HostKey = "host" - PortKey = "port" - UserKey = "user" - PasswordKey = "password" - ReadTimeoutKey = "net_read_timeout" - WriteTimeoutKey = "net_write_timeout" - ReadOnlyKey = "read_only" - LogLevelKey = "log_level" - LogFormatKey = "log_format" - AutoCommitKey = "autocommit" - DoltTransactionCommitKey = "dolt_transaction_commit" - BranchActivityTrackingKey = "branch_activity_tracking" - DataDirKey = "data_dir" - CfgDirKey = "cfg_dir" - MaxConnectionsKey = "max_connections" - MaxWaitConnectionsKey = "back_log" - MaxWaitConnectionsTimeoutKey = "max_connections_timeout" - TLSKeyKey = "tls_key" - TLSCertKey = "tls_cert" - RequireSecureTransportKey = "require_secure_transport" - MaxLoggedQueryLenKey = "max_logged_query_len" - ShouldEncodeLoggedQueryKey = "should_encode_logged_query" - DisableClientMultiStatementsKey = "disable_client_multi_statements" - MetricsLabelsKey = "metrics_labels" - MetricsHostKey = "metrics_host" - MetricsPortKey = "metrics_port" - MetricsTLSCertKey = "metrics_tls_cert" - MetricsTLSKeyKey = "metrics_tls_key" - MetricsTLSCAKey = "metrics_tls_ca" - MetricsJwksConfigKey = "metrics_jwks_config" - PrivilegeFilePathKey = "privilege_file_path" - BranchControlFilePathKey = "branch_control_file_path" - UserVarsKey = "user_vars" - SystemVarsKey = "system_vars" - JwksConfigKey = "jwks_config" - AllowCleartextPasswordsKey = "allow_cleartext_passwords" - SocketKey = "socket" - RemotesapiPortKey = "remotesapi_port" - RemotesapiReadOnlyKey = "remotesapi_read_only" - ClusterConfigKey = "cluster_config" - EventSchedulerKey = "event_scheduler" + HostKey = "host" + PortKey = "port" + UserKey = "user" + PasswordKey = "password" + ReadTimeoutKey = "net_read_timeout" + WriteTimeoutKey = "net_write_timeout" + ReadOnlyKey = "read_only" + LogLevelKey = "log_level" + LogFormatKey = "log_format" + AutoCommitKey = "autocommit" + DoltTransactionCommitKey = "dolt_transaction_commit" + BranchActivityTrackingKey = "branch_activity_tracking" + DataDirKey = "data_dir" + CfgDirKey = "cfg_dir" + MaxConnectionsKey = "max_connections" + MaxWaitConnectionsKey = "back_log" + MaxWaitConnectionsTimeoutKey = "max_connections_timeout" + TLSKeyKey = "tls_key" + TLSCertKey = "tls_cert" + RequireSecureTransportKey = "require_secure_transport" + MaxLoggedQueryLenKey = "max_logged_query_len" + ShouldEncodeLoggedQueryKey = "should_encode_logged_query" + DisableClientMultiStatementsKey = "disable_client_multi_statements" + MetricsLabelsKey = "metrics_labels" + MetricsHostKey = "metrics_host" + MetricsPortKey = "metrics_port" + MetricsTLSCertKey = "metrics_tls_cert" + MetricsTLSKeyKey = "metrics_tls_key" + MetricsTLSCAKey = "metrics_tls_ca" + MetricsJwksConfigKey = "metrics_jwks_config" + MetricsJWTRequiredForLocalhostKey = "metrics_jwt_required_for_localhost" + PrivilegeFilePathKey = "privilege_file_path" + BranchControlFilePathKey = "branch_control_file_path" + UserVarsKey = "user_vars" + SystemVarsKey = "system_vars" + JwksConfigKey = "jwks_config" + AllowCleartextPasswordsKey = "allow_cleartext_passwords" + SocketKey = "socket" + RemotesapiPortKey = "remotesapi_port" + RemotesapiReadOnlyKey = "remotesapi_read_only" + ClusterConfigKey = "cluster_config" + EventSchedulerKey = "event_scheduler" ) type SystemVariableTarget interface { diff --git a/go/libraries/doltcore/servercfg/yaml_config.go b/go/libraries/doltcore/servercfg/yaml_config.go index ebfb38300f..5b39f47bf2 100644 --- a/go/libraries/doltcore/servercfg/yaml_config.go +++ b/go/libraries/doltcore/servercfg/yaml_config.go @@ -111,13 +111,14 @@ type PerformanceYAMLConfig struct { } type MetricsYAMLConfig struct { - Labels map[string]string `yaml:"labels"` - Host *string `yaml:"host,omitempty"` - Port *int `yaml:"port,omitempty"` - TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"` - TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"` - TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"` - Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"` + Labels map[string]string `yaml:"labels"` + Host *string `yaml:"host,omitempty"` + Port *int `yaml:"port,omitempty"` + TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"` + TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"` + TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"` + Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"` + JWTRequiredForLocalhost *bool `yaml:"jwt_required_for_localhost,omitempty" minver:"TBD"` } type RemotesapiYAMLConfig struct { @@ -234,13 +235,14 @@ func ServerConfigAsYAMLConfig(cfg ServerConfig) *YAMLConfig { DataDirStr: ptr(cfg.DataDir()), CfgDirStr: ptr(cfg.CfgDir()), MetricsConfig: MetricsYAMLConfig{ - Labels: cfg.MetricsLabels(), - Host: nillableStrPtr(cfg.MetricsHost()), - Port: ptr(cfg.MetricsPort()), - TlsCert: ptr(cfg.MetricsTLSCert()), - TlsKey: ptr(cfg.MetricsTLSKey()), - TlsCa: ptr(cfg.MetricsTLSCA()), - Jwks: cfg.MetricsJwksConfig(), + Labels: cfg.MetricsLabels(), + Host: nillableStrPtr(cfg.MetricsHost()), + Port: ptr(cfg.MetricsPort()), + TlsCert: ptr(cfg.MetricsTLSCert()), + TlsKey: ptr(cfg.MetricsTLSKey()), + TlsCa: ptr(cfg.MetricsTLSCA()), + Jwks: cfg.MetricsJwksConfig(), + JWTRequiredForLocalhost: ptr(cfg.MetricsJWTRequiredForLocalhost()), }, RemotesapiConfig: RemotesapiYAMLConfig{ Port_: cfg.RemotesapiPort(), @@ -311,13 +313,14 @@ func ServerConfigSetValuesAsYAMLConfig(cfg ServerConfig) *YAMLConfig { DataDirStr: zeroIf(ptr(cfg.DataDir()), !cfg.ValueSet(DataDirKey)), CfgDirStr: zeroIf(ptr(cfg.CfgDir()), !cfg.ValueSet(CfgDirKey)), MetricsConfig: MetricsYAMLConfig{ - Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)), - Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)), - Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)), - TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)), - TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)), - TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)), - Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)), + Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)), + Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)), + Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)), + TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)), + TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)), + TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)), + Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)), + JWTRequiredForLocalhost: zeroIf(ptr(cfg.MetricsJWTRequiredForLocalhost()), !cfg.ValueSet(MetricsJWTRequiredForLocalhostKey)), }, RemotesapiConfig: RemotesapiYAMLConfig{ Port_: zeroIf(cfg.RemotesapiPort(), !cfg.ValueSet(RemotesapiPortKey)), @@ -809,6 +812,14 @@ func (cfg YAMLConfig) MetricsJwksConfig() *JwksConfig { return cfg.MetricsConfig.Jwks } +func (cfg YAMLConfig) MetricsJWTRequiredForLocalhost() bool { + if cfg.MetricsConfig.JWTRequiredForLocalhost == nil { + return false + } + + return *cfg.MetricsConfig.JWTRequiredForLocalhost +} + func (cfg YAMLConfig) RemotesapiPort() *int { return cfg.RemotesapiConfig.Port_ } diff --git a/go/libraries/utils/http/requests.go b/go/libraries/utils/http/requests.go new file mode 100644 index 0000000000..1165149764 --- /dev/null +++ b/go/libraries/utils/http/requests.go @@ -0,0 +1,57 @@ +package http + +import ( + "fmt" + "net" + "net/http" + "strings" +) + +// IsLocalRequest determines if an HTTP request originated from a local source. +// It checks for common proxy headers to rule out forwarded requests and +// inspects the RemoteAddr to see if it corresponds to a loopback address +// or a unix socket path. +func IsLocalRequest(r *http.Request) (bool, error) { + // If any common proxy/forwarding headers are present, consider the request forwarded. + proxyHeaders := []string{ + "X-Forwarded-For", + "X-Real-IP", + "Forwarded", + "Via", + "True-Client-IP", + "X-Cluster-Client-Ip", + } + for _, h := range proxyHeaders { + if v := r.Header.Get(h); v != "" { + return false, nil + } + } + + remote := r.RemoteAddr + if remote == "" { + return false, fmt.Errorf("empty RemoteAddr") + } + + // remote can be "host:port" or a raw address. Try SplitHostPort first. + host, _, err := net.SplitHostPort(remote) + if err != nil { + // If SplitHostPort fails, treat the whole value as the host (could be a unix socket path or raw IP). + host = remote + } + + // Treat obvious unix-socket paths as local. + if strings.HasPrefix(host, "/") || strings.HasPrefix(host, "@") { + return true, nil + } + + ip := net.ParseIP(host) + if ip == nil { + return false, fmt.Errorf("invalid remote IP: %s", host) + } + + // Consider loopback addresses local. + if ip.IsLoopback() { + return true, nil + } + return false, nil +} diff --git a/go/libraries/utils/http/requests_test.go b/go/libraries/utils/http/requests_test.go new file mode 100644 index 0000000000..f8d8c9aebc --- /dev/null +++ b/go/libraries/utils/http/requests_test.go @@ -0,0 +1,94 @@ +package http + +import ( + nethttp "net/http" + "testing" +) + +func TestIsLocalRequest(t *testing.T) { + tests := []struct { + name string + remoteAddr string + headers map[string]string + wantLocal bool + wantErr bool + }{ + { + name: "ipv4 loopback", + remoteAddr: "127.0.0.1:12345", + headers: nil, + wantLocal: true, + }, + { + name: "ipv6 loopback", + remoteAddr: "[::1]:54321", + headers: nil, + wantLocal: true, + }, + { + name: "unix socket path", + remoteAddr: "/var/run/socket", + headers: nil, + wantLocal: true, + }, + { + name: "abstract unix socket", + remoteAddr: "@abstractsocket", + headers: nil, + wantLocal: true, + }, + { + name: "non-local ip", + remoteAddr: "192.168.1.10:8080", + headers: nil, + wantLocal: false, + }, + { + name: "forwarded header present", + remoteAddr: "127.0.0.1:1111", + headers: map[string]string{ + "X-Forwarded-For": "203.0.113.1", + }, + wantLocal: false, + }, + { + name: "malformed remote addr (no IP)", + remoteAddr: "not-an-ip", + headers: nil, + wantErr: true, + }, + { + name: "hostname remote addr", + remoteAddr: "localhost:9999", + headers: nil, + wantErr: true, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req := &nethttp.Request{ + Header: make(nethttp.Header), + RemoteAddr: tc.remoteAddr, + } + for k, v := range tc.headers { + req.Header.Set(k, v) + } + + got, err := IsLocalRequest(req) + if tc.wantErr { + if err == nil { + t.Fatalf("expected error but got nil (got=%v)", got) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.wantLocal { + t.Fatalf("unexpected result: got=%v want=%v", got, tc.wantLocal) + } + }) + } +} diff --git a/integration-tests/go-sql-server-driver/metrics_auth_test.go b/integration-tests/go-sql-server-driver/metrics_auth_test.go index acca54eb84..7c46d7be66 100644 --- a/integration-tests/go-sql-server-driver/metrics_auth_test.go +++ b/integration-tests/go-sql-server-driver/metrics_auth_test.go @@ -209,6 +209,7 @@ listener: metrics: host: localhost port: %d + jwt_required_for_localhost: true `, serverPort, metricsPort) } @@ -236,6 +237,7 @@ metrics: iss: dolthub.com sub: test_sub aud: test_aud + jwt_required_for_localhost: true `, serverPort, metricsPort, jwksPort) } @@ -262,6 +264,7 @@ metrics: iss: dolthub.com sub: test_sub aud: test_aud + jwt_required_for_localhost: true `, serverPort, metricsPort, jwksPort) } @@ -289,6 +292,7 @@ metrics: iss: dolthub.com sub: test_sub aud: test_aud + jwt_required_for_localhost: true `, serverPort, metricsPort, jwksPort) } @@ -316,6 +320,7 @@ metrics: iss: dolthub.com sub: test_sub aud: test_aud + jwt_required_for_localhost: true `, serverPort, metricsPort, jwksPort) } @@ -343,6 +348,7 @@ metrics: iss: dolthub.com sub: test_sub aud: test_aud + jwt_required_for_localhost: true `, serverPort, metricsPort, jwksPort) }