mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-02-18 03:18:52 -06:00
Use expiration from access token if available
This commit is contained in:
12
proxy/pkg/cache/cache.go
vendored
12
proxy/pkg/cache/cache.go
vendored
@@ -7,15 +7,14 @@ import (
|
||||
|
||||
// Entry represents an entry on the cache. You can type assert on V.
|
||||
type Entry struct {
|
||||
V interface{}
|
||||
inserted time.Time
|
||||
V interface{}
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
// Cache is a barebones cache implementation.
|
||||
type Cache struct {
|
||||
entries map[string]*Entry
|
||||
size int
|
||||
ttl time.Duration
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
@@ -25,7 +24,6 @@ func NewCache(o ...Option) Cache {
|
||||
|
||||
return Cache{
|
||||
size: opts.size,
|
||||
ttl: opts.ttl,
|
||||
entries: map[string]*Entry{},
|
||||
}
|
||||
}
|
||||
@@ -46,7 +44,7 @@ func (c *Cache) Get(k string) *Entry {
|
||||
}
|
||||
|
||||
// Set sets a roleID / role-bundle.
|
||||
func (c *Cache) Set(k string, val interface{}) {
|
||||
func (c *Cache) Set(k string, val interface{}, expiration time.Time) {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
@@ -56,7 +54,7 @@ func (c *Cache) Set(k string, val interface{}) {
|
||||
|
||||
c.entries[k] = &Entry{
|
||||
val,
|
||||
time.Now(),
|
||||
expiration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +69,7 @@ func (c *Cache) evict() {
|
||||
|
||||
// expired checks if an entry is expired
|
||||
func (c *Cache) expired(e *Entry) bool {
|
||||
return e.inserted.Add(c.ttl).Before(time.Now())
|
||||
return e.expiration.Before(time.Now())
|
||||
}
|
||||
|
||||
// fits returns whether the cache fits more entries.
|
||||
|
||||
@@ -2,8 +2,12 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
|
||||
gOidc "github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
@@ -20,31 +24,30 @@ type OIDCProvider interface {
|
||||
// OIDCAuth provides a middleware to check access secured by a static token.
|
||||
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
tokenCache := cache.NewCache(
|
||||
cache.Size(options.UserinfoCacheSize),
|
||||
cache.TTL(options.UserinfoCacheTTL),
|
||||
)
|
||||
tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize))
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return &oidcAuth{
|
||||
next: next,
|
||||
logger: options.Logger,
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
tokenCache: &tokenCache,
|
||||
next: next,
|
||||
logger: options.Logger,
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
tokenCache: &tokenCache,
|
||||
tokenCacheTTL: options.UserinfoCacheTTL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type oidcAuth struct {
|
||||
next http.Handler
|
||||
logger log.Logger
|
||||
provider OIDCProvider
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
tokenCache *cache.Cache
|
||||
next http.Handler
|
||||
logger log.Logger
|
||||
provider OIDCProvider
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
tokenCache *cache.Cache
|
||||
tokenCacheTTL time.Duration
|
||||
}
|
||||
|
||||
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@@ -83,7 +86,7 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
userInfo, err := m.getProvider().UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
@@ -99,12 +102,13 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
|
||||
m.tokenCache.Set(token, claims)
|
||||
expiration := m.extractExpiration(token)
|
||||
m.tokenCache.Set(token, claims, expiration)
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,6 +121,31 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.Standa
|
||||
return
|
||||
}
|
||||
|
||||
// extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt.
|
||||
// defaults to the configured fallback TTL.
|
||||
// TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL.
|
||||
func (m oidcAuth) extractExpiration(token string) time.Time {
|
||||
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
|
||||
|
||||
s := strings.SplitN(token, ".", 4)
|
||||
if len(s) != 3 {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
b, err := jwt.DecodeSegment(s[1])
|
||||
if err != nil {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
at := &jwt.StandardClaims{}
|
||||
err = json.Unmarshal(b, at)
|
||||
if err != nil || at.ExpiresAt == 0 {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
return time.Unix(at.ExpiresAt, 0)
|
||||
}
|
||||
|
||||
func (m oidcAuth) shouldServe(req *http.Request) bool {
|
||||
header := req.Header.Get("Authorization")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user