From 08e218aa3e6552c09399e4aa2b4b1be38b078dd7 Mon Sep 17 00:00:00 2001 From: Benedikt Kulmann Date: Wed, 18 Nov 2020 12:08:23 +0100 Subject: [PATCH] Use expiration from access token if available --- proxy/pkg/cache/cache.go | 12 +++--- proxy/pkg/middleware/oidc_auth.go | 71 ++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/proxy/pkg/cache/cache.go b/proxy/pkg/cache/cache.go index 2c1d7e93ec..fcfa7e01b4 100644 --- a/proxy/pkg/cache/cache.go +++ b/proxy/pkg/cache/cache.go @@ -7,15 +7,14 @@ import ( // Entry represents an entry on the cache. You can type assert on V. type Entry struct { - V interface{} - inserted time.Time + V interface{} + expiration time.Time } // Cache is a barebones cache implementation. type Cache struct { entries map[string]*Entry size int - ttl time.Duration m sync.Mutex } @@ -25,7 +24,6 @@ func NewCache(o ...Option) Cache { return Cache{ size: opts.size, - ttl: opts.ttl, entries: map[string]*Entry{}, } } @@ -46,7 +44,7 @@ func (c *Cache) Get(k string) *Entry { } // Set sets a roleID / role-bundle. -func (c *Cache) Set(k string, val interface{}) { +func (c *Cache) Set(k string, val interface{}, expiration time.Time) { c.m.Lock() defer c.m.Unlock() @@ -56,7 +54,7 @@ func (c *Cache) Set(k string, val interface{}) { c.entries[k] = &Entry{ val, - time.Now(), + expiration, } } @@ -71,7 +69,7 @@ func (c *Cache) evict() { // expired checks if an entry is expired func (c *Cache) expired(e *Entry) bool { - return e.inserted.Add(c.ttl).Before(time.Now()) + return e.expiration.Before(time.Now()) } // fits returns whether the cache fits more entries. diff --git a/proxy/pkg/middleware/oidc_auth.go b/proxy/pkg/middleware/oidc_auth.go index ec9aad884e..de44511e7d 100644 --- a/proxy/pkg/middleware/oidc_auth.go +++ b/proxy/pkg/middleware/oidc_auth.go @@ -2,8 +2,12 @@ package middleware import ( "context" + "encoding/json" "net/http" "strings" + "time" + + "github.com/dgrijalva/jwt-go" gOidc "github.com/coreos/go-oidc" "github.com/owncloud/ocis/ocis-pkg/log" @@ -20,31 +24,30 @@ type OIDCProvider interface { // OIDCAuth provides a middleware to check access secured by a static token. func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { options := newOptions(optionSetters...) - tokenCache := cache.NewCache( - cache.Size(options.UserinfoCacheSize), - cache.TTL(options.UserinfoCacheTTL), - ) + tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize)) return func(next http.Handler) http.Handler { return &oidcAuth{ - next: next, - logger: options.Logger, - providerFunc: options.OIDCProviderFunc, - httpClient: options.HTTPClient, - oidcIss: options.OIDCIss, - tokenCache: &tokenCache, + next: next, + logger: options.Logger, + providerFunc: options.OIDCProviderFunc, + httpClient: options.HTTPClient, + oidcIss: options.OIDCIss, + tokenCache: &tokenCache, + tokenCacheTTL: options.UserinfoCacheTTL, } } } type oidcAuth struct { - next http.Handler - logger log.Logger - provider OIDCProvider - providerFunc func() (OIDCProvider, error) - httpClient *http.Client - oidcIss string - tokenCache *cache.Cache + next http.Handler + logger log.Logger + provider OIDCProvider + providerFunc func() (OIDCProvider, error) + httpClient *http.Client + oidcIss string + tokenCache *cache.Cache + tokenCacheTTL time.Duration } func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -83,7 +86,7 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa AccessToken: token, } - userInfo, err := m.provider.UserInfo( + userInfo, err := m.getProvider().UserInfo( context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient), oauth2.StaticTokenSource(oauth2Token), ) @@ -99,12 +102,13 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa return } - m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo") - //TODO: This should be read from the token instead of config claims.Iss = m.oidcIss - m.tokenCache.Set(token, claims) + expiration := m.extractExpiration(token) + m.tokenCache.Set(token, claims, expiration) + + m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo") return } @@ -117,6 +121,31 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa return } +// extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt. +// defaults to the configured fallback TTL. +// TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL. +func (m oidcAuth) extractExpiration(token string) time.Time { + defaultExpiration := time.Now().Add(m.tokenCacheTTL) + + s := strings.SplitN(token, ".", 4) + if len(s) != 3 { + return defaultExpiration + } + + b, err := jwt.DecodeSegment(s[1]) + if err != nil { + return defaultExpiration + } + + at := &jwt.StandardClaims{} + err = json.Unmarshal(b, at) + if err != nil || at.ExpiresAt == 0 { + return defaultExpiration + } + + return time.Unix(at.ExpiresAt, 0) +} + func (m oidcAuth) shouldServe(req *http.Request) bool { header := req.Header.Get("Authorization")