Use expiration from access token if available

This commit is contained in:
Benedikt Kulmann
2020-11-18 12:08:23 +01:00
parent a410d40166
commit 08e218aa3e
2 changed files with 55 additions and 28 deletions

View File

@@ -7,15 +7,14 @@ import (
// Entry represents an entry on the cache. You can type assert on V.
type Entry struct {
V interface{}
inserted time.Time
V interface{}
expiration time.Time
}
// Cache is a barebones cache implementation.
type Cache struct {
entries map[string]*Entry
size int
ttl time.Duration
m sync.Mutex
}
@@ -25,7 +24,6 @@ func NewCache(o ...Option) Cache {
return Cache{
size: opts.size,
ttl: opts.ttl,
entries: map[string]*Entry{},
}
}
@@ -46,7 +44,7 @@ func (c *Cache) Get(k string) *Entry {
}
// Set sets a roleID / role-bundle.
func (c *Cache) Set(k string, val interface{}) {
func (c *Cache) Set(k string, val interface{}, expiration time.Time) {
c.m.Lock()
defer c.m.Unlock()
@@ -56,7 +54,7 @@ func (c *Cache) Set(k string, val interface{}) {
c.entries[k] = &Entry{
val,
time.Now(),
expiration,
}
}
@@ -71,7 +69,7 @@ func (c *Cache) evict() {
// expired checks if an entry is expired
func (c *Cache) expired(e *Entry) bool {
return e.inserted.Add(c.ttl).Before(time.Now())
return e.expiration.Before(time.Now())
}
// fits returns whether the cache fits more entries.

View File

@@ -2,8 +2,12 @@ package middleware
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
gOidc "github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
@@ -20,31 +24,30 @@ type OIDCProvider interface {
// OIDCAuth provides a middleware to check access secured by a static token.
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
options := newOptions(optionSetters...)
tokenCache := cache.NewCache(
cache.Size(options.UserinfoCacheSize),
cache.TTL(options.UserinfoCacheTTL),
)
tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize))
return func(next http.Handler) http.Handler {
return &oidcAuth{
next: next,
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
next: next,
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
tokenCacheTTL: options.UserinfoCacheTTL,
}
}
}
type oidcAuth struct {
next http.Handler
logger log.Logger
provider OIDCProvider
providerFunc func() (OIDCProvider, error)
httpClient *http.Client
oidcIss string
tokenCache *cache.Cache
next http.Handler
logger log.Logger
provider OIDCProvider
providerFunc func() (OIDCProvider, error)
httpClient *http.Client
oidcIss string
tokenCache *cache.Cache
tokenCacheTTL time.Duration
}
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@@ -83,7 +86,7 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
AccessToken: token,
}
userInfo, err := m.provider.UserInfo(
userInfo, err := m.getProvider().UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
oauth2.StaticTokenSource(oauth2Token),
)
@@ -99,12 +102,13 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
return
}
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
//TODO: This should be read from the token instead of config
claims.Iss = m.oidcIss
m.tokenCache.Set(token, claims)
expiration := m.extractExpiration(token)
m.tokenCache.Set(token, claims, expiration)
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
return
}
@@ -117,6 +121,31 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
return
}
// extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt.
// defaults to the configured fallback TTL.
// TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL.
func (m oidcAuth) extractExpiration(token string) time.Time {
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
s := strings.SplitN(token, ".", 4)
if len(s) != 3 {
return defaultExpiration
}
b, err := jwt.DecodeSegment(s[1])
if err != nil {
return defaultExpiration
}
at := &jwt.StandardClaims{}
err = json.Unmarshal(b, at)
if err != nil || at.ExpiresAt == 0 {
return defaultExpiration
}
return time.Unix(at.ExpiresAt, 0)
}
func (m oidcAuth) shouldServe(req *http.Request) bool {
header := req.Header.Get("Authorization")