diff --git a/proxy/pkg/middleware/selector_cookie.go b/proxy/pkg/middleware/selector_cookie.go index 1ef8a70c39..04cf8cfbc3 100644 --- a/proxy/pkg/middleware/selector_cookie.go +++ b/proxy/pkg/middleware/selector_cookie.go @@ -37,9 +37,6 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } - ctx := req.Context() - claims := oidc.FromContext(ctx) - selectorCookieName := "" if m.policySelector.Regex != nil { selectorCookieName = m.policySelector.Regex.SelectorCookieName @@ -50,14 +47,14 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) { _, err := req.Cookie(selectorCookieName) if err != nil { // no cookie there - try to add one - if claims != nil { + if oidc.FromContext(req.Context()) != nil { selectorFunc, err := policy.LoadSelector(&m.policySelector) if err != nil { m.logger.Err(err) } - selector, err := selectorFunc(ctx, req) + selector, err := selectorFunc(req) if err != nil { m.logger.Err(err) } diff --git a/proxy/pkg/proxy/policy/selector.go b/proxy/pkg/proxy/policy/selector.go index 966f105264..ef5a1b556a 100644 --- a/proxy/pkg/proxy/policy/selector.go +++ b/proxy/pkg/proxy/policy/selector.go @@ -1,7 +1,6 @@ package policy import ( - "context" "fmt" "net/http" "regexp" @@ -52,7 +51,7 @@ const ( // } // ] //} -type Selector func(ctx context.Context, r *http.Request) (string, error) +type Selector func(r *http.Request) (string, error) // LoadSelector constructs a specific policy-selector from a given configuration func LoadSelector(cfg *config.PolicySelector) (Selector, error) { @@ -113,7 +112,7 @@ func LoadSelector(cfg *config.PolicySelector) (Selector, error) { // "static": {"policy" : "ocis"} // }, func NewStaticSelector(cfg *config.StaticSelectorConf) Selector { - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(r *http.Request) (s string, err error) { return cfg.Policy, nil } } @@ -132,9 +131,9 @@ func NewStaticSelector(cfg *config.StaticSelectorConf) Selector { // thus have an entry in ocis-accounts. All users without accounts entry are routed to the legacy ownCloud10 instance. func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.AccountsService) Selector { var acc = ss - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(r *http.Request) (s string, err error) { var claims map[string]interface{} - if claims = oidc.FromContext(ctx); claims == nil { + if claims = oidc.FromContext(r.Context()); claims == nil { return cfg.UnauthenticatedPolicy, nil } @@ -145,7 +144,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account return cfg.AccNotFoundPolicy, nil } - if _, err := acc.GetAccount(ctx, &accounts.GetAccountRequest{Id: userID}); err != nil { + if _, err := acc.GetAccount(r.Context(), &accounts.GetAccountRequest{Id: userID}); err != nil { return cfg.AccNotFoundPolicy, nil } return cfg.AccFoundPolicy, nil @@ -164,7 +163,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account // // This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector { - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(r *http.Request) (s string, err error) { // use cookie first if provided selectorCookie, err := r.Cookie(cfg.SelectorCookieName) if err == nil { @@ -172,7 +171,7 @@ func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector { } // if no cookie is present, try to route by selector - if claims := oidc.FromContext(ctx); claims != nil { + if claims := oidc.FromContext(r.Context()); claims != nil { if p, ok := claims[oidc.OcisRoutingPolicy].(string); ok && p != "" { // TODO check we know the routing policy? return p, nil @@ -213,7 +212,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector { policy: cfg.MatchesPolicies[i].Policy, }) } - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(r *http.Request) (s string, err error) { // use cookie first if provided selectorCookie, err := r.Cookie(cfg.SelectorCookieName) if err == nil { @@ -221,7 +220,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector { } // if no cookie is present, try to route by selector - if u, ok := revauser.ContextGetUser(ctx); ok { + if u, ok := revauser.ContextGetUser(r.Context()); ok { for i := range regexRules { switch regexRules[i].property { case "mail": diff --git a/proxy/pkg/proxy/policy/selector_test.go b/proxy/pkg/proxy/policy/selector_test.go index a22dba988e..5e65674c95 100644 --- a/proxy/pkg/proxy/policy/selector_test.go +++ b/proxy/pkg/proxy/policy/selector_test.go @@ -3,6 +3,7 @@ package policy import ( "context" "fmt" + "net/http/httptest" "testing" "github.com/asim/go-micro/v3/client" @@ -47,9 +48,9 @@ func TestLoadSelector(t *testing.T) { func TestStaticSelector(t *testing.T) { sel := NewStaticSelector(&config.StaticSelectorConf{Policy: "ocis"}) - ctx := context.Background() + req := httptest.NewRequest("GET", "https://example.org/foo", nil) want := "ocis" - got, err := sel(ctx) + got, err := sel(req) if got != want { t.Errorf("Expected policy %v got %v", want, got) } @@ -61,7 +62,7 @@ func TestStaticSelector(t *testing.T) { sel = NewStaticSelector(&config.StaticSelectorConf{Policy: "foo"}) want = "foo" - got, err = sel(ctx) + got, err = sel(req) if got != want { t.Errorf("Expected policy %v got %v", want, got) } @@ -93,9 +94,11 @@ func TestMigrationSelector(t *testing.T) { for _, tc := range tests { tc := tc sut := NewMigrationSelector(&cfg, mockAccSvc(tc.AccSvcShouldReturnError)) - ctx := oidc.NewContext(context.Background(), tc.Claims) + r := httptest.NewRequest("GET", "https://example.com", nil) + ctx := oidc.NewContext(r.Context(), tc.Claims) + nr := r.WithContext(ctx) - got, err := sut(ctx) + got, err := sut(nr) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -141,7 +144,9 @@ func TestClaimsSelector(t *testing.T) { {"claim-value", oidc.NewContext(context.Background(), map[string]interface{}{oidc.OcisRoutingPolicy: "ocis.routing.policy-value"}), "ocis.routing.policy-value"}, } for _, tc := range tests { - got, err := sel(tc.Context) + r := httptest.NewRequest("GET", "https://example.com", nil) + nr := r.WithContext(tc.Context) + got, err := sel(nr) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -182,7 +187,9 @@ func TestRegexSelector(t *testing.T) { for _, tc := range tests { tc := tc // capture range variable t.Run(tc.Name, func(t *testing.T) { - got, err := sel(tc.Context) + r := httptest.NewRequest("GET", "https://example.com", nil) + nr := r.WithContext(tc.Context) + got, err := sel(nr) if err != nil { t.Errorf("Unexpected error: %v", err) } diff --git a/proxy/pkg/proxy/proxy.go b/proxy/pkg/proxy/proxy.go index bb8c0c76a9..9193252e13 100644 --- a/proxy/pkg/proxy/proxy.go +++ b/proxy/pkg/proxy/proxy.go @@ -66,7 +66,7 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy { if options.Config.PolicySelector == nil { firstPolicy := options.Config.Policies[0].Name - rp.logger.Warn().Msgf("policy-selector not configured. Will always use first policy: '%v'", firstPolicy) + rp.logger.Warn().Str("policy", firstPolicy).Msg("policy-selector not configured. Will always use first policy") options.Config.PolicySelector = &config.PolicySelector{ Static: &config.StaticSelectorConf{ Policy: firstPolicy, @@ -91,9 +91,10 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy { uri, err := url.Parse(route.Backend) if err != nil { rp.logger. - Fatal(). + Fatal(). // fail early on misconfiguration Err(err). - Msgf("malformed url: %v", route.Backend) + Str("backend", route.Backend). + Msg("malformed url") } rp.logger. @@ -109,16 +110,17 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy { } func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) { - pol, err := p.PolicySelector(r.Context(), r) + pol, err := p.PolicySelector(r) if err != nil { - p.logger.Error().Msgf("Error while selecting pol %v", err) + p.logger.Error().Err(err).Msg("Error while selecting pol") return } if _, ok := p.Directors[pol]; !ok { p.logger. Error(). - Msgf("policy %v is not configured", pol) + Str("policy", pol). + Msg("policy is not configured") return } @@ -247,7 +249,7 @@ func (p MultiHostReverseProxy) queryRouteMatcher(endpoint string, target url.URL func (p *MultiHostReverseProxy) regexRouteMatcher(pattern string, target url.URL) bool { matched, err := regexp.MatchString(pattern, target.String()) if err != nil { - p.logger.Warn().Err(err).Msgf("regex with pattern %s failed", pattern) + p.logger.Warn().Err(err).Str("pattern", pattern).Msg("regex with pattern failed") } return matched }