mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-04 19:29:49 -06:00
pass only request instead of context
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
@@ -37,9 +37,6 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := req.Context()
|
||||
claims := oidc.FromContext(ctx)
|
||||
|
||||
selectorCookieName := ""
|
||||
if m.policySelector.Regex != nil {
|
||||
selectorCookieName = m.policySelector.Regex.SelectorCookieName
|
||||
@@ -50,14 +47,14 @@ func (m selectorCookie) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
_, err := req.Cookie(selectorCookieName)
|
||||
if err != nil {
|
||||
// no cookie there - try to add one
|
||||
if claims != nil {
|
||||
if oidc.FromContext(req.Context()) != nil {
|
||||
|
||||
selectorFunc, err := policy.LoadSelector(&m.policySelector)
|
||||
if err != nil {
|
||||
m.logger.Err(err)
|
||||
}
|
||||
|
||||
selector, err := selectorFunc(ctx, req)
|
||||
selector, err := selectorFunc(req)
|
||||
if err != nil {
|
||||
m.logger.Err(err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package policy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
@@ -52,7 +51,7 @@ const (
|
||||
// }
|
||||
// ]
|
||||
//}
|
||||
type Selector func(ctx context.Context, r *http.Request) (string, error)
|
||||
type Selector func(r *http.Request) (string, error)
|
||||
|
||||
// LoadSelector constructs a specific policy-selector from a given configuration
|
||||
func LoadSelector(cfg *config.PolicySelector) (Selector, error) {
|
||||
@@ -113,7 +112,7 @@ func LoadSelector(cfg *config.PolicySelector) (Selector, error) {
|
||||
// "static": {"policy" : "ocis"}
|
||||
// },
|
||||
func NewStaticSelector(cfg *config.StaticSelectorConf) Selector {
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
return func(r *http.Request) (s string, err error) {
|
||||
return cfg.Policy, nil
|
||||
}
|
||||
}
|
||||
@@ -132,9 +131,9 @@ func NewStaticSelector(cfg *config.StaticSelectorConf) Selector {
|
||||
// thus have an entry in ocis-accounts. All users without accounts entry are routed to the legacy ownCloud10 instance.
|
||||
func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.AccountsService) Selector {
|
||||
var acc = ss
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
return func(r *http.Request) (s string, err error) {
|
||||
var claims map[string]interface{}
|
||||
if claims = oidc.FromContext(ctx); claims == nil {
|
||||
if claims = oidc.FromContext(r.Context()); claims == nil {
|
||||
return cfg.UnauthenticatedPolicy, nil
|
||||
}
|
||||
|
||||
@@ -145,7 +144,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account
|
||||
return cfg.AccNotFoundPolicy, nil
|
||||
}
|
||||
|
||||
if _, err := acc.GetAccount(ctx, &accounts.GetAccountRequest{Id: userID}); err != nil {
|
||||
if _, err := acc.GetAccount(r.Context(), &accounts.GetAccountRequest{Id: userID}); err != nil {
|
||||
return cfg.AccNotFoundPolicy, nil
|
||||
}
|
||||
return cfg.AccFoundPolicy, nil
|
||||
@@ -164,7 +163,7 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account
|
||||
//
|
||||
// This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and
|
||||
func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector {
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
return func(r *http.Request) (s string, err error) {
|
||||
// use cookie first if provided
|
||||
selectorCookie, err := r.Cookie(cfg.SelectorCookieName)
|
||||
if err == nil {
|
||||
@@ -172,7 +171,7 @@ func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector {
|
||||
}
|
||||
|
||||
// if no cookie is present, try to route by selector
|
||||
if claims := oidc.FromContext(ctx); claims != nil {
|
||||
if claims := oidc.FromContext(r.Context()); claims != nil {
|
||||
if p, ok := claims[oidc.OcisRoutingPolicy].(string); ok && p != "" {
|
||||
// TODO check we know the routing policy?
|
||||
return p, nil
|
||||
@@ -213,7 +212,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector {
|
||||
policy: cfg.MatchesPolicies[i].Policy,
|
||||
})
|
||||
}
|
||||
return func(ctx context.Context, r *http.Request) (s string, err error) {
|
||||
return func(r *http.Request) (s string, err error) {
|
||||
// use cookie first if provided
|
||||
selectorCookie, err := r.Cookie(cfg.SelectorCookieName)
|
||||
if err == nil {
|
||||
@@ -221,7 +220,7 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector {
|
||||
}
|
||||
|
||||
// if no cookie is present, try to route by selector
|
||||
if u, ok := revauser.ContextGetUser(ctx); ok {
|
||||
if u, ok := revauser.ContextGetUser(r.Context()); ok {
|
||||
for i := range regexRules {
|
||||
switch regexRules[i].property {
|
||||
case "mail":
|
||||
|
||||
@@ -3,6 +3,7 @@ package policy
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/asim/go-micro/v3/client"
|
||||
@@ -47,9 +48,9 @@ func TestLoadSelector(t *testing.T) {
|
||||
|
||||
func TestStaticSelector(t *testing.T) {
|
||||
sel := NewStaticSelector(&config.StaticSelectorConf{Policy: "ocis"})
|
||||
ctx := context.Background()
|
||||
req := httptest.NewRequest("GET", "https://example.org/foo", nil)
|
||||
want := "ocis"
|
||||
got, err := sel(ctx)
|
||||
got, err := sel(req)
|
||||
if got != want {
|
||||
t.Errorf("Expected policy %v got %v", want, got)
|
||||
}
|
||||
@@ -61,7 +62,7 @@ func TestStaticSelector(t *testing.T) {
|
||||
sel = NewStaticSelector(&config.StaticSelectorConf{Policy: "foo"})
|
||||
|
||||
want = "foo"
|
||||
got, err = sel(ctx)
|
||||
got, err = sel(req)
|
||||
if got != want {
|
||||
t.Errorf("Expected policy %v got %v", want, got)
|
||||
}
|
||||
@@ -93,9 +94,11 @@ func TestMigrationSelector(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
sut := NewMigrationSelector(&cfg, mockAccSvc(tc.AccSvcShouldReturnError))
|
||||
ctx := oidc.NewContext(context.Background(), tc.Claims)
|
||||
r := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
ctx := oidc.NewContext(r.Context(), tc.Claims)
|
||||
nr := r.WithContext(ctx)
|
||||
|
||||
got, err := sut(ctx)
|
||||
got, err := sut(nr)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -141,7 +144,9 @@ func TestClaimsSelector(t *testing.T) {
|
||||
{"claim-value", oidc.NewContext(context.Background(), map[string]interface{}{oidc.OcisRoutingPolicy: "ocis.routing.policy-value"}), "ocis.routing.policy-value"},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
got, err := sel(tc.Context)
|
||||
r := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
nr := r.WithContext(tc.Context)
|
||||
got, err := sel(nr)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
@@ -182,7 +187,9 @@ func TestRegexSelector(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
got, err := sel(tc.Context)
|
||||
r := httptest.NewRequest("GET", "https://example.com", nil)
|
||||
nr := r.WithContext(tc.Context)
|
||||
got, err := sel(nr)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
|
||||
|
||||
if options.Config.PolicySelector == nil {
|
||||
firstPolicy := options.Config.Policies[0].Name
|
||||
rp.logger.Warn().Msgf("policy-selector not configured. Will always use first policy: '%v'", firstPolicy)
|
||||
rp.logger.Warn().Str("policy", firstPolicy).Msg("policy-selector not configured. Will always use first policy")
|
||||
options.Config.PolicySelector = &config.PolicySelector{
|
||||
Static: &config.StaticSelectorConf{
|
||||
Policy: firstPolicy,
|
||||
@@ -91,9 +91,10 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
|
||||
uri, err := url.Parse(route.Backend)
|
||||
if err != nil {
|
||||
rp.logger.
|
||||
Fatal().
|
||||
Fatal(). // fail early on misconfiguration
|
||||
Err(err).
|
||||
Msgf("malformed url: %v", route.Backend)
|
||||
Str("backend", route.Backend).
|
||||
Msg("malformed url")
|
||||
}
|
||||
|
||||
rp.logger.
|
||||
@@ -109,16 +110,17 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy {
|
||||
}
|
||||
|
||||
func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) {
|
||||
pol, err := p.PolicySelector(r.Context(), r)
|
||||
pol, err := p.PolicySelector(r)
|
||||
if err != nil {
|
||||
p.logger.Error().Msgf("Error while selecting pol %v", err)
|
||||
p.logger.Error().Err(err).Msg("Error while selecting pol")
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := p.Directors[pol]; !ok {
|
||||
p.logger.
|
||||
Error().
|
||||
Msgf("policy %v is not configured", pol)
|
||||
Str("policy", pol).
|
||||
Msg("policy is not configured")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -247,7 +249,7 @@ func (p MultiHostReverseProxy) queryRouteMatcher(endpoint string, target url.URL
|
||||
func (p *MultiHostReverseProxy) regexRouteMatcher(pattern string, target url.URL) bool {
|
||||
matched, err := regexp.MatchString(pattern, target.String())
|
||||
if err != nil {
|
||||
p.logger.Warn().Err(err).Msgf("regex with pattern %s failed", pattern)
|
||||
p.logger.Warn().Err(err).Str("pattern", pattern).Msg("regex with pattern failed")
|
||||
}
|
||||
return matched
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user