diff --git a/proxy/pkg/cache/cache.go b/proxy/pkg/cache/cache.go index cdc95d1c4e..2c1d7e93ec 100644 --- a/proxy/pkg/cache/cache.go +++ b/proxy/pkg/cache/cache.go @@ -1,20 +1,21 @@ package cache import ( - "fmt" "sync" + "time" ) // Entry represents an entry on the cache. You can type assert on V. type Entry struct { - V interface{} - Valid bool + V interface{} + inserted time.Time } // Cache is a barebones cache implementation. type Cache struct { - entries map[string]map[string]Entry + entries map[string]*Entry size int + ttl time.Duration m sync.Mutex } @@ -24,78 +25,56 @@ func NewCache(o ...Option) Cache { return Cache{ size: opts.size, - entries: map[string]map[string]Entry{}, + ttl: opts.ttl, + entries: map[string]*Entry{}, } } -// Get gets an entry on a service `svcKey` by a give `key`. -func (c *Cache) Get(svcKey, key string) (*Entry, error) { - var value Entry - ok := true - +// Get gets a role-bundle by a given `roleID`. +func (c *Cache) Get(k string) *Entry { c.m.Lock() defer c.m.Unlock() - if value, ok = c.entries[svcKey][key]; !ok { - return nil, fmt.Errorf("invalid service key: `%v`", key) + if _, ok := c.entries[k]; ok { + if c.expired(c.entries[k]) { + delete(c.entries, k) + return nil + } + return c.entries[k] } - - return &value, nil + return nil } -// Set sets a key / value. It lets a service add entries on a request basis. -func (c *Cache) Set(svcKey, key string, val interface{}) error { +// Set sets a roleID / role-bundle. +func (c *Cache) Set(k string, val interface{}) { c.m.Lock() defer c.m.Unlock() if !c.fits() { - return fmt.Errorf("cache is full") + c.evict() } - if _, ok := c.entries[svcKey]; !ok { - c.entries[svcKey] = map[string]Entry{} + c.entries[k] = &Entry{ + val, + time.Now(), } - - if _, ok := c.entries[svcKey][key]; ok { - return fmt.Errorf("key `%v` already exists", key) - } - - c.entries[svcKey][key] = Entry{ - V: val, - Valid: true, - } - - return nil } -// Invalidate invalidates a cache Entry by key. -func (c *Cache) Invalidate(svcKey, key string) error { - r, err := c.Get(svcKey, key) - if err != nil { - return err - } - - r.Valid = false - c.entries[svcKey][key] = *r - return nil -} - -// Evict frees memory from the cache by removing invalid keys. It is a noop. -func (c *Cache) Evict() { - for _, v := range c.entries { - for k, svcEntry := range v { - if !svcEntry.Valid { - delete(v, k) - } +// evict frees memory from the cache by removing entries that exceeded the cache TTL. +func (c *Cache) evict() { + for i := range c.entries { + if c.expired(c.entries[i]) { + delete(c.entries, i) } } } -// Length returns the amount of entries per service key. -func (c *Cache) Length(k string) int { - return len(c.entries[k]) +// expired checks if an entry is expired +func (c *Cache) expired(e *Entry) bool { + return e.inserted.Add(c.ttl).Before(time.Now()) } +// fits returns whether the cache fits more entries. func (c *Cache) fits() bool { - return c.size >= len(c.entries) + return c.size > len(c.entries) } diff --git a/proxy/pkg/cache/cache_test.go b/proxy/pkg/cache/cache_test.go deleted file mode 100644 index 7ac66a9b71..0000000000 --- a/proxy/pkg/cache/cache_test.go +++ /dev/null @@ -1,99 +0,0 @@ -package cache - -import ( - "testing" -) - -// Prevents from invalid import cycle. -type AccountsCacheEntry struct { - Email string - UUID string -} - -func TestSet(t *testing.T) { - c := NewCache( - Size(256), - ) - - err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{ - Email: "hello@foo.bar", - UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05", - }) - if err != nil { - t.Error(err) - } - - if c.Length("accounts") != 1 { - t.Errorf("expected length 1 got `%v`", len(c.entries)) - } - - item, err := c.Get("accounts", "hello@foo.bar") - if err != nil { - t.Error(err) - } - - if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok { - t.Errorf("invalid cached value type") - } else { - if cachedEntry.Email != "hello@foo.bar" { - t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email) - } - } -} - -func TestEvict(t *testing.T) { - c := NewCache( - Size(256), - ) - - if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{ - Email: "hello@foo.bar", - UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05", - }); err != nil { - t.Error(err) - } - - if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil { - t.Error(err) - } - - v, err := c.Get("accounts", "hello@foo.bar") - if err != nil { - t.Error(err) - } - - if v.Valid { - t.Errorf("cache key unexpected valid state") - } - - c.Evict() - - if c.Length("accounts") != 0 { - t.Errorf("expected length 0 got `%v`", len(c.entries)) - } -} - -func TestGet(t *testing.T) { - svcCache := NewCache( - Size(256), - ) - - err := svcCache.Set("accounts", "node", "0.0.0.0:1234") - if err != nil { - t.Error(err) - } - - raw, err := svcCache.Get("accounts", "node") - if err != nil { - t.Error(err) - } - - v, ok := raw.V.(string) - if !ok { - t.Errorf("invalid type on service node key") - } - - if v != "0.0.0.0:1234" { - t.Errorf("expected `0.0.0.0:1234` got `%v`", v) - } -} diff --git a/proxy/pkg/command/server.go b/proxy/pkg/command/server.go index 0765bc1016..092c85e3f7 100644 --- a/proxy/pkg/command/server.go +++ b/proxy/pkg/command/server.go @@ -15,12 +15,12 @@ import ( "github.com/coreos/go-oidc" "github.com/justinas/alice" "github.com/micro/cli/v2" - "github.com/owncloud/ocis/ocis-pkg/service/grpc" "github.com/oklog/run" openzipkin "github.com/openzipkin/zipkin-go" zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http" acc "github.com/owncloud/ocis/accounts/pkg/proto/v0" "github.com/owncloud/ocis/ocis-pkg/log" + "github.com/owncloud/ocis/ocis-pkg/service/grpc" "github.com/owncloud/ocis/proxy/pkg/config" "github.com/owncloud/ocis/proxy/pkg/cs3" "github.com/owncloud/ocis/proxy/pkg/flagset" @@ -281,6 +281,8 @@ func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alic }), middleware.HTTPClient(oidcHTTPClient), middleware.OIDCIss(cfg.OIDC.Issuer), + middleware.TokenCacheSize(1024), + middleware.TokenCacheTTL(time.Second*10), ), middleware.BasicAuth( middleware.Logger(l), diff --git a/proxy/pkg/middleware/account_resolver_test.go b/proxy/pkg/middleware/account_resolver_test.go index ff111829e7..f5ad54ede2 100644 --- a/proxy/pkg/middleware/account_resolver_test.go +++ b/proxy/pkg/middleware/account_resolver_test.go @@ -3,33 +3,31 @@ package middleware import ( "context" "fmt" + "net/http" + "net/http/httptest" + "testing" + "github.com/micro/go-micro/v2/client" "github.com/owncloud/ocis/accounts/pkg/proto/v0" "github.com/owncloud/ocis/ocis-pkg/log" "github.com/owncloud/ocis/ocis-pkg/oidc" "github.com/owncloud/ocis/proxy/pkg/config" settings "github.com/owncloud/ocis/settings/pkg/proto/v0" - "net/http" - "net/http/httptest" - "testing" ) func TestGetAccountSuccess(t *testing.T) { - svcCache.Invalidate(AccountsKey, "success") if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 { t.Errorf("expected an account") } } func TestGetAccountInternalError(t *testing.T) { - svcCache.Invalidate(AccountsKey, "failure") if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError { t.Errorf("expected an internal server error") } } func TestAccountResolverMiddleware(t *testing.T) { - svcCache.Invalidate(AccountsKey, "success") next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) m := AccountResolver( Logger(log.NewLogger()), @@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) { } func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) { - svcCache.Invalidate(AccountsKey, "failure") next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) m := AccountResolver( Logger(log.NewLogger()), diff --git a/proxy/pkg/middleware/middleware.go b/proxy/pkg/middleware/errors.go similarity index 87% rename from proxy/pkg/middleware/middleware.go rename to proxy/pkg/middleware/errors.go index c97da6e2b3..18cfcca0e9 100644 --- a/proxy/pkg/middleware/middleware.go +++ b/proxy/pkg/middleware/errors.go @@ -12,5 +12,5 @@ var ( ErrUnauthorized = errors.New("unauthorized") // ErrInternal is returned if something went wrong - ErrInternal = errors.New("internal error") + ErrInternal = errors.New("internal error") ) diff --git a/proxy/pkg/middleware/middleware_test.go b/proxy/pkg/middleware/middleware_test.go deleted file mode 100644 index 6bb707d179..0000000000 --- a/proxy/pkg/middleware/middleware_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package middleware - -import ( - "github.com/owncloud/ocis/proxy/pkg/cache" -) - -const ( - // AccountsKey declares the svcKey for the Accounts service. - AccountsKey = "accounts" -) - -var ( - // svcCache caches requests for given services to prevent round trips to the service - svcCache = cache.NewCache( - cache.Size(256), - ) -) diff --git a/proxy/pkg/middleware/oidc_auth.go b/proxy/pkg/middleware/oidc_auth.go index 7c805aab22..ca7df7553f 100644 --- a/proxy/pkg/middleware/oidc_auth.go +++ b/proxy/pkg/middleware/oidc_auth.go @@ -2,12 +2,14 @@ package middleware import ( "context" + "net/http" + "strings" + gOidc "github.com/coreos/go-oidc" "github.com/owncloud/ocis/ocis-pkg/log" "github.com/owncloud/ocis/ocis-pkg/oidc" + "github.com/owncloud/ocis/proxy/pkg/cache" "golang.org/x/oauth2" - "net/http" - "strings" ) // OIDCProvider used to mock the oidc provider during tests @@ -18,6 +20,10 @@ type OIDCProvider interface { // OIDCAuth provides a middleware to check access secured by a static token. func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { options := newOptions(optionSetters...) + tokenCache := cache.NewCache( + cache.Size(options.TokenCacheSize), + cache.TTL(options.TokenCacheTTL), + ) return func(next http.Handler) http.Handler { return &oidcAuth{ @@ -26,6 +32,7 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { providerFunc: options.OIDCProviderFunc, httpClient: options.HTTPClient, oidcIss: options.OIDCIss, + tokenCache: &tokenCache, } } } @@ -37,6 +44,7 @@ type oidcAuth struct { providerFunc func() (OIDCProvider, error) httpClient *http.Client oidcIss string + tokenCache *cache.Cache } func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) { @@ -64,36 +72,48 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) { token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") - // TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token) - oauth2Token := &oauth2.Token{ - AccessToken: token, - } - - userInfo, err := m.provider.UserInfo( - context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient), - oauth2.StaticTokenSource(oauth2Token), - ) - if err != nil { - m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo") - http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized) - return - } - + hit := m.tokenCache.Get(token) var claims oidc.StandardClaims - if err := userInfo.Claims(&claims); err != nil { - m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims") - w.WriteHeader(http.StatusInternalServerError) - return - } + if hit == nil { + // TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token) + oauth2Token := &oauth2.Token{ + AccessToken: token, + } - //TODO: This should be read from the token instead of config - claims.Iss = m.oidcIss + userInfo, err := m.provider.UserInfo( + context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient), + oauth2.StaticTokenSource(oauth2Token), + ) + if err != nil { + m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo") + http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized) + return + } + + if err := userInfo.Claims(&claims); err != nil { + m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims") + w.WriteHeader(http.StatusInternalServerError) + return + } + + m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo") + + //TODO: This should be read from the token instead of config + claims.Iss = m.oidcIss + + m.tokenCache.Set(token, claims) + } else { + var ok = false + if claims, ok = hit.V.(oidc.StandardClaims); !ok { + w.WriteHeader(http.StatusInternalServerError) + return + } + m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo") + } // inject claims to the request context for the account_uuid middleware. req = req.WithContext(oidc.NewContext(req.Context(), &claims)) - m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo") - // store claims in context // uses the original context, not the one with probably reduced security m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims))) diff --git a/proxy/pkg/middleware/oidc_auth_test.go b/proxy/pkg/middleware/oidc_auth_test.go index b4b899c1c1..e9fb3f05dd 100644 --- a/proxy/pkg/middleware/oidc_auth_test.go +++ b/proxy/pkg/middleware/oidc_auth_test.go @@ -3,17 +3,16 @@ package middleware import ( "context" "fmt" - "github.com/coreos/go-oidc" - "github.com/owncloud/ocis/ocis-pkg/log" - "golang.org/x/oauth2" "net/http" "net/http/httptest" "testing" + + "github.com/coreos/go-oidc" + "github.com/owncloud/ocis/ocis-pkg/log" + "golang.org/x/oauth2" ) func TestOIDCAuthMiddleware(t *testing.T) { - svcCache.Invalidate(AccountsKey, "success") - next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) m := OIDCAuth( diff --git a/proxy/pkg/middleware/options.go b/proxy/pkg/middleware/options.go index 32ecaa6223..bebc06f77e 100644 --- a/proxy/pkg/middleware/options.go +++ b/proxy/pkg/middleware/options.go @@ -2,6 +2,7 @@ package middleware import ( "net/http" + "time" settings "github.com/owncloud/ocis/settings/pkg/proto/v0" @@ -41,6 +42,10 @@ type Options struct { AutoprovisionAccounts bool // EnableBasicAuth to allow basic auth EnableBasicAuth bool + // TokenCacheSize defines the max number of entries in the token cache + TokenCacheSize int + // TokenCacheTTL sets the max cache duration for the token cache + TokenCacheTTL time.Duration } // newOptions initializes the available default options. @@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option { o.EnableBasicAuth = enableBasicAuth } } + +// TokenCacheSize provides a function to set the TokenCacheSize +func TokenCacheSize(size int) Option { + return func(o *Options) { + o.TokenCacheSize = size + } +} + +// TokenCacheTTL provides a function to set the TokenCacheTTL +func TokenCacheTTL(ttl time.Duration) Option { + return func(o *Options) { + o.TokenCacheTTL = ttl + } +}