pass only request instead of context

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2021-07-23 09:03:14 +00:00
parent a0dce56480
commit bea986fe26
4 changed files with 34 additions and 29 deletions

View File

@@ -37,9 +37,6 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
ctx := req.Context()
claims := oidc.FromContext(ctx)
selectorCookieName := ""
if m.policySelector.Regex != nil {
selectorCookieName = m.policySelector.Regex.SelectorCookieName
@@ -50,14 +47,14 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) {
_, err := req.Cookie(selectorCookieName)
if err != nil {
// no cookie there - try to add one
if claims != nil {
if oidc.FromContext(req.Context()) != nil {
selectorFunc, err := policy.LoadSelector(&m.policySelector)
if err != nil {
m.logger.Err(err)
}
selector, err := selectorFunc(ctx, req)
selector, err := selectorFunc(req)
if err != nil {
m.logger.Err(err)
}

View File

@@ -1,7 +1,6 @@
package policy
import (
"context"
"fmt"
"net/http"
"regexp"
@@ -52,7 +51,7 @@ const (
// }
// ]
//}
type Selector func(ctx context.Context, r *http.Request) (string, error)
type Selector func(r *http.Request) (string, error)
// LoadSelector constructs a specific policy-selector from a given configuration
func LoadSelector(cfg *config.PolicySelector) (Selector, error) {
@@ -113,7 +112,7 @@ func LoadSelector(cfg *config.PolicySelector) (Selector, error) {
// "static": {"policy" : "ocis"}
// },
func NewStaticSelector(cfg *config.StaticSelectorConf) Selector {
return func(ctx context.Context, r *http.Request) (s string, err error) {
return func(r *http.Request) (s string, err error) {
return cfg.Policy, nil
}
}
@@ -132,9 +131,9 @@ func NewStaticSelector(cfg *config.StaticSelectorConf) Selector {
// thus have an entry in ocis-accounts. All users without accounts entry are routed to the legacy ownCloud10 instance.
func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.AccountsService) Selector {
var acc = ss
return func(ctx context.Context, r *http.Request) (s string, err error) {
return func(r *http.Request) (s string, err error) {
var claims map[string]interface{}
if claims = oidc.FromContext(ctx); claims == nil {
if claims = oidc.FromContext(r.Context()); claims == nil {
return cfg.UnauthenticatedPolicy, nil
}
@@ -145,7 +144,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account
return cfg.AccNotFoundPolicy, nil
}
if _, err := acc.GetAccount(ctx, &accounts.GetAccountRequest{Id: userID}); err != nil {
if _, err := acc.GetAccount(r.Context(), &accounts.GetAccountRequest{Id: userID}); err != nil {
return cfg.AccNotFoundPolicy, nil
}
return cfg.AccFoundPolicy, nil
@@ -164,7 +163,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account
//
// This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and
func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector {
return func(ctx context.Context, r *http.Request) (s string, err error) {
return func(r *http.Request) (s string, err error) {
// use cookie first if provided
selectorCookie, err := r.Cookie(cfg.SelectorCookieName)
if err == nil {
@@ -172,7 +171,7 @@ func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector {
}
// if no cookie is present, try to route by selector
if claims := oidc.FromContext(ctx); claims != nil {
if claims := oidc.FromContext(r.Context()); claims != nil {
if p, ok := claims[oidc.OcisRoutingPolicy].(string); ok && p != "" {
// TODO check we know the routing policy?
return p, nil
@@ -213,7 +212,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector {
policy: cfg.MatchesPolicies[i].Policy,
})
}
return func(ctx context.Context, r *http.Request) (s string, err error) {
return func(r *http.Request) (s string, err error) {
// use cookie first if provided
selectorCookie, err := r.Cookie(cfg.SelectorCookieName)
if err == nil {
@@ -221,7 +220,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector {
}
// if no cookie is present, try to route by selector
if u, ok := revauser.ContextGetUser(ctx); ok {
if u, ok := revauser.ContextGetUser(r.Context()); ok {
for i := range regexRules {
switch regexRules[i].property {
case "mail":

View File

@@ -3,6 +3,7 @@ package policy
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"github.com/asim/go-micro/v3/client"
@@ -47,9 +48,9 @@ func TestLoadSelector(t *testing.T) {
func TestStaticSelector(t *testing.T) {
sel := NewStaticSelector(&config.StaticSelectorConf{Policy: "ocis"})
ctx := context.Background()
req := httptest.NewRequest("GET", "https://example.org/foo", nil)
want := "ocis"
got, err := sel(ctx)
got, err := sel(req)
if got != want {
t.Errorf("Expected policy %v got %v", want, got)
}
@@ -61,7 +62,7 @@ func TestStaticSelector(t *testing.T) {
sel = NewStaticSelector(&config.StaticSelectorConf{Policy: "foo"})
want = "foo"
got, err = sel(ctx)
got, err = sel(req)
if got != want {
t.Errorf("Expected policy %v got %v", want, got)
}
@@ -93,9 +94,11 @@ func TestMigrationSelector(t *testing.T) {
for _, tc := range tests {
tc := tc
sut := NewMigrationSelector(&cfg, mockAccSvc(tc.AccSvcShouldReturnError))
ctx := oidc.NewContext(context.Background(), tc.Claims)
r := httptest.NewRequest("GET", "https://example.com", nil)
ctx := oidc.NewContext(r.Context(), tc.Claims)
nr := r.WithContext(ctx)
got, err := sut(ctx)
got, err := sut(nr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -141,7 +144,9 @@ func TestClaimsSelector(t *testing.T) {
{"claim-value", oidc.NewContext(context.Background(), map[string]interface{}{oidc.OcisRoutingPolicy: "ocis.routing.policy-value"}), "ocis.routing.policy-value"},
}
for _, tc := range tests {
got, err := sel(tc.Context)
r := httptest.NewRequest("GET", "https://example.com", nil)
nr := r.WithContext(tc.Context)
got, err := sel(nr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}
@@ -182,7 +187,9 @@ func TestRegexSelector(t *testing.T) {
for _, tc := range tests {
tc := tc // capture range variable
t.Run(tc.Name, func(t *testing.T) {
got, err := sel(tc.Context)
r := httptest.NewRequest("GET", "https://example.com", nil)
nr := r.WithContext(tc.Context)
got, err := sel(nr)
if err != nil {
t.Errorf("Unexpected error: %v", err)
}

View File

@@ -66,7 +66,7 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
if options.Config.PolicySelector == nil {
firstPolicy := options.Config.Policies[0].Name
rp.logger.Warn().Msgf("policy-selector not configured. Will always use first policy: '%v'", firstPolicy)
rp.logger.Warn().Str("policy", firstPolicy).Msg("policy-selector not configured. Will always use first policy")
options.Config.PolicySelector = &config.PolicySelector{
Static: &config.StaticSelectorConf{
Policy: firstPolicy,
@@ -91,9 +91,10 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
uri, err := url.Parse(route.Backend)
if err != nil {
rp.logger.
Fatal().
Fatal(). // fail early on misconfiguration
Err(err).
Msgf("malformed url: %v", route.Backend)
Str("backend", route.Backend).
Msg("malformed url")
}
rp.logger.
@@ -109,16 +110,17 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
}
func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) {
pol, err := p.PolicySelector(r.Context(), r)
pol, err := p.PolicySelector(r)
if err != nil {
p.logger.Error().Msgf("Error while selecting pol %v", err)
p.logger.Error().Err(err).Msg("Error while selecting pol")
return
}
if _, ok := p.Directors[pol]; !ok {
p.logger.
Error().
Msgf("policy %v is not configured", pol)
Str("policy", pol).
Msg("policy is not configured")
return
}
@@ -247,7 +249,7 @@ func (p MultiHostReverseProxy) queryRouteMatcher(endpoint string, target url.URL
func (p *MultiHostReverseProxy) regexRouteMatcher(pattern string, target url.URL) bool {
matched, err := regexp.MatchString(pattern, target.String())
if err != nil {
p.logger.Warn().Err(err).Msgf("regex with pattern %s failed", pattern)
p.logger.Warn().Err(err).Str("pattern", pattern).Msg("regex with pattern failed")
}
return matched
}