mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-02-22 05:29:01 -06:00
First implementation for userinfo cache without config
This commit is contained in:
85
proxy/pkg/cache/cache.go
vendored
85
proxy/pkg/cache/cache.go
vendored
@@ -1,20 +1,21 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Entry represents an entry on the cache. You can type assert on V.
|
||||
type Entry struct {
|
||||
V interface{}
|
||||
Valid bool
|
||||
V interface{}
|
||||
inserted time.Time
|
||||
}
|
||||
|
||||
// Cache is a barebones cache implementation.
|
||||
type Cache struct {
|
||||
entries map[string]map[string]Entry
|
||||
entries map[string]*Entry
|
||||
size int
|
||||
ttl time.Duration
|
||||
m sync.Mutex
|
||||
}
|
||||
|
||||
@@ -24,78 +25,56 @@ func NewCache(o ...Option) Cache {
|
||||
|
||||
return Cache{
|
||||
size: opts.size,
|
||||
entries: map[string]map[string]Entry{},
|
||||
ttl: opts.ttl,
|
||||
entries: map[string]*Entry{},
|
||||
}
|
||||
}
|
||||
|
||||
// Get gets an entry on a service `svcKey` by a give `key`.
|
||||
func (c *Cache) Get(svcKey, key string) (*Entry, error) {
|
||||
var value Entry
|
||||
ok := true
|
||||
|
||||
// Get gets a role-bundle by a given `roleID`.
|
||||
func (c *Cache) Get(k string) *Entry {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if value, ok = c.entries[svcKey][key]; !ok {
|
||||
return nil, fmt.Errorf("invalid service key: `%v`", key)
|
||||
if _, ok := c.entries[k]; ok {
|
||||
if c.expired(c.entries[k]) {
|
||||
delete(c.entries, k)
|
||||
return nil
|
||||
}
|
||||
return c.entries[k]
|
||||
}
|
||||
|
||||
return &value, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set sets a key / value. It lets a service add entries on a request basis.
|
||||
func (c *Cache) Set(svcKey, key string, val interface{}) error {
|
||||
// Set sets a roleID / role-bundle.
|
||||
func (c *Cache) Set(k string, val interface{}) {
|
||||
c.m.Lock()
|
||||
defer c.m.Unlock()
|
||||
|
||||
if !c.fits() {
|
||||
return fmt.Errorf("cache is full")
|
||||
c.evict()
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey]; !ok {
|
||||
c.entries[svcKey] = map[string]Entry{}
|
||||
c.entries[k] = &Entry{
|
||||
val,
|
||||
time.Now(),
|
||||
}
|
||||
|
||||
if _, ok := c.entries[svcKey][key]; ok {
|
||||
return fmt.Errorf("key `%v` already exists", key)
|
||||
}
|
||||
|
||||
c.entries[svcKey][key] = Entry{
|
||||
V: val,
|
||||
Valid: true,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate invalidates a cache Entry by key.
|
||||
func (c *Cache) Invalidate(svcKey, key string) error {
|
||||
r, err := c.Get(svcKey, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.Valid = false
|
||||
c.entries[svcKey][key] = *r
|
||||
return nil
|
||||
}
|
||||
|
||||
// Evict frees memory from the cache by removing invalid keys. It is a noop.
|
||||
func (c *Cache) Evict() {
|
||||
for _, v := range c.entries {
|
||||
for k, svcEntry := range v {
|
||||
if !svcEntry.Valid {
|
||||
delete(v, k)
|
||||
}
|
||||
// evict frees memory from the cache by removing entries that exceeded the cache TTL.
|
||||
func (c *Cache) evict() {
|
||||
for i := range c.entries {
|
||||
if c.expired(c.entries[i]) {
|
||||
delete(c.entries, i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Length returns the amount of entries per service key.
|
||||
func (c *Cache) Length(k string) int {
|
||||
return len(c.entries[k])
|
||||
// expired checks if an entry is expired
|
||||
func (c *Cache) expired(e *Entry) bool {
|
||||
return e.inserted.Add(c.ttl).Before(time.Now())
|
||||
}
|
||||
|
||||
// fits returns whether the cache fits more entries.
|
||||
func (c *Cache) fits() bool {
|
||||
return c.size >= len(c.entries)
|
||||
return c.size > len(c.entries)
|
||||
}
|
||||
|
||||
99
proxy/pkg/cache/cache_test.go
vendored
99
proxy/pkg/cache/cache_test.go
vendored
@@ -1,99 +0,0 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Prevents from invalid import cycle.
|
||||
type AccountsCacheEntry struct {
|
||||
Email string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if c.Length("accounts") != 1 {
|
||||
t.Errorf("expected length 1 got `%v`", len(c.entries))
|
||||
}
|
||||
|
||||
item, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok {
|
||||
t.Errorf("invalid cached value type")
|
||||
} else {
|
||||
if cachedEntry.Email != "hello@foo.bar" {
|
||||
t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvict(t *testing.T) {
|
||||
c := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
|
||||
Email: "hello@foo.bar",
|
||||
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
|
||||
}); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, err := c.Get("accounts", "hello@foo.bar")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if v.Valid {
|
||||
t.Errorf("cache key unexpected valid state")
|
||||
}
|
||||
|
||||
c.Evict()
|
||||
|
||||
if c.Length("accounts") != 0 {
|
||||
t.Errorf("expected length 0 got `%v`", len(c.entries))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
svcCache := NewCache(
|
||||
Size(256),
|
||||
)
|
||||
|
||||
err := svcCache.Set("accounts", "node", "0.0.0.0:1234")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
raw, err := svcCache.Get("accounts", "node")
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
v, ok := raw.V.(string)
|
||||
if !ok {
|
||||
t.Errorf("invalid type on service node key")
|
||||
}
|
||||
|
||||
if v != "0.0.0.0:1234" {
|
||||
t.Errorf("expected `0.0.0.0:1234` got `%v`", v)
|
||||
}
|
||||
}
|
||||
@@ -15,12 +15,12 @@ import (
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/justinas/alice"
|
||||
"github.com/micro/cli/v2"
|
||||
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
|
||||
"github.com/oklog/run"
|
||||
openzipkin "github.com/openzipkin/zipkin-go"
|
||||
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
|
||||
acc "github.com/owncloud/ocis/accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/config"
|
||||
"github.com/owncloud/ocis/proxy/pkg/cs3"
|
||||
"github.com/owncloud/ocis/proxy/pkg/flagset"
|
||||
@@ -281,6 +281,8 @@ func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alic
|
||||
}),
|
||||
middleware.HTTPClient(oidcHTTPClient),
|
||||
middleware.OIDCIss(cfg.OIDC.Issuer),
|
||||
middleware.TokenCacheSize(1024),
|
||||
middleware.TokenCacheTTL(time.Second*10),
|
||||
),
|
||||
middleware.BasicAuth(
|
||||
middleware.Logger(l),
|
||||
|
||||
@@ -3,33 +3,31 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/micro/go-micro/v2/client"
|
||||
"github.com/owncloud/ocis/accounts/pkg/proto/v0"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/config"
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetAccountSuccess(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
|
||||
t.Errorf("expected an account")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountInternalError(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
|
||||
t.Errorf("expected an internal server error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
@@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "failure")
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
m := AccountResolver(
|
||||
Logger(log.NewLogger()),
|
||||
|
||||
@@ -12,5 +12,5 @@ var (
|
||||
ErrUnauthorized = errors.New("unauthorized")
|
||||
|
||||
// ErrInternal is returned if something went wrong
|
||||
ErrInternal = errors.New("internal error")
|
||||
ErrInternal = errors.New("internal error")
|
||||
)
|
||||
@@ -1,17 +0,0 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
// AccountsKey declares the svcKey for the Accounts service.
|
||||
AccountsKey = "accounts"
|
||||
)
|
||||
|
||||
var (
|
||||
// svcCache caches requests for given services to prevent round trips to the service
|
||||
svcCache = cache.NewCache(
|
||||
cache.Size(256),
|
||||
)
|
||||
)
|
||||
@@ -2,12 +2,14 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
gOidc "github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"github.com/owncloud/ocis/ocis-pkg/oidc"
|
||||
"github.com/owncloud/ocis/proxy/pkg/cache"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OIDCProvider used to mock the oidc provider during tests
|
||||
@@ -18,6 +20,10 @@ type OIDCProvider interface {
|
||||
// OIDCAuth provides a middleware to check access secured by a static token.
|
||||
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
options := newOptions(optionSetters...)
|
||||
tokenCache := cache.NewCache(
|
||||
cache.Size(options.TokenCacheSize),
|
||||
cache.TTL(options.TokenCacheTTL),
|
||||
)
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return &oidcAuth{
|
||||
@@ -26,6 +32,7 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
||||
providerFunc: options.OIDCProviderFunc,
|
||||
httpClient: options.HTTPClient,
|
||||
oidcIss: options.OIDCIss,
|
||||
tokenCache: &tokenCache,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -37,6 +44,7 @@ type oidcAuth struct {
|
||||
providerFunc func() (OIDCProvider, error)
|
||||
httpClient *http.Client
|
||||
oidcIss string
|
||||
tokenCache *cache.Cache
|
||||
}
|
||||
|
||||
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
@@ -64,36 +72,48 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hit := m.tokenCache.Get(token)
|
||||
var claims oidc.StandardClaims
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if hit == nil {
|
||||
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
|
||||
oauth2Token := &oauth2.Token{
|
||||
AccessToken: token,
|
||||
}
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
userInfo, err := m.provider.UserInfo(
|
||||
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
||||
oauth2.StaticTokenSource(oauth2Token),
|
||||
)
|
||||
if err != nil {
|
||||
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
|
||||
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
//TODO: This should be read from the token instead of config
|
||||
claims.Iss = m.oidcIss
|
||||
|
||||
m.tokenCache.Set(token, claims)
|
||||
} else {
|
||||
var ok = false
|
||||
if claims, ok = hit.V.(oidc.StandardClaims); !ok {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
||||
}
|
||||
|
||||
// inject claims to the request context for the account_uuid middleware.
|
||||
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
|
||||
|
||||
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
|
||||
|
||||
// store claims in context
|
||||
// uses the original context, not the one with probably reduced security
|
||||
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
|
||||
|
||||
@@ -3,17 +3,16 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/owncloud/ocis/ocis-pkg/log"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func TestOIDCAuthMiddleware(t *testing.T) {
|
||||
svcCache.Invalidate(AccountsKey, "success")
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
m := OIDCAuth(
|
||||
|
||||
@@ -2,6 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
|
||||
|
||||
@@ -41,6 +42,10 @@ type Options struct {
|
||||
AutoprovisionAccounts bool
|
||||
// EnableBasicAuth to allow basic auth
|
||||
EnableBasicAuth bool
|
||||
// TokenCacheSize defines the max number of entries in the token cache
|
||||
TokenCacheSize int
|
||||
// TokenCacheTTL sets the max cache duration for the token cache
|
||||
TokenCacheTTL time.Duration
|
||||
}
|
||||
|
||||
// newOptions initializes the available default options.
|
||||
@@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option {
|
||||
o.EnableBasicAuth = enableBasicAuth
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheSize provides a function to set the TokenCacheSize
|
||||
func TokenCacheSize(size int) Option {
|
||||
return func(o *Options) {
|
||||
o.TokenCacheSize = size
|
||||
}
|
||||
}
|
||||
|
||||
// TokenCacheTTL provides a function to set the TokenCacheTTL
|
||||
func TokenCacheTTL(ttl time.Duration) Option {
|
||||
return func(o *Options) {
|
||||
o.TokenCacheTTL = ttl
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user