mirror of
https://github.com/dolthub/dolt.git
synced 2026-05-07 03:05:59 -05:00
add ability to skip metrics auth for the localhost
This commit is contained in:
@@ -384,6 +384,10 @@ func (cfg *commandLineServerConfig) MetricsJwksConfig() *servercfg.JwksConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cfg *commandLineServerConfig) MetricsJWTRequiredForLocalhost() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (cfg *commandLineServerConfig) RemotesapiPort() *int {
|
||||
return cfg.remotesapiPort
|
||||
}
|
||||
|
||||
@@ -58,6 +58,7 @@ import (
|
||||
"github.com/dolthub/dolt/go/libraries/events"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/config"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/filesys"
|
||||
httputils "github.com/dolthub/dolt/go/libraries/utils/http"
|
||||
"github.com/dolthub/dolt/go/libraries/utils/svcs"
|
||||
"github.com/dolthub/dolt/go/store/chunks"
|
||||
eventsapi "github.com/dolthub/eventsapi_schema/dolt/services/eventsapi/v1alpha1"
|
||||
@@ -621,11 +622,23 @@ func ConfigureServices(
|
||||
metricsHandler := promhttp.Handler()
|
||||
jwksConfig := cfg.ServerConfig.MetricsJwksConfig()
|
||||
enableMetricsAuth := jwksConfig != nil
|
||||
requireLocalhostAuth := cfg.ServerConfig.MetricsJWTRequiredForLocalhost()
|
||||
|
||||
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s", enableMetricsAuth, addr)
|
||||
logrus.Infof("Starting metrics server. auth_enabled = %t, addr = %s, require_localhost_auth = %t", enableMetricsAuth, addr, requireLocalhostAuth)
|
||||
|
||||
if enableMetricsAuth {
|
||||
mux.Handle("/metrics", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !requireLocalhostAuth {
|
||||
isLocal, err := httputils.IsLocalRequest(r)
|
||||
logrus.Info("Metrics JWT not required for localhost isLocal:", isLocal, "err:", err)
|
||||
if err != nil {
|
||||
logrus.Warnf("error checking if request is local for /metrics (assuming remote) request: %v.", err)
|
||||
} else if isLocal {
|
||||
metricsHandler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
|
||||
w.Header().Set("WWW-Authenticate", `Bearer realm="metrics"`)
|
||||
|
||||
@@ -209,6 +209,7 @@ type ServerConfig interface {
|
||||
MetricsTLSKey() string
|
||||
MetricsTLSCA() string
|
||||
MetricsJwksConfig() *JwksConfig
|
||||
MetricsJWTRequiredForLocalhost() bool
|
||||
|
||||
// PrivilegeFilePath returns the path to the file which contains all needed privilege information in the form of a
|
||||
// JSON string.
|
||||
@@ -333,47 +334,48 @@ func ValidateConfig(config ServerConfig) error {
|
||||
}
|
||||
|
||||
const (
|
||||
HostKey = "host"
|
||||
PortKey = "port"
|
||||
UserKey = "user"
|
||||
PasswordKey = "password"
|
||||
ReadTimeoutKey = "net_read_timeout"
|
||||
WriteTimeoutKey = "net_write_timeout"
|
||||
ReadOnlyKey = "read_only"
|
||||
LogLevelKey = "log_level"
|
||||
LogFormatKey = "log_format"
|
||||
AutoCommitKey = "autocommit"
|
||||
DoltTransactionCommitKey = "dolt_transaction_commit"
|
||||
BranchActivityTrackingKey = "branch_activity_tracking"
|
||||
DataDirKey = "data_dir"
|
||||
CfgDirKey = "cfg_dir"
|
||||
MaxConnectionsKey = "max_connections"
|
||||
MaxWaitConnectionsKey = "back_log"
|
||||
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
|
||||
TLSKeyKey = "tls_key"
|
||||
TLSCertKey = "tls_cert"
|
||||
RequireSecureTransportKey = "require_secure_transport"
|
||||
MaxLoggedQueryLenKey = "max_logged_query_len"
|
||||
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
|
||||
DisableClientMultiStatementsKey = "disable_client_multi_statements"
|
||||
MetricsLabelsKey = "metrics_labels"
|
||||
MetricsHostKey = "metrics_host"
|
||||
MetricsPortKey = "metrics_port"
|
||||
MetricsTLSCertKey = "metrics_tls_cert"
|
||||
MetricsTLSKeyKey = "metrics_tls_key"
|
||||
MetricsTLSCAKey = "metrics_tls_ca"
|
||||
MetricsJwksConfigKey = "metrics_jwks_config"
|
||||
PrivilegeFilePathKey = "privilege_file_path"
|
||||
BranchControlFilePathKey = "branch_control_file_path"
|
||||
UserVarsKey = "user_vars"
|
||||
SystemVarsKey = "system_vars"
|
||||
JwksConfigKey = "jwks_config"
|
||||
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
|
||||
SocketKey = "socket"
|
||||
RemotesapiPortKey = "remotesapi_port"
|
||||
RemotesapiReadOnlyKey = "remotesapi_read_only"
|
||||
ClusterConfigKey = "cluster_config"
|
||||
EventSchedulerKey = "event_scheduler"
|
||||
HostKey = "host"
|
||||
PortKey = "port"
|
||||
UserKey = "user"
|
||||
PasswordKey = "password"
|
||||
ReadTimeoutKey = "net_read_timeout"
|
||||
WriteTimeoutKey = "net_write_timeout"
|
||||
ReadOnlyKey = "read_only"
|
||||
LogLevelKey = "log_level"
|
||||
LogFormatKey = "log_format"
|
||||
AutoCommitKey = "autocommit"
|
||||
DoltTransactionCommitKey = "dolt_transaction_commit"
|
||||
BranchActivityTrackingKey = "branch_activity_tracking"
|
||||
DataDirKey = "data_dir"
|
||||
CfgDirKey = "cfg_dir"
|
||||
MaxConnectionsKey = "max_connections"
|
||||
MaxWaitConnectionsKey = "back_log"
|
||||
MaxWaitConnectionsTimeoutKey = "max_connections_timeout"
|
||||
TLSKeyKey = "tls_key"
|
||||
TLSCertKey = "tls_cert"
|
||||
RequireSecureTransportKey = "require_secure_transport"
|
||||
MaxLoggedQueryLenKey = "max_logged_query_len"
|
||||
ShouldEncodeLoggedQueryKey = "should_encode_logged_query"
|
||||
DisableClientMultiStatementsKey = "disable_client_multi_statements"
|
||||
MetricsLabelsKey = "metrics_labels"
|
||||
MetricsHostKey = "metrics_host"
|
||||
MetricsPortKey = "metrics_port"
|
||||
MetricsTLSCertKey = "metrics_tls_cert"
|
||||
MetricsTLSKeyKey = "metrics_tls_key"
|
||||
MetricsTLSCAKey = "metrics_tls_ca"
|
||||
MetricsJwksConfigKey = "metrics_jwks_config"
|
||||
MetricsJWTRequiredForLocalhostKey = "metrics_jwt_required_for_localhost"
|
||||
PrivilegeFilePathKey = "privilege_file_path"
|
||||
BranchControlFilePathKey = "branch_control_file_path"
|
||||
UserVarsKey = "user_vars"
|
||||
SystemVarsKey = "system_vars"
|
||||
JwksConfigKey = "jwks_config"
|
||||
AllowCleartextPasswordsKey = "allow_cleartext_passwords"
|
||||
SocketKey = "socket"
|
||||
RemotesapiPortKey = "remotesapi_port"
|
||||
RemotesapiReadOnlyKey = "remotesapi_read_only"
|
||||
ClusterConfigKey = "cluster_config"
|
||||
EventSchedulerKey = "event_scheduler"
|
||||
)
|
||||
|
||||
type SystemVariableTarget interface {
|
||||
|
||||
@@ -111,13 +111,14 @@ type PerformanceYAMLConfig struct {
|
||||
}
|
||||
|
||||
type MetricsYAMLConfig struct {
|
||||
Labels map[string]string `yaml:"labels"`
|
||||
Host *string `yaml:"host,omitempty"`
|
||||
Port *int `yaml:"port,omitempty"`
|
||||
TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"`
|
||||
TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"`
|
||||
TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"`
|
||||
Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"`
|
||||
Labels map[string]string `yaml:"labels"`
|
||||
Host *string `yaml:"host,omitempty"`
|
||||
Port *int `yaml:"port,omitempty"`
|
||||
TlsCert *string `yaml:"tls_cert,omitempty" minver:"1.78.2"`
|
||||
TlsKey *string `yaml:"tls_key,omitempty" minver:"1.78.2"`
|
||||
TlsCa *string `yaml:"tls_ca,omitempty" minver:"1.78.2"`
|
||||
Jwks *JwksConfig `yaml:"jwks,omitempty" minver:"TBD"`
|
||||
JWTRequiredForLocalhost *bool `yaml:"jwt_required_for_localhost,omitempty" minver:"TBD"`
|
||||
}
|
||||
|
||||
type RemotesapiYAMLConfig struct {
|
||||
@@ -234,13 +235,14 @@ func ServerConfigAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
|
||||
DataDirStr: ptr(cfg.DataDir()),
|
||||
CfgDirStr: ptr(cfg.CfgDir()),
|
||||
MetricsConfig: MetricsYAMLConfig{
|
||||
Labels: cfg.MetricsLabels(),
|
||||
Host: nillableStrPtr(cfg.MetricsHost()),
|
||||
Port: ptr(cfg.MetricsPort()),
|
||||
TlsCert: ptr(cfg.MetricsTLSCert()),
|
||||
TlsKey: ptr(cfg.MetricsTLSKey()),
|
||||
TlsCa: ptr(cfg.MetricsTLSCA()),
|
||||
Jwks: cfg.MetricsJwksConfig(),
|
||||
Labels: cfg.MetricsLabels(),
|
||||
Host: nillableStrPtr(cfg.MetricsHost()),
|
||||
Port: ptr(cfg.MetricsPort()),
|
||||
TlsCert: ptr(cfg.MetricsTLSCert()),
|
||||
TlsKey: ptr(cfg.MetricsTLSKey()),
|
||||
TlsCa: ptr(cfg.MetricsTLSCA()),
|
||||
Jwks: cfg.MetricsJwksConfig(),
|
||||
JWTRequiredForLocalhost: ptr(cfg.MetricsJWTRequiredForLocalhost()),
|
||||
},
|
||||
RemotesapiConfig: RemotesapiYAMLConfig{
|
||||
Port_: cfg.RemotesapiPort(),
|
||||
@@ -311,13 +313,14 @@ func ServerConfigSetValuesAsYAMLConfig(cfg ServerConfig) *YAMLConfig {
|
||||
DataDirStr: zeroIf(ptr(cfg.DataDir()), !cfg.ValueSet(DataDirKey)),
|
||||
CfgDirStr: zeroIf(ptr(cfg.CfgDir()), !cfg.ValueSet(CfgDirKey)),
|
||||
MetricsConfig: MetricsYAMLConfig{
|
||||
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
|
||||
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
|
||||
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
|
||||
TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)),
|
||||
TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)),
|
||||
TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)),
|
||||
Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)),
|
||||
Labels: zeroIf(cfg.MetricsLabels(), !cfg.ValueSet(MetricsLabelsKey)),
|
||||
Host: zeroIf(ptr(cfg.MetricsHost()), !cfg.ValueSet(MetricsHostKey)),
|
||||
Port: zeroIf(ptr(cfg.MetricsPort()), !cfg.ValueSet(MetricsPortKey)),
|
||||
TlsCert: zeroIf(ptr(cfg.MetricsTLSCert()), !cfg.ValueSet(MetricsTLSCertKey)),
|
||||
TlsKey: zeroIf(ptr(cfg.MetricsTLSKey()), !cfg.ValueSet(MetricsTLSKeyKey)),
|
||||
TlsCa: zeroIf(ptr(cfg.MetricsTLSCA()), !cfg.ValueSet(MetricsTLSCAKey)),
|
||||
Jwks: zeroIf(cfg.MetricsJwksConfig(), !cfg.ValueSet(MetricsJwksConfigKey)),
|
||||
JWTRequiredForLocalhost: zeroIf(ptr(cfg.MetricsJWTRequiredForLocalhost()), !cfg.ValueSet(MetricsJWTRequiredForLocalhostKey)),
|
||||
},
|
||||
RemotesapiConfig: RemotesapiYAMLConfig{
|
||||
Port_: zeroIf(cfg.RemotesapiPort(), !cfg.ValueSet(RemotesapiPortKey)),
|
||||
@@ -809,6 +812,14 @@ func (cfg YAMLConfig) MetricsJwksConfig() *JwksConfig {
|
||||
return cfg.MetricsConfig.Jwks
|
||||
}
|
||||
|
||||
func (cfg YAMLConfig) MetricsJWTRequiredForLocalhost() bool {
|
||||
if cfg.MetricsConfig.JWTRequiredForLocalhost == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return *cfg.MetricsConfig.JWTRequiredForLocalhost
|
||||
}
|
||||
|
||||
func (cfg YAMLConfig) RemotesapiPort() *int {
|
||||
return cfg.RemotesapiConfig.Port_
|
||||
}
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IsLocalRequest determines if an HTTP request originated from a local source.
|
||||
// It checks for common proxy headers to rule out forwarded requests and
|
||||
// inspects the RemoteAddr to see if it corresponds to a loopback address
|
||||
// or a unix socket path.
|
||||
func IsLocalRequest(r *http.Request) (bool, error) {
|
||||
// If any common proxy/forwarding headers are present, consider the request forwarded.
|
||||
proxyHeaders := []string{
|
||||
"X-Forwarded-For",
|
||||
"X-Real-IP",
|
||||
"Forwarded",
|
||||
"Via",
|
||||
"True-Client-IP",
|
||||
"X-Cluster-Client-Ip",
|
||||
}
|
||||
for _, h := range proxyHeaders {
|
||||
if v := r.Header.Get(h); v != "" {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
remote := r.RemoteAddr
|
||||
if remote == "" {
|
||||
return false, fmt.Errorf("empty RemoteAddr")
|
||||
}
|
||||
|
||||
// remote can be "host:port" or a raw address. Try SplitHostPort first.
|
||||
host, _, err := net.SplitHostPort(remote)
|
||||
if err != nil {
|
||||
// If SplitHostPort fails, treat the whole value as the host (could be a unix socket path or raw IP).
|
||||
host = remote
|
||||
}
|
||||
|
||||
// Treat obvious unix-socket paths as local.
|
||||
if strings.HasPrefix(host, "/") || strings.HasPrefix(host, "@") {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false, fmt.Errorf("invalid remote IP: %s", host)
|
||||
}
|
||||
|
||||
// Consider loopback addresses local.
|
||||
if ip.IsLoopback() {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
nethttp "net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsLocalRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
headers map[string]string
|
||||
wantLocal bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "ipv4 loopback",
|
||||
remoteAddr: "127.0.0.1:12345",
|
||||
headers: nil,
|
||||
wantLocal: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback",
|
||||
remoteAddr: "[::1]:54321",
|
||||
headers: nil,
|
||||
wantLocal: true,
|
||||
},
|
||||
{
|
||||
name: "unix socket path",
|
||||
remoteAddr: "/var/run/socket",
|
||||
headers: nil,
|
||||
wantLocal: true,
|
||||
},
|
||||
{
|
||||
name: "abstract unix socket",
|
||||
remoteAddr: "@abstractsocket",
|
||||
headers: nil,
|
||||
wantLocal: true,
|
||||
},
|
||||
{
|
||||
name: "non-local ip",
|
||||
remoteAddr: "192.168.1.10:8080",
|
||||
headers: nil,
|
||||
wantLocal: false,
|
||||
},
|
||||
{
|
||||
name: "forwarded header present",
|
||||
remoteAddr: "127.0.0.1:1111",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1",
|
||||
},
|
||||
wantLocal: false,
|
||||
},
|
||||
{
|
||||
name: "malformed remote addr (no IP)",
|
||||
remoteAddr: "not-an-ip",
|
||||
headers: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "hostname remote addr",
|
||||
remoteAddr: "localhost:9999",
|
||||
headers: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := &nethttp.Request{
|
||||
Header: make(nethttp.Header),
|
||||
RemoteAddr: tc.remoteAddr,
|
||||
}
|
||||
for k, v := range tc.headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
got, err := IsLocalRequest(req)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error but got nil (got=%v)", got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if got != tc.wantLocal {
|
||||
t.Fatalf("unexpected result: got=%v want=%v", got, tc.wantLocal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -209,6 +209,7 @@ listener:
|
||||
metrics:
|
||||
host: localhost
|
||||
port: %d
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort)
|
||||
}
|
||||
|
||||
@@ -236,6 +237,7 @@ metrics:
|
||||
iss: dolthub.com
|
||||
sub: test_sub
|
||||
aud: test_aud
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort, jwksPort)
|
||||
}
|
||||
|
||||
@@ -262,6 +264,7 @@ metrics:
|
||||
iss: dolthub.com
|
||||
sub: test_sub
|
||||
aud: test_aud
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort, jwksPort)
|
||||
}
|
||||
|
||||
@@ -289,6 +292,7 @@ metrics:
|
||||
iss: dolthub.com
|
||||
sub: test_sub
|
||||
aud: test_aud
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort, jwksPort)
|
||||
}
|
||||
|
||||
@@ -316,6 +320,7 @@ metrics:
|
||||
iss: dolthub.com
|
||||
sub: test_sub
|
||||
aud: test_aud
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort, jwksPort)
|
||||
}
|
||||
|
||||
@@ -343,6 +348,7 @@ metrics:
|
||||
iss: dolthub.com
|
||||
sub: test_sub
|
||||
aud: test_aud
|
||||
jwt_required_for_localhost: true
|
||||
`, serverPort, metricsPort, jwksPort)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user