mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-05-03 01:09:54 -05:00
First implementation for userinfo cache without config
This commit is contained in:
@@ -3,33 +3,31 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/go-micro/v2/client"
|
||||
"github.com/owncloud/ocis/accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/config"
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetAccountSuccess(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
|
||||
t.Errorf("expected an account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountInternalError(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
@@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
|
||||
@@ -12,5 +12,5 @@ var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
// ErrInternal is returned if something went wrong
|
||||
ErrInternal = errors.New("internal error")
|
||||
ErrInternal = errors.New("internal error")
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccountsKey declares the svcKey for the Accounts service.
|
||||
AccountsKey = "accounts"
|
||||
)
|
||||
|
||||
var (
|
||||
// svcCache caches requests for given services to prevent round trips to the service
|
||||
svcCache = cache.NewCache(
|
||||
cache.Size(256),
|
||||
)
|
||||
)
|
||||
@@ -2,12 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
gOidc "github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
@@ -18,6 +20,10 @@ type OIDCProvider interface {
|
||||
// OIDCAuth provides a middleware to check access secured by a static token.
|
||||
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
tokenCache := cache.NewCache(
|
||||
cache.Size(options.TokenCacheSize),
|
||||
cache.TTL(options.TokenCacheTTL),
|
||||
)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return &oidcAuth{
|
||||
@@ -26,6 +32,7 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
tokenCache: &tokenCache,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +44,7 @@ type oidcAuth struct {
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
tokenCache *cache.Cache
|
||||
}
|
||||
|
||||
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@@ -64,36 +72,48 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hit := m.tokenCache.Get(token)
|
||||
var claims oidc.StandardClaims
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if hit == nil {
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
|
||||
m.tokenCache.Set(token, claims)
|
||||
} else {
|
||||
var ok = false
|
||||
if claims, ok = hit.V.(oidc.StandardClaims); !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
}
|
||||
|
||||
// inject claims to the request context for the account_uuid middleware.
|
||||
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
// store claims in context
|
||||
// uses the original context, not the one with probably reduced security
|
||||
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
|
||||
|
||||
@@ -3,17 +3,16 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOIDCAuthMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
m := OIDCAuth(
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
|
||||
@@ -41,6 +42,10 @@ type Options struct {
|
||||
AutoprovisionAccounts bool
|
||||
// EnableBasicAuth to allow basic auth
|
||||
EnableBasicAuth bool
|
||||
// TokenCacheSize defines the max number of entries in the token cache
|
||||
TokenCacheSize int
|
||||
// TokenCacheTTL sets the max cache duration for the token cache
|
||||
TokenCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
@@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option {
|
||||
o.EnableBasicAuth = enableBasicAuth
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheSize provides a function to set the TokenCacheSize
|
||||
func TokenCacheSize(size int) Option {
|
||||
return func(o *Options) {
|
||||
o.TokenCacheSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheTTL provides a function to set the TokenCacheTTL
|
||||
func TokenCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.TokenCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user