mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-04 03:09:33 -06:00
7
changelog/unreleased/proxy-cache-userinfo.md
Normal file
7
changelog/unreleased/proxy-cache-userinfo.md
Normal file
@@ -0,0 +1,7 @@
|
||||
Enhancement: Cache userinfo in proxy
|
||||
|
||||
Tags: proxy
|
||||
|
||||
We introduced caching for the userinfo response. The token expiration is used for cache invalidation if available. Otherwise we fall back to a preconfigured TTL (default 10 seconds).
|
||||
|
||||
https://github.com/owncloud/ocis/pull/877
|
||||
@@ -4,9 +4,8 @@
|
||||
clients:
|
||||
- id: phoenix
|
||||
name: ownCloud web app
|
||||
application_type: web
|
||||
insecure: yes
|
||||
trusted: yes
|
||||
insecure: yes
|
||||
redirect_uris:
|
||||
- https://localhost:9200/
|
||||
- https://localhost:9200/oidc-callback.html
|
||||
@@ -19,9 +18,8 @@ clients:
|
||||
- http://localhost:9100
|
||||
|
||||
- id: ocis-explorer.js
|
||||
name: OCIS Graph Explorer
|
||||
name: oCIS Graph Explorer
|
||||
trusted: yes
|
||||
application_type: web
|
||||
insecure: yes
|
||||
|
||||
- id: xdXOt13JKxym1B1QcEncf2XDkLAexMBFwiT9j6EfhhHFJhs2KM9jbjTmf8JBXE69
|
||||
|
||||
File diff suppressed because one or more lines are too long
83
proxy/pkg/cache/cache.go
vendored
83
proxy/pkg/cache/cache.go
vendored
@@ -1,19 +1,19 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Entry represents an entry on the cache. You can type assert on V.
|
||||
type Entry struct {
|
||||
V interface{}
|
||||
Valid bool
|
||||
V interface{}
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
// Cache is a barebones cache implementation.
|
||||
type Cache struct {
|
||||
entries map[string]map[string]Entry
|
||||
entries map[string]*Entry
|
||||
size int
|
||||
m sync.Mutex
|
||||
}
|
||||
@@ -24,78 +24,55 @@ func NewCache(o ...Option) Cache {
|
||||
|
||||
return Cache{
|
||||
size: opts.size,
|
||||
entries: map[string]map[string]Entry{},
|
||||
entries: map[string]*Entry{},
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets an entry on a service `svcKey` by a give `key`.
|
||||
func (c *Cache) Get(svcKey, key string) (*Entry, error) {
|
||||
var value Entry
|
||||
ok := true
|
||||
|
||||
// Get gets a role-bundle by a given `roleID`.
|
||||
func (c *Cache) Get(k string) *Entry {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if value, ok = c.entries[svcKey][key]; !ok {
|
||||
return nil, fmt.Errorf("invalid service key: `%v`", key)
|
||||
if _, ok := c.entries[k]; ok {
|
||||
if c.expired(c.entries[k]) {
|
||||
delete(c.entries, k)
|
||||
return nil
|
||||
}
|
||||
return c.entries[k]
|
||||
}
|
||||
|
||||
return &value, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set sets a key / value. It lets a service add entries on a request basis.
|
||||
func (c *Cache) Set(svcKey, key string, val interface{}) error {
|
||||
// Set sets a roleID / role-bundle.
|
||||
func (c *Cache) Set(k string, val interface{}, expiration time.Time) {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if !c.fits() {
|
||||
return fmt.Errorf("cache is full")
|
||||
c.evict()
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey]; !ok {
|
||||
c.entries[svcKey] = map[string]Entry{}
|
||||
c.entries[k] = &Entry{
|
||||
val,
|
||||
expiration,
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey][key]; ok {
|
||||
return fmt.Errorf("key `%v` already exists", key)
|
||||
}
|
||||
|
||||
c.entries[svcKey][key] = Entry{
|
||||
V: val,
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate invalidates a cache Entry by key.
|
||||
func (c *Cache) Invalidate(svcKey, key string) error {
|
||||
r, err := c.Get(svcKey, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Valid = false
|
||||
c.entries[svcKey][key] = *r
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evict frees memory from the cache by removing invalid keys. It is a noop.
|
||||
func (c *Cache) Evict() {
|
||||
for _, v := range c.entries {
|
||||
for k, svcEntry := range v {
|
||||
if !svcEntry.Valid {
|
||||
delete(v, k)
|
||||
}
|
||||
// evict frees memory from the cache by removing entries that exceeded the cache TTL.
|
||||
func (c *Cache) evict() {
|
||||
for i := range c.entries {
|
||||
if c.expired(c.entries[i]) {
|
||||
delete(c.entries, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Length returns the amount of entries per service key.
|
||||
func (c *Cache) Length(k string) int {
|
||||
return len(c.entries[k])
|
||||
// expired checks if an entry is expired
|
||||
func (c *Cache) expired(e *Entry) bool {
|
||||
return e.expiration.Before(time.Now())
|
||||
}
|
||||
|
||||
// fits returns whether the cache fits more entries.
|
||||
func (c *Cache) fits() bool {
|
||||
return c.size >= len(c.entries)
|
||||
return c.size > len(c.entries)
|
||||
}
|
||||
|
||||
99
proxy/pkg/cache/cache_test.go
vendored
99
proxy/pkg/cache/cache_test.go
vendored
@@ -1,99 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Prevents from invalid import cycle.
|
||||
type AccountsCacheEntry struct {
|
||||
Email string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.Length("accounts") != 1 {
|
||||
t.Errorf("expected length 1 got `%v`", len(c.entries))
|
||||
}
|
||||
|
||||
item, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok {
|
||||
t.Errorf("invalid cached value type")
|
||||
} else {
|
||||
if cachedEntry.Email != "hello@foo.bar" {
|
||||
t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvict(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if v.Valid {
|
||||
t.Errorf("cache key unexpected valid state")
|
||||
}
|
||||
|
||||
c.Evict()
|
||||
|
||||
if c.Length("accounts") != 0 {
|
||||
t.Errorf("expected length 0 got `%v`", len(c.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
svcCache := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := svcCache.Set("accounts", "node", "0.0.0.0:1234")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
raw, err := svcCache.Get("accounts", "node")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, ok := raw.V.(string)
|
||||
if !ok {
|
||||
t.Errorf("invalid type on service node key")
|
||||
}
|
||||
|
||||
if v != "0.0.0.0:1234" {
|
||||
t.Errorf("expected `0.0.0.0:1234` got `%v`", v)
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
|
||||
"github.com/oklog/run"
|
||||
openzipkin "github.com/openzipkin/zipkin-go"
|
||||
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
|
||||
acc "github.com/owncloud/ocis/accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/config"
|
||||
"github.com/owncloud/ocis/proxy/pkg/cs3"
|
||||
"github.com/owncloud/ocis/proxy/pkg/flagset"
|
||||
@@ -281,6 +281,8 @@ func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alic
|
||||
}),
|
||||
middleware.HTTPClient(oidcHTTPClient),
|
||||
middleware.OIDCIss(cfg.OIDC.Issuer),
|
||||
middleware.TokenCacheSize(cfg.OIDC.UserinfoCache.Size),
|
||||
middleware.TokenCacheTTL(time.Second*time.Duration(cfg.OIDC.UserinfoCache.TTL)),
|
||||
),
|
||||
middleware.BasicAuth(
|
||||
middleware.Logger(l),
|
||||
|
||||
@@ -83,6 +83,12 @@ type Reva struct {
|
||||
Address string
|
||||
}
|
||||
|
||||
// Cache is a TTL cache configuration.
|
||||
type Cache struct {
|
||||
Size int
|
||||
TTL int
|
||||
}
|
||||
|
||||
// Config combines all available configuration parts.
|
||||
type Config struct {
|
||||
File string
|
||||
@@ -105,8 +111,9 @@ type Config struct {
|
||||
// OIDC is the config for the OpenID-Connect middleware. If set the proxy will try to authenticate every request
|
||||
// with the configured oidc-provider
|
||||
type OIDC struct {
|
||||
Issuer string
|
||||
Insecure bool
|
||||
Issuer string
|
||||
Insecure bool
|
||||
UserinfoCache Cache
|
||||
}
|
||||
|
||||
// PolicySelector is the toplevel-configuration for different selectors
|
||||
|
||||
@@ -202,6 +202,20 @@ func ServerWithConfig(cfg *config.Config) []cli.Flag {
|
||||
EnvVars: []string{"PROXY_OIDC_INSECURE"},
|
||||
Destination: &cfg.OIDC.Insecure,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "oidc-userinfo-cache-tll",
|
||||
Value: 10,
|
||||
Usage: "Fallback TTL in seconds for caching userinfo, when no token lifetime can be identified",
|
||||
EnvVars: []string{"PROXY_OIDC_USERINFO_CACHE_TTL"},
|
||||
Destination: &cfg.OIDC.UserinfoCache.TTL,
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "oidc-userinfo-cache-size",
|
||||
Value: 1024,
|
||||
Usage: "Max entries for caching userinfo",
|
||||
EnvVars: []string{"PROXY_OIDC_USERINFO_CACHE_SIZE"},
|
||||
Destination: &cfg.OIDC.UserinfoCache.Size,
|
||||
},
|
||||
|
||||
&cli.BoolFlag{
|
||||
Name: "autoprovision-accounts",
|
||||
|
||||
@@ -3,33 +3,31 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/go-micro/v2/client"
|
||||
"github.com/owncloud/ocis/accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/config"
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetAccountSuccess(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
|
||||
t.Errorf("expected an account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountInternalError(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
@@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
|
||||
@@ -12,5 +12,5 @@ var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
// ErrInternal is returned if something went wrong
|
||||
ErrInternal = errors.New("internal error")
|
||||
ErrInternal = errors.New("internal error")
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccountsKey declares the svcKey for the Accounts service.
|
||||
AccountsKey = "accounts"
|
||||
)
|
||||
|
||||
var (
|
||||
// svcCache caches requests for given services to prevent round trips to the service
|
||||
svcCache = cache.NewCache(
|
||||
cache.Size(256),
|
||||
)
|
||||
)
|
||||
@@ -2,12 +2,18 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
|
||||
gOidc "github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
@@ -18,25 +24,30 @@ type OIDCProvider interface {
|
||||
// OIDCAuth provides a middleware to check access secured by a static token.
|
||||
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize))
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return &oidcAuth{
|
||||
next: next,
|
||||
logger: options.Logger,
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
next: next,
|
||||
logger: options.Logger,
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
tokenCache: &tokenCache,
|
||||
tokenCacheTTL: options.UserinfoCacheTTL,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type oidcAuth struct {
|
||||
next http.Handler
|
||||
logger log.Logger
|
||||
provider OIDCProvider
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
next http.Handler
|
||||
logger log.Logger
|
||||
provider OIDCProvider
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
tokenCache *cache.Cache
|
||||
tokenCacheTTL time.Duration
|
||||
}
|
||||
|
||||
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@@ -46,59 +57,95 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if m.provider == nil {
|
||||
// Lazily initialize a provider
|
||||
|
||||
// provider needs to be cached as when it is created
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
provider, err := m.providerFunc()
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.provider = provider
|
||||
}
|
||||
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var claims oidc.StandardClaims
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
if m.getProvider() == nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
claims, status := m.getClaims(token, req)
|
||||
if status != 0 {
|
||||
w.WriteHeader(status)
|
||||
return
|
||||
}
|
||||
|
||||
// inject claims to the request context for the account_uuid middleware.
|
||||
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
// store claims in context
|
||||
// uses the original context, not the one with probably reduced security
|
||||
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
|
||||
}
|
||||
|
||||
func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.StandardClaims, status int) {
|
||||
hit := m.tokenCache.Get(token)
|
||||
if hit == nil {
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.getProvider().UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
status = http.StatusUnauthorized
|
||||
return
|
||||
}
|
||||
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
status = http.StatusInternalServerError
|
||||
return
|
||||
}
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
|
||||
expiration := m.extractExpiration(token)
|
||||
m.tokenCache.Set(token, claims, expiration)
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
|
||||
return
|
||||
}
|
||||
|
||||
var ok = false
|
||||
if claims, ok = hit.V.(oidc.StandardClaims); !ok {
|
||||
status = http.StatusInternalServerError
|
||||
return
|
||||
}
|
||||
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
return
|
||||
}
|
||||
|
||||
// extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt.
|
||||
// defaults to the configured fallback TTL.
|
||||
// TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL.
|
||||
func (m oidcAuth) extractExpiration(token string) time.Time {
|
||||
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
|
||||
|
||||
s := strings.SplitN(token, ".", 4)
|
||||
if len(s) != 3 {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
b, err := jwt.DecodeSegment(s[1])
|
||||
if err != nil {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
at := &jwt.StandardClaims{}
|
||||
err = json.Unmarshal(b, at)
|
||||
if err != nil || at.ExpiresAt == 0 {
|
||||
return defaultExpiration
|
||||
}
|
||||
|
||||
return time.Unix(at.ExpiresAt, 0)
|
||||
}
|
||||
|
||||
func (m oidcAuth) shouldServe(req *http.Request) bool {
|
||||
header := req.Header.Get("Authorization")
|
||||
|
||||
@@ -107,6 +154,7 @@ func (m oidcAuth) shouldServe(req *http.Request) bool {
|
||||
}
|
||||
|
||||
// todo: looks dirty, check later
|
||||
// TODO: make a PR to coreos/go-oidc for exposing userinfo endpoint on provider, see https://github.com/coreos/go-oidc/issues/248
|
||||
for _, ignoringPath := range []string{"/konnect/v1/userinfo"} {
|
||||
if req.URL.Path == ignoringPath {
|
||||
return false
|
||||
@@ -115,3 +163,21 @@ func (m oidcAuth) shouldServe(req *http.Request) bool {
|
||||
|
||||
return strings.HasPrefix(header, "Bearer ")
|
||||
}
|
||||
|
||||
func (m oidcAuth) getProvider() OIDCProvider {
|
||||
if m.provider == nil {
|
||||
// Lazily initialize a provider
|
||||
|
||||
// provider needs to be cached as when it is created
|
||||
// it will fetch the keys from the issuer using the .well-known
|
||||
// endpoint
|
||||
provider, err := m.providerFunc()
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
||||
return nil
|
||||
}
|
||||
|
||||
m.provider = provider
|
||||
}
|
||||
return m.provider
|
||||
}
|
||||
|
||||
@@ -3,17 +3,16 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOIDCAuthMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
m := OIDCAuth(
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
|
||||
@@ -27,7 +28,7 @@ type Options struct {
|
||||
AccountsClient acc.AccountsService
|
||||
// SettingsRoleService for the roles API in settings
|
||||
SettingsRoleService settings.RoleService
|
||||
// OIDCProviderFunc to lazily initialize a provider, must be set for the oidcProvider middleware
|
||||
// OIDCProviderFunc to lazily initialize an oidc provider, must be set for the oidc_auth middleware
|
||||
OIDCProviderFunc func() (OIDCProvider, error)
|
||||
// OIDCIss is the oidcAuth-issuer
|
||||
OIDCIss string
|
||||
@@ -41,6 +42,10 @@ type Options struct {
|
||||
AutoprovisionAccounts bool
|
||||
// EnableBasicAuth to allow basic auth
|
||||
EnableBasicAuth bool
|
||||
// UserinfoCacheSize defines the max number of entries in the userinfo cache, intended for the oidc_auth middleware
|
||||
UserinfoCacheSize int
|
||||
// UserinfoCacheTTL sets the max cache duration for the userinfo cache, intended for the oidc_auth middleware
|
||||
UserinfoCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
@@ -89,7 +94,7 @@ func SettingsRoleService(rc settings.RoleService) Option {
|
||||
}
|
||||
}
|
||||
|
||||
// OIDCProviderFunc provides a function to set the the oidcAuth provider function option.
|
||||
// OIDCProviderFunc provides a function to set the the oidc provider function option.
|
||||
func OIDCProviderFunc(f func() (OIDCProvider, error)) Option {
|
||||
return func(o *Options) {
|
||||
o.OIDCProviderFunc = f
|
||||
@@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option {
|
||||
o.EnableBasicAuth = enableBasicAuth
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheSize provides a function to set the TokenCacheSize
|
||||
func TokenCacheSize(size int) Option {
|
||||
return func(o *Options) {
|
||||
o.UserinfoCacheSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheTTL provides a function to set the TokenCacheTTL
|
||||
func TokenCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.UserinfoCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user