From 07a718dc8ea962faccffa4b1890510b15f559d93 Mon Sep 17 00:00:00 2001 From: Ralf Haferkamp Date: Thu, 17 Aug 2023 15:28:39 +0200 Subject: [PATCH] proxy: User ReverseProxy.Rewrite instead of Director With Go 1.20 the "Rewrite" hook for ReverseProxy was introduced to supersede of the "Director" hook (see: https://github.com/golang/go/commit/a55793835f16d0242be18aff4ec0bd13494175bd) The Rewrite hooks allows for better separation between the incoming and outgoing request. In particular it makes it pretty easy to set the correct X-Forwarded-* Headers on the outgoing request. The need for using "Rewrite" came up when trying to embed authelia. It uses the X-Forwarded-Host and X-Forwared-Proto headers to e.g. compute the correct return values for the various endpoints in .well-known/openid-configuration. --- services/proxy/pkg/proxy/proxy.go | 6 +- services/proxy/pkg/router/router.go | 76 ++++++++++++------------ services/proxy/pkg/router/router_test.go | 12 +++- 3 files changed, 51 insertions(+), 43 deletions(-) diff --git a/services/proxy/pkg/proxy/proxy.go b/services/proxy/pkg/proxy/proxy.go index 0694f6619..4d3894c84 100644 --- a/services/proxy/pkg/proxy/proxy.go +++ b/services/proxy/pkg/proxy/proxy.go @@ -36,9 +36,9 @@ func NewMultiHostReverseProxy(opts ...Option) (*MultiHostReverseProxy, error) { config: options.Config, } - rp.Director = func(r *http.Request) { - ri := router.ContextRoutingInfo(r.Context()) - ri.Director()(r) + rp.Rewrite = func(r *httputil.ProxyRequest) { + ri := router.ContextRoutingInfo(r.In.Context()) + ri.Rewrite()(r) } tlsConf := &tls.Config{ diff --git a/services/proxy/pkg/router/router.go b/services/proxy/pkg/router/router.go index d61e99a72..85734ae5d 100644 --- a/services/proxy/pkg/router/router.go +++ b/services/proxy/pkg/router/router.go @@ -3,6 +3,7 @@ package router import ( "context" "net/http" + "net/http/httputil" "net/url" "regexp" "strings" @@ -57,7 +58,7 @@ func New(policySelector *config.PolicySelector, policies []config.Policy, logger r := Router{ logger: logger, - directors: make(map[string]map[config.RouteType]map[string][]RoutingInfo), + rewriters: make(map[string]map[config.RouteType]map[string][]RoutingInfo), policySelector: selector, } for _, pol := range policies { @@ -83,16 +84,16 @@ func New(policySelector *config.PolicySelector, policies []config.Policy, logger return r } -// RoutingInfo contains the proxy director and some information about the route. +// RoutingInfo contains the proxy rewrite hook and some information about the route. type RoutingInfo struct { - director func(*http.Request) + rewrite func(*httputil.ProxyRequest) endpoint string unprotected bool } -// Director returns the proxy director. -func (r RoutingInfo) Director() func(*http.Request) { - return r.director +// Rewrite returns the proxy rewrite hook. +func (r RoutingInfo) Rewrite() func(*httputil.ProxyRequest) { + return r.rewrite } // IsRouteUnprotected returns true if the route doesn't need to be authenticated. @@ -103,33 +104,33 @@ func (r RoutingInfo) IsRouteUnprotected() bool { // Router handles the routing of HTTP requests according to the given policies. type Router struct { logger log.Logger - directors map[string]map[config.RouteType]map[string][]RoutingInfo + rewriters map[string]map[config.RouteType]map[string][]RoutingInfo policySelector policy.Selector } func (rt Router) addHost(policy string, target *url.URL, route config.Route) { targetQuery := target.RawQuery - if rt.directors[policy] == nil { - rt.directors[policy] = make(map[config.RouteType]map[string][]RoutingInfo) + if rt.rewriters[policy] == nil { + rt.rewriters[policy] = make(map[config.RouteType]map[string][]RoutingInfo) } routeType := config.DefaultRouteType if route.Type != "" { routeType = route.Type } - if rt.directors[policy][routeType] == nil { - rt.directors[policy][routeType] = make(map[string][]RoutingInfo) + if rt.rewriters[policy][routeType] == nil { + rt.rewriters[policy][routeType] = make(map[string][]RoutingInfo) } - if rt.directors[policy][routeType][route.Method] == nil { - rt.directors[policy][routeType][route.Method] = make([]RoutingInfo, 0) + if rt.rewriters[policy][routeType][route.Method] == nil { + rt.rewriters[policy][routeType][route.Method] = make([]RoutingInfo, 0) } reg := registry.GetRegistry() sel := selector.NewSelector(selector.Registry(reg)) - rt.directors[policy][routeType][route.Method] = append(rt.directors[policy][routeType][route.Method], RoutingInfo{ + rt.rewriters[policy][routeType][route.Method] = append(rt.rewriters[policy][routeType][route.Method], RoutingInfo{ endpoint: route.Endpoint, unprotected: route.Unprotected, - director: func(req *http.Request) { + rewrite: func(req *httputil.ProxyRequest) { if route.Service != "" { // select next node next, err := sel.Select(route.Service) @@ -146,32 +147,33 @@ func (rt Router) addHost(policy string, target *url.URL, route config.Route) { Msg("could not select next node") return // TODO error? fallback to target.Host & Scheme? } - req.URL.Host = node.Address - req.URL.Scheme = node.Metadata["protocol"] // TODO check property exists? + req.Out.URL.Host = node.Address + req.Out.URL.Scheme = node.Metadata["protocol"] // TODO check property exists? if node.Metadata["use_tls"] == "true" { - req.URL.Scheme = "https" + req.Out.URL.Scheme = "https" } } else { - req.URL.Host = target.Host - req.URL.Scheme = target.Scheme + req.Out.URL.Host = target.Host + req.Out.URL.Scheme = target.Scheme } - // Apache deployments host addresses need to match on req.Host and req.URL.Host + // Apache deployments host addresses need to match on req.Out.Host and req.Out.URL.Host // see https://stackoverflow.com/questions/34745654/golang-reverseproxy-with-apache2-sni-hostname-error if route.ApacheVHost { - req.Host = target.Host + req.Out.Host = target.Host } - req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) - if targetQuery == "" || req.URL.RawQuery == "" { - req.URL.RawQuery = targetQuery + req.URL.RawQuery + req.Out.URL.Path = singleJoiningSlash(target.Path, req.Out.URL.Path) + if targetQuery == "" || req.Out.URL.RawQuery == "" { + req.Out.URL.RawQuery = targetQuery + req.Out.URL.RawQuery } else { - req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + req.Out.URL.RawQuery = targetQuery + "&" + req.Out.URL.RawQuery } - if _, ok := req.Header["User-Agent"]; !ok { + if _, ok := req.Out.Header["User-Agent"]; !ok { // explicitly disable User-Agent so it's not set to default value - req.Header.Set("User-Agent", "") + req.Out.Header.Set("User-Agent", "") } + req.SetXForwarded() }, }) } @@ -184,7 +186,7 @@ func (rt Router) Route(r *http.Request) (RoutingInfo, bool) { return noInfo, false } - if _, ok := rt.directors[pol]; !ok { + if _, ok := rt.rewriters[pol]; !ok { rt.logger. Error(). Str("policy", pol). @@ -193,7 +195,7 @@ func (rt Router) Route(r *http.Request) (RoutingInfo, bool) { } method := "" - // find matching director + // find matching rewrite hook for _, rtype := range config.RouteTypes { var handler func(string, url.URL) bool switch rtype { @@ -206,12 +208,12 @@ func (rt Router) Route(r *http.Request) (RoutingInfo, bool) { default: handler = prefixRouteMatcher } - if rt.directors[pol][rtype][r.Method] != nil { + if rt.rewriters[pol][rtype][r.Method] != nil { // use specific method method = r.Method } - for _, ri := range rt.directors[pol][rtype][method] { + for _, ri := range rt.rewriters[pol][rtype][method] { if handler(ri.endpoint, *r.URL) { rt.logger.Debug(). Str("policy", pol). @@ -219,17 +221,17 @@ func (rt Router) Route(r *http.Request) (RoutingInfo, bool) { Str("prefix", ri.endpoint). Str("path", r.URL.Path). Str("routeType", string(rtype)). - Msg("director found") + Msg("rewrite hook found") return ri, true } } } - // override default director with root. If any - if ri := rt.directors[pol][config.PrefixRoute][method][0]; ri.endpoint == "/" { // try specific method + // override default rewrite hook with root. If any + if ri := rt.rewriters[pol][config.PrefixRoute][method][0]; ri.endpoint == "/" { // try specific method return ri, true - } else if ri := rt.directors[pol][config.PrefixRoute][""][0]; ri.endpoint == "/" { // fallback to unspecific method + } else if ri := rt.rewriters[pol][config.PrefixRoute][""][0]; ri.endpoint == "/" { // fallback to unspecific method return ri, true } @@ -237,7 +239,7 @@ func (rt Router) Route(r *http.Request) (RoutingInfo, bool) { Warn(). Str("policy", pol). Str("path", r.URL.Path). - Msg("no director found") + Msg("no rewrite hook found") return noInfo, false } diff --git a/services/proxy/pkg/router/router_test.go b/services/proxy/pkg/router/router_test.go index 3bbebc868..2652db56a 100644 --- a/services/proxy/pkg/router/router_test.go +++ b/services/proxy/pkg/router/router_test.go @@ -1,9 +1,11 @@ package router import ( + "context" "fmt" "net/http" "net/http/httptest" + "net/http/httputil" "net/url" "testing" @@ -146,10 +148,14 @@ func TestRouter(t *testing.T) { t.Errorf("TestRouter route flag unprotected expected to be %t got %t", test.unprotected, routingInfo.IsRouteUnprotected()) } - routingInfo.Director()(r) + pr := &httputil.ProxyRequest{ + In: r, + Out: r.Clone(context.Background()), + } + routingInfo.Rewrite()(pr) - if r.URL.Host != test.target { - t.Errorf("TestRouter got host %s expected %s", r.URL.Host, test.target) + if pr.Out.URL.Host != test.target { + t.Errorf("TestRouter got host %s expected %s", pr.Out.URL.Host, test.target) } } }