mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-02-18 03:18:52 -06:00
allow skipping userinfo call
Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
@@ -27,7 +27,7 @@ import (
|
||||
// OIDCClient used to mock the oidc client during tests
|
||||
type OIDCClient interface {
|
||||
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
|
||||
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error)
|
||||
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error)
|
||||
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
|
||||
}
|
||||
|
||||
@@ -271,27 +271,26 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error) {
|
||||
var mapClaims []string
|
||||
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error) {
|
||||
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
|
||||
return RegClaimsWithSID{}, mapClaims, err
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, err
|
||||
}
|
||||
switch c.accessTokenVerifyMethod {
|
||||
case config.AccessTokenVerificationJWT:
|
||||
return c.verifyAccessTokenJWT(token)
|
||||
case config.AccessTokenVerificationNone:
|
||||
c.Logger.Debug().Msg("Access Token verification disabled")
|
||||
return RegClaimsWithSID{}, mapClaims, nil
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, nil
|
||||
default:
|
||||
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
||||
return RegClaimsWithSID{}, mapClaims, errors.New("unknown Access Token Verification method")
|
||||
return RegClaimsWithSID{}, jwt.MapClaims{}, errors.New("unknown Access Token Verification method")
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
||||
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, []string, error) {
|
||||
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, jwt.MapClaims, error) {
|
||||
var claims RegClaimsWithSID
|
||||
var mapClaims []string
|
||||
mapClaims := jwt.MapClaims{}
|
||||
jwks := c.getKeyfunc()
|
||||
if jwks == nil {
|
||||
return claims, mapClaims, errors.New("error initializing jwks keyfunc")
|
||||
@@ -301,7 +300,7 @@ func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, []str
|
||||
if err != nil {
|
||||
return claims, mapClaims, err
|
||||
}
|
||||
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
_, _, err = new(jwt.Parser).ParseUnverified(token, mapClaims)
|
||||
// TODO: decode mapClaims to sth readable
|
||||
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
if err != nil {
|
||||
|
||||
@@ -402,6 +402,7 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config,
|
||||
middleware.Logger(logger),
|
||||
middleware.UserProvider(userProvider),
|
||||
middleware.UserRoleAssigner(roleAssigner),
|
||||
middleware.SkipUserInfo(cfg.OIDC.SkipUserInfo),
|
||||
middleware.UserOIDCClaim(cfg.UserOIDCClaim),
|
||||
middleware.UserCS3Claim(cfg.UserCS3Claim),
|
||||
middleware.AutoprovisionAccounts(cfg.AutoprovisionAccounts),
|
||||
|
||||
@@ -107,6 +107,7 @@ type OIDC struct {
|
||||
Issuer string `yaml:"issuer" env:"OCIS_URL;OCIS_OIDC_ISSUER;PROXY_OIDC_ISSUER" desc:"URL of the OIDC issuer. It defaults to URL of the builtin IDP."`
|
||||
Insecure bool `yaml:"insecure" env:"OCIS_INSECURE;PROXY_OIDC_INSECURE" desc:"Disable TLS certificate validation for connections to the IDP. Note that this is not recommended for production environments."`
|
||||
AccessTokenVerifyMethod string `yaml:"access_token_verify_method" env:"PROXY_OIDC_ACCESS_TOKEN_VERIFY_METHOD" desc:"Sets how OIDC access tokens should be verified. Possible values are 'none' and 'jwt'. When using 'none', no special validation apart from using it for accessing the IPD's userinfo endpoint will be done. When using 'jwt', it tries to parse the access token as a jwt token and verifies the signature using the keys published on the IDP's 'jwks_uri'."`
|
||||
SkipUserInfo bool `yaml:"skip_user_info" env:"PROXY_OIDC_SKIP_USER_INFO" desc:"Do not look up user claims at the userinfo endpoint and directly read them from the access token. Incompatible with PROXY_OIDC_ACCESS_TOKEN_VERIFY_METHOD=none"`
|
||||
UserinfoCache *Cache `yaml:"user_info_cache"`
|
||||
JWKS JWKS `yaml:"jwks"`
|
||||
RewriteWellKnown bool `yaml:"rewrite_well_known" env:"PROXY_OIDC_REWRITE_WELLKNOWN" desc:"Enables rewriting the /.well-known/openid-configuration to the configured OIDC issuer. Needed by the Desktop Client, Android Client and iOS Client to discover the OIDC provider."`
|
||||
|
||||
@@ -41,6 +41,7 @@ func DefaultConfig() *config.Config {
|
||||
Issuer: "https://localhost:9200",
|
||||
|
||||
AccessTokenVerifyMethod: config.AccessTokenVerificationJWT,
|
||||
SkipUserInfo: false,
|
||||
UserinfoCache: &config.Cache{
|
||||
Store: "memory",
|
||||
Database: "ocis",
|
||||
@@ -61,10 +62,10 @@ func DefaultConfig() *config.Config {
|
||||
OIDCRoleMapper: config.OIDCRoleMapper{
|
||||
RoleClaim: "roles",
|
||||
RolesMap: []config.RoleMapping{
|
||||
config.RoleMapping{RoleName: "admin", ClaimValue: "ocisAdmin"},
|
||||
config.RoleMapping{RoleName: "spaceadmin", ClaimValue: "ocisSpaceAdmin"},
|
||||
config.RoleMapping{RoleName: "user", ClaimValue: "ocisUser"},
|
||||
config.RoleMapping{RoleName: "guest", ClaimValue: "ocisGuest"},
|
||||
{RoleName: "admin", ClaimValue: "ocisAdmin"},
|
||||
{RoleName: "spaceadmin", ClaimValue: "ocisSpaceAdmin"},
|
||||
{RoleName: "user", ClaimValue: "ocisUser"},
|
||||
{RoleName: "guest", ClaimValue: "ocisGuest"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -50,6 +50,12 @@ func Validate(cfg *config.Config) error {
|
||||
config.AccessTokenVerificationJWT, config.AccessTokenVerificationNone,
|
||||
)
|
||||
}
|
||||
if cfg.OIDC.AccessTokenVerifyMethod == "none" && cfg.OIDC.SkipUserInfo {
|
||||
return fmt.Errorf(
|
||||
"Incompatible value '%t' for 'skip_user_info' in service %s. Must be false when 'access_token_verify_method' is 'none'.",
|
||||
cfg.OIDC.SkipUserInfo, cfg.Service.Name,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
|
||||
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
|
||||
@@ -41,6 +43,54 @@ type accountResolver struct {
|
||||
userCS3Claim string
|
||||
}
|
||||
|
||||
// from https://codereview.stackexchange.com/a/280193
|
||||
func splitWithEscaping(s string, separator string, escapeString string) []string {
|
||||
a := strings.Split(s, separator)
|
||||
|
||||
for i := len(a) - 2; i >= 0; i-- {
|
||||
if strings.HasSuffix(a[i], escapeString) {
|
||||
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
|
||||
a = append(a[:i+1], a[i+2:]...)
|
||||
}
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
|
||||
// happy path
|
||||
value, _ := claims[path].(string)
|
||||
if value != "" {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// try splitting path at .
|
||||
segments := splitWithEscaping(path, ".", "\\")
|
||||
subclaims := claims
|
||||
lastSegment := len(segments) - 1
|
||||
for i := range segments {
|
||||
if i < lastSegment {
|
||||
if castedClaims, ok := subclaims[segments[i]].(map[string]interface{}); ok {
|
||||
subclaims = castedClaims
|
||||
} else if castedClaims, ok := subclaims[segments[i]].(map[interface{}]interface{}); ok {
|
||||
subclaims = make(map[string]interface{}, len(castedClaims))
|
||||
for k, v := range castedClaims {
|
||||
if s, ok := k.(string); ok {
|
||||
subclaims[s] = v
|
||||
} else {
|
||||
return "", fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if value, _ = subclaims[segments[i]].(string); value != "" {
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return value, fmt.Errorf("claim path '%s' not set or empty", path)
|
||||
}
|
||||
|
||||
// TODO do not use the context to store values: https://medium.com/@cep21/how-to-correctly-use-context-context-in-go-1-7-8f2c0fafdf39
|
||||
func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
ctx := req.Context()
|
||||
@@ -55,13 +105,10 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
|
||||
if user == nil && claims != nil {
|
||||
|
||||
var err error
|
||||
var value string
|
||||
var ok bool
|
||||
if value, ok = claims[m.userOIDCClaim].(string); !ok || value == "" {
|
||||
m.logger.Error().Str("claim", m.userOIDCClaim).Interface("claims", claims).Msg("claim not set or empty")
|
||||
w.WriteHeader(http.StatusInternalServerError) // admin needs to make the idp send the right claim
|
||||
value, err := readUserIDClaim(m.userOIDCClaim, claims)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("could not read user id claim")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -57,6 +57,68 @@ func TestTokenIsAddedWithUsernameClaim(t *testing.T) {
|
||||
assert.Contains(t, token, "eyJ")
|
||||
}
|
||||
|
||||
func TestTokenIsAddedWithDotUsernamePathClaim(t *testing.T) {
|
||||
sut := newMockAccountResolver(&userv1beta1.User{
|
||||
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
|
||||
Mail: "foo@example.com",
|
||||
}, nil, "li.un", "username")
|
||||
|
||||
// This is how lico adds the username to the access token
|
||||
req, rw := mockRequest(map[string]interface{}{
|
||||
oidc.Iss: "https://idx.example.com",
|
||||
"li": map[string]interface{}{
|
||||
"un": "foo",
|
||||
},
|
||||
})
|
||||
|
||||
sut.ServeHTTP(rw, req)
|
||||
|
||||
token := req.Header.Get(revactx.TokenHeader)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
assert.Contains(t, token, "eyJ")
|
||||
}
|
||||
|
||||
func TestTokenIsAddedWithDotEscapedUsernameClaim(t *testing.T) {
|
||||
sut := newMockAccountResolver(&userv1beta1.User{
|
||||
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
|
||||
Mail: "foo@example.com",
|
||||
}, nil, "li\\.un", "username")
|
||||
|
||||
// This tests the . escaping of the readUserIDClaim
|
||||
req, rw := mockRequest(map[string]interface{}{
|
||||
oidc.Iss: "https://idx.example.com",
|
||||
"li.un": "foo",
|
||||
})
|
||||
|
||||
sut.ServeHTTP(rw, req)
|
||||
|
||||
token := req.Header.Get(revactx.TokenHeader)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
assert.Contains(t, token, "eyJ")
|
||||
}
|
||||
|
||||
func TestTokenIsAddedWithDottedUsernameClaimFallback(t *testing.T) {
|
||||
sut := newMockAccountResolver(&userv1beta1.User{
|
||||
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
|
||||
Mail: "foo@example.com",
|
||||
}, nil, "li.un", "username")
|
||||
|
||||
// This tests the . escaping fallback of the readUserIDClaim
|
||||
req, rw := mockRequest(map[string]interface{}{
|
||||
oidc.Iss: "https://idx.example.com",
|
||||
"li.un": "foo",
|
||||
})
|
||||
|
||||
sut.ServeHTTP(rw, req)
|
||||
|
||||
token := req.Header.Get(revactx.TokenHeader)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
assert.Contains(t, token, "eyJ")
|
||||
}
|
||||
|
||||
func TestNSkipOnNoClaims(t *testing.T) {
|
||||
sut := newMockAccountResolver(nil, backend.ErrAccountDisabled, oidc.Email, "mail")
|
||||
req, rw := mockRequest(nil)
|
||||
@@ -133,6 +195,7 @@ func newMockAccountResolver(userBackendResult *userv1beta1.User, userBackendErr
|
||||
UserProvider(&ub),
|
||||
UserRoleAssigner(&ra),
|
||||
TokenManagerConfig(config.TokenManager{JWTSecret: "secret"}),
|
||||
SkipUserInfo(false),
|
||||
UserOIDCClaim(oidcclaim),
|
||||
UserCS3Claim(cs3claim),
|
||||
AutoprovisionAccounts(false),
|
||||
|
||||
@@ -3,7 +3,6 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,6 +34,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
|
||||
OIDCIss: options.OIDCIss,
|
||||
oidcClient: options.OIDCClient,
|
||||
AccessTokenVerifyMethod: options.AccessTokenVerifyMethod,
|
||||
skipUserInfo: options.SkipUserInfo,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type OIDCAuthenticator struct {
|
||||
DefaultTokenCacheTTL time.Duration
|
||||
oidcClient oidc.OIDCClient
|
||||
AccessTokenVerifyMethod string
|
||||
skipUserInfo bool
|
||||
}
|
||||
|
||||
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
|
||||
@@ -62,36 +63,38 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
|
||||
m.Logger.Error().Err(err).Msg("could not read from userinfo cache")
|
||||
}
|
||||
if len(record) > 0 {
|
||||
if err = msgpack.Unmarshal(record[0].Value, &claims); err == nil {
|
||||
if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil {
|
||||
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
return claims, nil
|
||||
}
|
||||
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
|
||||
}
|
||||
|
||||
aClaims, _, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
|
||||
aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to verify access token")
|
||||
}
|
||||
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
if !m.skipUserInfo {
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.oidcClient.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get userinfo")
|
||||
}
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
|
||||
userInfo, err := m.oidcClient.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get userinfo")
|
||||
}
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
|
||||
}
|
||||
}
|
||||
|
||||
expiration := m.extractExpiration(aClaims)
|
||||
go func() {
|
||||
if d, err := msgpack.Marshal(claims); err != nil {
|
||||
if d, err := msgpack.MarshalAsMap(claims); err != nil {
|
||||
m.Logger.Error().Err(err).Msg("failed to marshal claims for userinfo cache")
|
||||
} else {
|
||||
err = m.userInfoCache.Write(&store.Record{
|
||||
@@ -106,7 +109,7 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
|
||||
if sid := aClaims.SessionID; sid != "" {
|
||||
// reuse user cache for session id lookup
|
||||
err = m.userInfoCache.Write(&store.Record{
|
||||
Key: fmt.Sprintf("%s", sid),
|
||||
Key: sid,
|
||||
Value: []byte(encodedHash),
|
||||
Expiry: time.Until(expiration),
|
||||
})
|
||||
|
||||
@@ -75,6 +75,8 @@ type Options struct {
|
||||
RoleQuotas map[string]uint64
|
||||
// TraceProvider sets the tracing provider.
|
||||
TraceProvider trace.TracerProvider
|
||||
// SkipUserInfo prevents the oidc middleware from querying the userinfo endpoint and read any claims directly from the access token instead
|
||||
SkipUserInfo bool
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
@@ -248,3 +250,10 @@ func TraceProvider(tp trace.TracerProvider) Option {
|
||||
o.TraceProvider = tp
|
||||
}
|
||||
}
|
||||
|
||||
// SkipUserInfo sets the skipUserInfo flag.
|
||||
func SkipUserInfo(val bool) Option {
|
||||
return func(o *Options) {
|
||||
o.SkipUserInfo = val
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user