diff --git a/proxy/pkg/config/config.go b/proxy/pkg/config/config.go index b1c61664d8..4015daa949 100644 --- a/proxy/pkg/config/config.go +++ b/proxy/pkg/config/config.go @@ -1,6 +1,8 @@ package config -import "context" +import ( + "context" +) // Log defines the available logging configuration. type Log struct { @@ -176,9 +178,15 @@ type ClaimsSelectorConf struct { // RegexSelectorConf is the config for the regex-selector type RegexSelectorConf struct { - DefaultPolicy string `mapstructure:"default_policy"` - MatchesPolicies map[string]map[string]string `mapstructure:"matches_policies"` - UnauthenticatedPolicy string `mapstructure:"unauthenticated_policy"` + DefaultPolicy string `mapstructure:"default_policy"` + MatchesPolicies []RegexRuleConf `mapstructure:"matches_policies"` + UnauthenticatedPolicy string `mapstructure:"unauthenticated_policy"` +} +type RegexRuleConf struct { + Priority int `mapstructure:"priority"` + Property string `mapstructure:"property"` + Match string `mapstructure:"match"` + Policy string `mapstructure:"policy"` } // New initializes a new configuration diff --git a/proxy/pkg/proxy/policy/selector.go b/proxy/pkg/proxy/policy/selector.go index ac0d3f5366..3ced3d18af 100644 --- a/proxy/pkg/proxy/policy/selector.go +++ b/proxy/pkg/proxy/policy/selector.go @@ -3,8 +3,8 @@ package policy import ( "context" "fmt" - "net/http" "regexp" + "sort" "github.com/asim/go-micro/plugins/client/grpc/v3" revauser "github.com/cs3org/reva/pkg/user" @@ -15,9 +15,9 @@ import ( var ( // ErrMultipleSelectors in case there is more then one selector configured. - ErrMultipleSelectors = fmt.Errorf("only one type of policy-selector (static or migration) can be configured") + ErrMultipleSelectors = fmt.Errorf("only one type of policy-selector (static, migration, claim or regex) can be configured") // ErrSelectorConfigIncomplete if policy_selector conf is missing - ErrSelectorConfigIncomplete = fmt.Errorf("missing either \"static\" or \"migration\" configuration in policy_selector config ") + ErrSelectorConfigIncomplete = fmt.Errorf("missing either \"static\", \"migration\", \"claim\" or \"regex\" configuration in policy_selector config ") // ErrUnexpectedConfigError unexpected config error ErrUnexpectedConfigError = fmt.Errorf("could not initialize policy-selector for given config") ) @@ -47,15 +47,29 @@ var ( // } // ] //} -type Selector func(ctx context.Context, r *http.Request) (string, error) +type Selector func(ctx context.Context) (string, error) // LoadSelector constructs a specific policy-selector from a given configuration func LoadSelector(cfg *config.PolicySelector) (Selector, error) { - if cfg.Migration != nil && cfg.Static != nil { + selCount := 0 + + if cfg.Migration != nil { + selCount++ + } + if cfg.Static != nil { + selCount++ + } + if cfg.Claims != nil { + selCount++ + } + if cfg.Regex != nil { + selCount++ + } + if selCount > 1 { return nil, ErrMultipleSelectors } - if cfg.Migration == nil && cfg.Static == nil { + if cfg.Migration == nil && cfg.Static == nil && cfg.Claims == nil && cfg.Regex == nil { return nil, ErrSelectorConfigIncomplete } @@ -88,7 +102,7 @@ func LoadSelector(cfg *config.PolicySelector) (Selector, error) { // "static": {"policy" : "ocis"} // }, func NewStaticSelector(cfg *config.StaticSelectorConf) Selector { - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(ctx context.Context) (s string, err error) { return cfg.Policy, nil } } @@ -107,9 +121,9 @@ func NewStaticSelector(cfg *config.StaticSelectorConf) Selector { // thus have an entry in ocis-accounts. All users without accounts entry are routed to the legacy ownCloud10 instance. func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.AccountsService) Selector { var acc = ss - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(ctx context.Context) (s string, err error) { var claims map[string]interface{} - if claims = oidc.FromContext(r.Context()); claims == nil { + if claims = oidc.FromContext(ctx); claims == nil { return cfg.UnauthenticatedPolicy, nil } @@ -139,8 +153,8 @@ func NewMigrationSelector(cfg *config.MigrationSelectorConf, ss accounts.Account // // This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector { - return func(ctx context.Context, r *http.Request) (s string, err error) { - if claims := oidc.FromContext(r.Context()); claims != nil { + return func(ctx context.Context) (s string, err error) { + if claims := oidc.FromContext(ctx); claims != nil { if p, ok := claims[oidc.OcisRoutingPolicy].(string); ok && p != "" { // TODO check we know the routing policy? return p, nil @@ -156,58 +170,46 @@ func NewClaimsSelector(cfg *config.ClaimsSelectorConf) Selector { // The policy for each case is configurable: // "policy_selector": { // "migration": { -// "matches_policies": { -// "mail": { -// "marie@example.com": "oc10" -// "[^@]+@example.com": "ocis" -// }, -// "username": { -// "(einstein|feynman)": "ocis" -// "marie": "oc10" -// }, -// "id": { -// "4c510ada-c86b-4815-8820-42cdf82c3d51": "ocis" -// "f7fbf8c8-139b-4376-b307-cf0a8c2d0d9c": "oc10" -// }, -// }, +// "matches_policies": [ +// {"priority": 10, "property": "mail", "match": "marie@example.org", "policy": "ocis"}, +// {"priority": 20, "property": "mail", "match": "[^@]+@example.org", "policy": "oc10"}, +// {"priority": 30, "property": "username", "match": "(einstein|feynman)", "policy": "ocis"}, +// {"priority": 40, "property": "username", "match": ".+", "policy": "oc10"}, +// {"priority": 50, "property": "id", "match": "4c510ada-c86b-4815-8820-42cdf82c3d51", "policy": "ocis"}, +// {"priority": 60, "property": "id", "match": "f7fbf8c8-139b-4376-b307-cf0a8c2d0d9c", "policy": "oc10"}, +// ], // "unauthenticated_policy": "oc10" // } // }, // // This selector can be used in migration-scenarios where some users have already migrated from ownCloud10 to OCIS and func NewRegexSelector(cfg *config.RegexSelectorConf) Selector { - var mailRegexPolicies map[*regexp.Regexp]string - for m, p := range cfg.MatchesPolicies["mail"] { - mailRegexPolicies[regexp.MustCompile(m)] = p + regexRules := []*regexRule{} + sort.Slice(cfg.MatchesPolicies, func(i, j int) bool { + return cfg.MatchesPolicies[i].Priority < cfg.MatchesPolicies[j].Priority + }) + for i := range cfg.MatchesPolicies { + regexRules = append(regexRules, ®exRule{ + property: cfg.MatchesPolicies[i].Property, + rule: regexp.MustCompile(cfg.MatchesPolicies[i].Match), + policy: cfg.MatchesPolicies[i].Policy, + }) } - var usernameRegexPolicies map[*regexp.Regexp]string - for m, p := range cfg.MatchesPolicies["username"] { - usernameRegexPolicies[regexp.MustCompile(m)] = p - } - var idRegexPolicies map[*regexp.Regexp]string - for m, p := range cfg.MatchesPolicies["id"] { - usernameRegexPolicies[regexp.MustCompile(m)] = p - } - return func(ctx context.Context, r *http.Request) (s string, err error) { + return func(ctx context.Context) (s string, err error) { if u, ok := revauser.ContextGetUser(ctx); ok { - if u.Mail != "" { - for r, p := range mailRegexPolicies { - if r.MatchString(u.Mail) { - return p, nil + for i := range regexRules { + switch regexRules[i].property { + case "mail": + if regexRules[i].rule.MatchString(u.Mail) { + return regexRules[i].policy, nil } - } - } - if u.Username != "" { - for r, p := range usernameRegexPolicies { - if r.MatchString(u.Username) { - return p, nil + case "username": + if regexRules[i].rule.MatchString(u.Username) { + return regexRules[i].policy, nil } - } - } - if u.Id != nil && u.Id.OpaqueId != "" { - for r, p := range idRegexPolicies { - if r.MatchString(u.Id.OpaqueId) { - return p, nil + case "id": + if u.Id != nil && regexRules[i].rule.MatchString(u.Id.OpaqueId) { + return regexRules[i].policy, nil } } } @@ -217,3 +219,9 @@ func NewRegexSelector(cfg *config.RegexSelectorConf) Selector { return cfg.UnauthenticatedPolicy, nil } } + +type regexRule struct { + property string + rule *regexp.Regexp + policy string +} diff --git a/proxy/pkg/proxy/policy/selector_test.go b/proxy/pkg/proxy/policy/selector_test.go index 09127a9dcc..a22dba988e 100644 --- a/proxy/pkg/proxy/policy/selector_test.go +++ b/proxy/pkg/proxy/policy/selector_test.go @@ -3,10 +3,11 @@ package policy import ( "context" "fmt" - "net/http/httptest" "testing" "github.com/asim/go-micro/v3/client" + userv1beta1 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1" + revauser "github.com/cs3org/reva/pkg/user" "github.com/owncloud/ocis/accounts/pkg/proto/v0" "github.com/owncloud/ocis/ocis-pkg/oidc" "github.com/owncloud/ocis/proxy/pkg/config" @@ -23,29 +24,32 @@ func TestLoadSelector(t *testing.T) { AccNotFoundPolicy: "not_found", UnauthenticatedPolicy: "unauth", } + ccfg := &config.ClaimsSelectorConf{} + rcfg := &config.RegexSelectorConf{} table := []test{ {cfg: &config.PolicySelector{Static: sCfg, Migration: mcfg}, expectedErr: ErrMultipleSelectors}, + {cfg: &config.PolicySelector{Static: sCfg, Claims: ccfg, Regex: rcfg}, expectedErr: ErrMultipleSelectors}, {cfg: &config.PolicySelector{}, expectedErr: ErrSelectorConfigIncomplete}, {cfg: &config.PolicySelector{Static: sCfg}, expectedErr: nil}, {cfg: &config.PolicySelector{Migration: mcfg}, expectedErr: nil}, + {cfg: &config.PolicySelector{Claims: ccfg}, expectedErr: nil}, + {cfg: &config.PolicySelector{Regex: rcfg}, expectedErr: nil}, } for _, test := range table { _, err := LoadSelector(test.cfg) if err != test.expectedErr { - t.Fail() + t.Errorf("Unexpected error %v", err) } } } func TestStaticSelector(t *testing.T) { - ctx := context.Background() - req := httptest.NewRequest("GET", "https://example.org/foo", nil) sel := NewStaticSelector(&config.StaticSelectorConf{Policy: "ocis"}) - + ctx := context.Background() want := "ocis" - got, err := sel(ctx, req) + got, err := sel(ctx) if got != want { t.Errorf("Expected policy %v got %v", want, got) } @@ -57,7 +61,7 @@ func TestStaticSelector(t *testing.T) { sel = NewStaticSelector(&config.StaticSelectorConf{Policy: "foo"}) want = "foo" - got, err = sel(ctx, req) + got, err = sel(ctx) if got != want { t.Errorf("Expected policy %v got %v", want, got) } @@ -67,7 +71,7 @@ func TestStaticSelector(t *testing.T) { } } -type testCase struct { +type migrationTestCase struct { AccSvcShouldReturnError bool Claims map[string]interface{} Expected string @@ -79,7 +83,7 @@ func TestMigrationSelector(t *testing.T) { AccNotFoundPolicy: "not_found", UnauthenticatedPolicy: "unauth", } - var tests = []testCase{ + var tests = []migrationTestCase{ {true, map[string]interface{}{oidc.PreferredUsername: "Hans"}, "not_found"}, {true, map[string]interface{}{oidc.Email: "hans@example.test"}, "not_found"}, {false, map[string]interface{}{oidc.PreferredUsername: "Hans"}, "found"}, @@ -87,15 +91,11 @@ func TestMigrationSelector(t *testing.T) { } for _, tc := range tests { - //t.Run(fmt.Sprintf("#%v", k), func(t *testing.T) { - // t.Parallel() tc := tc sut := NewMigrationSelector(&cfg, mockAccSvc(tc.AccSvcShouldReturnError)) - r := httptest.NewRequest("GET", "https://example.com", nil) - ctx := oidc.NewContext(r.Context(), tc.Claims) - nr := r.WithContext(ctx) + ctx := oidc.NewContext(context.Background(), tc.Claims) - got, err := sut(ctx, nr) + got, err := sut(ctx) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -103,7 +103,6 @@ func TestMigrationSelector(t *testing.T) { if got != tc.Expected { t.Errorf("Expected Policy %v got %v", tc.Expected, got) } - //}) } } @@ -123,3 +122,74 @@ func mockAccSvc(retErr bool) proto.AccountsService { } } + +type testCase struct { + Name string + Context context.Context + Expected string +} + +func TestClaimsSelector(t *testing.T) { + sel := NewClaimsSelector(&config.ClaimsSelectorConf{ + DefaultPolicy: "default", + UnauthenticatedPolicy: "unauthenticated", + }) + + var tests = []testCase{ + {"unatuhenticated", context.Background(), "unauthenticated"}, + {"default", oidc.NewContext(context.Background(), map[string]interface{}{oidc.OcisRoutingPolicy: ""}), "default"}, + {"claim-value", oidc.NewContext(context.Background(), map[string]interface{}{oidc.OcisRoutingPolicy: "ocis.routing.policy-value"}), "ocis.routing.policy-value"}, + } + for _, tc := range tests { + got, err := sel(tc.Context) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if got != tc.Expected { + t.Errorf("Expected Policy %v got %v", tc.Expected, got) + } + } +} + +func TestRegexSelector(t *testing.T) { + sel := NewRegexSelector(&config.RegexSelectorConf{ + DefaultPolicy: "default", + MatchesPolicies: []config.RegexRuleConf{ + {Priority: 10, Property: "mail", Match: "marie@example.org", Policy: "ocis"}, + {Priority: 20, Property: "mail", Match: "[^@]+@example.org", Policy: "oc10"}, + {Priority: 30, Property: "username", Match: "(einstein|feynman)", Policy: "ocis"}, + {Priority: 40, Property: "username", Match: ".+", Policy: "oc10"}, + {Priority: 50, Property: "id", Match: "4c510ada-c86b-4815-8820-42cdf82c3d51", Policy: "ocis"}, + {Priority: 60, Property: "id", Match: "f7fbf8c8-139b-4376-b307-cf0a8c2d0d9c", Policy: "oc10"}, + }, + UnauthenticatedPolicy: "unauthenticated", + }) + + var tests = []testCase{ + {"unauthenticated", context.Background(), "unauthenticated"}, + {"default", revauser.ContextSetUser(context.Background(), &userv1beta1.User{}), "default"}, + {"mail-ocis", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Mail: "marie@example.org"}), "ocis"}, + {"mail-oc10", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Mail: "einstein@example.org"}), "oc10"}, + {"username-einstein", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Username: "einstein"}), "ocis"}, + {"username-feynman", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Username: "feynman"}), "ocis"}, + {"username-marie", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Username: "marie"}), "oc10"}, + {"id-nil", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Id: &userv1beta1.UserId{}}), "default"}, + {"id-1", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Id: &userv1beta1.UserId{OpaqueId: "4c510ada-c86b-4815-8820-42cdf82c3d51"}}), "ocis"}, + {"id-2", revauser.ContextSetUser(context.Background(), &userv1beta1.User{Id: &userv1beta1.UserId{OpaqueId: "f7fbf8c8-139b-4376-b307-cf0a8c2d0d9c"}}), "oc10"}, + } + + for _, tc := range tests { + tc := tc // capture range variable + t.Run(tc.Name, func(t *testing.T) { + got, err := sel(tc.Context) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + if got != tc.Expected { + t.Errorf("Expected Policy %v got %v", tc.Expected, got) + } + }) + } +} diff --git a/proxy/pkg/proxy/proxy.go b/proxy/pkg/proxy/proxy.go index d722439e05..b4871617a1 100644 --- a/proxy/pkg/proxy/proxy.go +++ b/proxy/pkg/proxy/proxy.go @@ -110,7 +110,7 @@ func NewMultiHostReverseProxy(opts ...Option) *MultiHostReverseProxy { } func (p *MultiHostReverseProxy) directorSelectionDirector(r *http.Request) { - pol, err := p.PolicySelector(r.Context(), r) + pol, err := p.PolicySelector(r.Context()) if err != nil { p.logger.Error().Msgf("Error while selecting pol %v", err) return