Parse role claims (#7713)

* extract and test role claim parsing

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add failing test

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* read segmented roles claim as array and string

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* reuse more code by extracting WalkSegments

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add TestSplitWithEscaping

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* docs and error for unhandled case

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add claims test

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

* add missing ReadStringClaim docs

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>

---------

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2023-12-04 12:18:52 +01:00
committed by GitHub
parent 81ace6dd1d
commit 23e59b5ded
5 changed files with 409 additions and 32 deletions
@@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
@@ -43,19 +42,6 @@ type accountResolver struct {
userCS3Claim string
}
// from https://codereview.stackexchange.com/a/280193
func splitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)
for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}
func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
// happy path
value, _ := claims[path].(string)
@@ -64,7 +50,7 @@ func readUserIDClaim(path string, claims map[string]interface{}) (string, error)
}
// try splitting path at .
segments := splitWithEscaping(path, ".", "\\")
segments := oidc.SplitWithEscaping(path, ".", "\\")
subclaims := claims
lastSegment := len(segments) - 1
for i := range segments {
+44 -17
View File
@@ -9,6 +9,7 @@ import (
cs3 "github.com/cs3org/go-cs3apis/cs3/identity/user/v1beta1"
"github.com/cs3org/reva/v2/pkg/utils"
"github.com/owncloud/ocis/v2/ocis-pkg/middleware"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
settingssvc "github.com/owncloud/ocis/v2/protogen/gen/ocis/services/settings/v0"
"go-micro.dev/v4/metadata"
)
@@ -29,6 +30,45 @@ func NewOIDCRoleAssigner(opts ...Option) UserRoleAssigner {
}
}
func extractRoles(rolesClaim string, claims map[string]interface{}) (map[string]struct{}, error) {
claimRoles := map[string]struct{}{}
// happy path
value, _ := claims[rolesClaim].(string)
if value != "" {
claimRoles[value] = struct{}{}
return claimRoles, nil
}
claim, err := oidc.WalkSegments(oidc.SplitWithEscaping(rolesClaim, ".", "\\"), claims)
if err != nil {
return nil, err
}
switch v := claim.(type) {
case []string:
for _, cr := range v {
claimRoles[cr] = struct{}{}
}
case []interface{}:
for _, cri := range v {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
return nil, err
}
claimRoles[cr] = struct{}{}
}
case string:
claimRoles[v] = struct{}{}
default:
return nil, errors.New("no roles in user claims")
}
return claimRoles, nil
}
// UpdateUserRoleAssignment assigns the role "User" to the supplied user. Unless the user
// already has a different role assigned.
func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *cs3.User, claims map[string]interface{}) (*cs3.User, error) {
@@ -39,23 +79,10 @@ func (ra oidcRoleAssigner) UpdateUserRoleAssignment(ctx context.Context, user *c
return nil, err
}
claimRolesRaw, ok := claims[ra.rolesClaim].([]interface{})
if !ok {
logger.Error().Str("rolesClaim", ra.rolesClaim).Msg("No roles in user claims")
return nil, errors.New("no roles in user claims")
}
logger.Debug().Str("rolesClaim", ra.rolesClaim).Interface("rolesInClaim", claims[ra.rolesClaim]).Msg("got roles in claim")
claimRoles := map[string]struct{}{}
for _, cri := range claimRolesRaw {
cr, ok := cri.(string)
if !ok {
err := errors.New("invalid role in claims")
logger.Error().Err(err).Interface("claimValue", cri).Msg("Is not a valid string.")
return nil, err
}
claimRoles[cr] = struct{}{}
claimRoles, err := extractRoles(ra.rolesClaim, claims)
if err != nil {
logger.Error().Err(err).Msg("Error mapping role names to role ids")
return nil, err
}
if len(claimRoles) == 0 {
@@ -0,0 +1,120 @@
package userroles
import (
"encoding/json"
"testing"
)
func TestExtractRolesArray(t *testing.T) {
byt := []byte(`{"roles":["a","b"]}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
if _, ok := roles["b"]; !ok {
t.Fatal("must contain 'b'")
}
}
func TestExtractRolesString(t *testing.T) {
byt := []byte(`{"roles":"a"}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestExtractRolesPathArray(t *testing.T) {
byt := []byte(`{"sub":{"roles":["a","b"]}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
if _, ok := roles["b"]; !ok {
t.Fatal("must contain 'b'")
}
}
func TestExtractRolesPathString(t *testing.T) {
byt := []byte(`{"sub":{"roles":"a"}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestExtractEscapedRolesPathString(t *testing.T) {
byt := []byte(`{"sub.roles":"a"}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub\\.roles", claims)
if err != nil {
t.Fatal(err)
}
if _, ok := roles["a"]; !ok {
t.Fatal("must contain 'a'")
}
}
func TestNoRoles(t *testing.T) {
byt := []byte(`{"sub":{"foo":"a"}}`)
claims := map[string]interface{}{}
err := json.Unmarshal(byt, &claims)
if err != nil {
t.Fatal(err)
}
roles, err := extractRoles("sub.roles", claims)
if err == nil {
t.Fatal("must not find a role")
}
if len(roles) != 0 {
t.Fatal("length of roles mut be 0")
}
}