allow skipping userinfo call

Signed-off-by: Jörn Friedrich Dreyer <jfd@butonic.de>
This commit is contained in:
Jörn Friedrich Dreyer
2023-08-23 10:36:05 +02:00
parent e49b2c159b
commit 5422586bfa
9 changed files with 167 additions and 37 deletions

View File

@@ -27,7 +27,7 @@ import (
// OIDCClient used to mock the oidc client during tests
type OIDCClient interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error)
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error)
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
}
@@ -271,27 +271,26 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc
}, nil
}
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error) {
var mapClaims []string
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, jwt.MapClaims, error) {
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return RegClaimsWithSID{}, mapClaims, err
return RegClaimsWithSID{}, jwt.MapClaims{}, err
}
switch c.accessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return c.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
c.Logger.Debug().Msg("Access Token verification disabled")
return RegClaimsWithSID{}, mapClaims, nil
return RegClaimsWithSID{}, jwt.MapClaims{}, nil
default:
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return RegClaimsWithSID{}, mapClaims, errors.New("unknown Access Token Verification method")
return RegClaimsWithSID{}, jwt.MapClaims{}, errors.New("unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, []string, error) {
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, jwt.MapClaims, error) {
var claims RegClaimsWithSID
var mapClaims []string
mapClaims := jwt.MapClaims{}
jwks := c.getKeyfunc()
if jwks == nil {
return claims, mapClaims, errors.New("error initializing jwks keyfunc")
@@ -301,7 +300,7 @@ func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, []str
if err != nil {
return claims, mapClaims, err
}
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
_, _, err = new(jwt.Parser).ParseUnverified(token, mapClaims)
// TODO: decode mapClaims to sth readable
c.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {

View File

@@ -402,6 +402,7 @@ func loadMiddlewares(ctx context.Context, logger log.Logger, cfg *config.Config,
middleware.Logger(logger),
middleware.UserProvider(userProvider),
middleware.UserRoleAssigner(roleAssigner),
middleware.SkipUserInfo(cfg.OIDC.SkipUserInfo),
middleware.UserOIDCClaim(cfg.UserOIDCClaim),
middleware.UserCS3Claim(cfg.UserCS3Claim),
middleware.AutoprovisionAccounts(cfg.AutoprovisionAccounts),

View File

@@ -107,6 +107,7 @@ type OIDC struct {
Issuer string `yaml:"issuer" env:"OCIS_URL;OCIS_OIDC_ISSUER;PROXY_OIDC_ISSUER" desc:"URL of the OIDC issuer. It defaults to URL of the builtin IDP."`
Insecure bool `yaml:"insecure" env:"OCIS_INSECURE;PROXY_OIDC_INSECURE" desc:"Disable TLS certificate validation for connections to the IDP. Note that this is not recommended for production environments."`
AccessTokenVerifyMethod string `yaml:"access_token_verify_method" env:"PROXY_OIDC_ACCESS_TOKEN_VERIFY_METHOD" desc:"Sets how OIDC access tokens should be verified. Possible values are 'none' and 'jwt'. When using 'none', no special validation apart from using it for accessing the IPD's userinfo endpoint will be done. When using 'jwt', it tries to parse the access token as a jwt token and verifies the signature using the keys published on the IDP's 'jwks_uri'."`
SkipUserInfo bool `yaml:"skip_user_info" env:"PROXY_OIDC_SKIP_USER_INFO" desc:"Do not look up user claims at the userinfo endpoint and directly read them from the access token. Incompatible with PROXY_OIDC_ACCESS_TOKEN_VERIFY_METHOD=none"`
UserinfoCache *Cache `yaml:"user_info_cache"`
JWKS JWKS `yaml:"jwks"`
RewriteWellKnown bool `yaml:"rewrite_well_known" env:"PROXY_OIDC_REWRITE_WELLKNOWN" desc:"Enables rewriting the /.well-known/openid-configuration to the configured OIDC issuer. Needed by the Desktop Client, Android Client and iOS Client to discover the OIDC provider."`

View File

@@ -41,6 +41,7 @@ func DefaultConfig() *config.Config {
Issuer: "https://localhost:9200",
AccessTokenVerifyMethod: config.AccessTokenVerificationJWT,
SkipUserInfo: false,
UserinfoCache: &config.Cache{
Store: "memory",
Database: "ocis",
@@ -61,10 +62,10 @@ func DefaultConfig() *config.Config {
OIDCRoleMapper: config.OIDCRoleMapper{
RoleClaim: "roles",
RolesMap: []config.RoleMapping{
config.RoleMapping{RoleName: "admin", ClaimValue: "ocisAdmin"},
config.RoleMapping{RoleName: "spaceadmin", ClaimValue: "ocisSpaceAdmin"},
config.RoleMapping{RoleName: "user", ClaimValue: "ocisUser"},
config.RoleMapping{RoleName: "guest", ClaimValue: "ocisGuest"},
{RoleName: "admin", ClaimValue: "ocisAdmin"},
{RoleName: "spaceadmin", ClaimValue: "ocisSpaceAdmin"},
{RoleName: "user", ClaimValue: "ocisUser"},
{RoleName: "guest", ClaimValue: "ocisGuest"},
},
},
},

View File

@@ -50,6 +50,12 @@ func Validate(cfg *config.Config) error {
config.AccessTokenVerificationJWT, config.AccessTokenVerificationNone,
)
}
if cfg.OIDC.AccessTokenVerifyMethod == "none" && cfg.OIDC.SkipUserInfo {
return fmt.Errorf(
"Incompatible value '%t' for 'skip_user_info' in service %s. Must be false when 'access_token_verify_method' is 'none'.",
cfg.OIDC.SkipUserInfo, cfg.Service.Name,
)
}
return nil
}

View File

@@ -2,7 +2,9 @@ package middleware
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/owncloud/ocis/v2/services/proxy/pkg/user/backend"
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
@@ -41,6 +43,54 @@ type accountResolver struct {
userCS3Claim string
}
// from https://codereview.stackexchange.com/a/280193
func splitWithEscaping(s string, separator string, escapeString string) []string {
a := strings.Split(s, separator)
for i := len(a) - 2; i >= 0; i-- {
if strings.HasSuffix(a[i], escapeString) {
a[i] = a[i][:len(a[i])-len(escapeString)] + separator + a[i+1]
a = append(a[:i+1], a[i+2:]...)
}
}
return a
}
func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
// happy path
value, _ := claims[path].(string)
if value != "" {
return value, nil
}
// try splitting path at .
segments := splitWithEscaping(path, ".", "\\")
subclaims := claims
lastSegment := len(segments) - 1
for i := range segments {
if i < lastSegment {
if castedClaims, ok := subclaims[segments[i]].(map[string]interface{}); ok {
subclaims = castedClaims
} else if castedClaims, ok := subclaims[segments[i]].(map[interface{}]interface{}); ok {
subclaims = make(map[string]interface{}, len(castedClaims))
for k, v := range castedClaims {
if s, ok := k.(string); ok {
subclaims[s] = v
} else {
return "", fmt.Errorf("could not walk claims path, key '%v' is not a string", k)
}
}
}
} else {
if value, _ = subclaims[segments[i]].(string); value != "" {
return value, nil
}
}
}
return value, fmt.Errorf("claim path '%s' not set or empty", path)
}
// TODO do not use the context to store values: https://medium.com/@cep21/how-to-correctly-use-context-context-in-go-1-7-8f2c0fafdf39
func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx := req.Context()
@@ -55,13 +105,10 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
}
if user == nil && claims != nil {
var err error
var value string
var ok bool
if value, ok = claims[m.userOIDCClaim].(string); !ok || value == "" {
m.logger.Error().Str("claim", m.userOIDCClaim).Interface("claims", claims).Msg("claim not set or empty")
w.WriteHeader(http.StatusInternalServerError) // admin needs to make the idp send the right claim
value, err := readUserIDClaim(m.userOIDCClaim, claims)
if err != nil {
m.logger.Error().Err(err).Msg("could not read user id claim")
w.WriteHeader(http.StatusInternalServerError)
return
}

View File

@@ -57,6 +57,68 @@ func TestTokenIsAddedWithUsernameClaim(t *testing.T) {
assert.Contains(t, token, "eyJ")
}
func TestTokenIsAddedWithDotUsernamePathClaim(t *testing.T) {
sut := newMockAccountResolver(&userv1beta1.User{
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
Mail: "foo@example.com",
}, nil, "li.un", "username")
// This is how lico adds the username to the access token
req, rw := mockRequest(map[string]interface{}{
oidc.Iss: "https://idx.example.com",
"li": map[string]interface{}{
"un": "foo",
},
})
sut.ServeHTTP(rw, req)
token := req.Header.Get(revactx.TokenHeader)
assert.NotEmpty(t, token)
assert.Contains(t, token, "eyJ")
}
func TestTokenIsAddedWithDotEscapedUsernameClaim(t *testing.T) {
sut := newMockAccountResolver(&userv1beta1.User{
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
Mail: "foo@example.com",
}, nil, "li\\.un", "username")
// This tests the . escaping of the readUserIDClaim
req, rw := mockRequest(map[string]interface{}{
oidc.Iss: "https://idx.example.com",
"li.un": "foo",
})
sut.ServeHTTP(rw, req)
token := req.Header.Get(revactx.TokenHeader)
assert.NotEmpty(t, token)
assert.Contains(t, token, "eyJ")
}
func TestTokenIsAddedWithDottedUsernameClaimFallback(t *testing.T) {
sut := newMockAccountResolver(&userv1beta1.User{
Id: &userv1beta1.UserId{Idp: "https://idx.example.com", OpaqueId: "123"},
Mail: "foo@example.com",
}, nil, "li.un", "username")
// This tests the . escaping fallback of the readUserIDClaim
req, rw := mockRequest(map[string]interface{}{
oidc.Iss: "https://idx.example.com",
"li.un": "foo",
})
sut.ServeHTTP(rw, req)
token := req.Header.Get(revactx.TokenHeader)
assert.NotEmpty(t, token)
assert.Contains(t, token, "eyJ")
}
func TestNSkipOnNoClaims(t *testing.T) {
sut := newMockAccountResolver(nil, backend.ErrAccountDisabled, oidc.Email, "mail")
req, rw := mockRequest(nil)
@@ -133,6 +195,7 @@ func newMockAccountResolver(userBackendResult *userv1beta1.User, userBackendErr
UserProvider(&ub),
UserRoleAssigner(&ra),
TokenManagerConfig(config.TokenManager{JWTSecret: "secret"}),
SkipUserInfo(false),
UserOIDCClaim(oidcclaim),
UserCS3Claim(cs3claim),
AutoprovisionAccounts(false),

View File

@@ -3,7 +3,6 @@ package middleware
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
"time"
@@ -35,6 +34,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
OIDCIss: options.OIDCIss,
oidcClient: options.OIDCClient,
AccessTokenVerifyMethod: options.AccessTokenVerifyMethod,
skipUserInfo: options.SkipUserInfo,
}
}
@@ -47,6 +47,7 @@ type OIDCAuthenticator struct {
DefaultTokenCacheTTL time.Duration
oidcClient oidc.OIDCClient
AccessTokenVerifyMethod string
skipUserInfo bool
}
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
@@ -62,36 +63,38 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
m.Logger.Error().Err(err).Msg("could not read from userinfo cache")
}
if len(record) > 0 {
if err = msgpack.Unmarshal(record[0].Value, &claims); err == nil {
if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil {
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
return claims, nil
}
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
}
aClaims, _, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify access token")
}
oauth2Token := &oauth2.Token{
AccessToken: token,
}
if !m.skipUserInfo {
oauth2Token := &oauth2.Token{
AccessToken: token,
}
userInfo, err := m.oidcClient.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
return nil, errors.Wrap(err, "failed to get userinfo")
}
if err := userInfo.Claims(&claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
userInfo, err := m.oidcClient.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
return nil, errors.Wrap(err, "failed to get userinfo")
}
if err := userInfo.Claims(&claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
}
}
expiration := m.extractExpiration(aClaims)
go func() {
if d, err := msgpack.Marshal(claims); err != nil {
if d, err := msgpack.MarshalAsMap(claims); err != nil {
m.Logger.Error().Err(err).Msg("failed to marshal claims for userinfo cache")
} else {
err = m.userInfoCache.Write(&store.Record{
@@ -106,7 +109,7 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
if sid := aClaims.SessionID; sid != "" {
// reuse user cache for session id lookup
err = m.userInfoCache.Write(&store.Record{
Key: fmt.Sprintf("%s", sid),
Key: sid,
Value: []byte(encodedHash),
Expiry: time.Until(expiration),
})

View File

@@ -75,6 +75,8 @@ type Options struct {
RoleQuotas map[string]uint64
// TraceProvider sets the tracing provider.
TraceProvider trace.TracerProvider
// SkipUserInfo prevents the oidc middleware from querying the userinfo endpoint and read any claims directly from the access token instead
SkipUserInfo bool
}
// newOptions initializes the available default options.
@@ -248,3 +250,10 @@ func TraceProvider(tp trace.TracerProvider) Option {
o.TraceProvider = tp
}
}
// SkipUserInfo sets the skipUserInfo flag.
func SkipUserInfo(val bool) Option {
return func(o *Options) {
o.SkipUserInfo = val
}
}