Merge pull request #877 from owncloud/proxy-caching

Caching in Proxy
This commit is contained in:
Benedikt Kulmann
2020-11-18 13:13:25 +01:00
committed by GitHub
14 changed files with 264 additions and 352 deletions

View File

@@ -0,0 +1,7 @@
Enhancement: Cache userinfo in proxy
Tags: proxy
We introduced caching for the userinfo response. The token expiration is used for cache invalidation if available. Otherwise we fall back to a preconfigured TTL (default 10 seconds).
https://github.com/owncloud/ocis/pull/877

View File

@@ -4,9 +4,8 @@
clients:
- id: phoenix
name: ownCloud web app
application_type: web
insecure: yes
trusted: yes
insecure: yes
redirect_uris:
- https://localhost:9200/
- https://localhost:9200/oidc-callback.html
@@ -19,9 +18,8 @@ clients:
- http://localhost:9100
- id: ocis-explorer.js
name: OCIS Graph Explorer
name: oCIS Graph Explorer
trusted: yes
application_type: web
insecure: yes
- id: xdXOt13JKxym1B1QcEncf2XDkLAexMBFwiT9j6EfhhHFJhs2KM9jbjTmf8JBXE69

File diff suppressed because one or more lines are too long

View File

@@ -1,19 +1,19 @@
package cache
import (
"fmt"
"sync"
"time"
)
// Entry represents an entry on the cache. You can type assert on V.
type Entry struct {
V interface{}
Valid bool
V interface{}
expiration time.Time
}
// Cache is a barebones cache implementation.
type Cache struct {
entries map[string]map[string]Entry
entries map[string]*Entry
size int
m sync.Mutex
}
@@ -24,78 +24,55 @@ func NewCache(o ...Option) Cache {
return Cache{
size: opts.size,
entries: map[string]map[string]Entry{},
entries: map[string]*Entry{},
}
}
// Get gets an entry on a service `svcKey` by a give `key`.
func (c *Cache) Get(svcKey, key string) (*Entry, error) {
var value Entry
ok := true
// Get gets a role-bundle by a given `roleID`.
func (c *Cache) Get(k string) *Entry {
c.m.Lock()
defer c.m.Unlock()
if value, ok = c.entries[svcKey][key]; !ok {
return nil, fmt.Errorf("invalid service key: `%v`", key)
if _, ok := c.entries[k]; ok {
if c.expired(c.entries[k]) {
delete(c.entries, k)
return nil
}
return c.entries[k]
}
return &value, nil
return nil
}
// Set sets a key / value. It lets a service add entries on a request basis.
func (c *Cache) Set(svcKey, key string, val interface{}) error {
// Set sets a roleID / role-bundle.
func (c *Cache) Set(k string, val interface{}, expiration time.Time) {
c.m.Lock()
defer c.m.Unlock()
if !c.fits() {
return fmt.Errorf("cache is full")
c.evict()
}
if _, ok := c.entries[svcKey]; !ok {
c.entries[svcKey] = map[string]Entry{}
c.entries[k] = &Entry{
val,
expiration,
}
if _, ok := c.entries[svcKey][key]; ok {
return fmt.Errorf("key `%v` already exists", key)
}
c.entries[svcKey][key] = Entry{
V: val,
Valid: true,
}
return nil
}
// Invalidate invalidates a cache Entry by key.
func (c *Cache) Invalidate(svcKey, key string) error {
r, err := c.Get(svcKey, key)
if err != nil {
return err
}
r.Valid = false
c.entries[svcKey][key] = *r
return nil
}
// Evict frees memory from the cache by removing invalid keys. It is a noop.
func (c *Cache) Evict() {
for _, v := range c.entries {
for k, svcEntry := range v {
if !svcEntry.Valid {
delete(v, k)
}
// evict frees memory from the cache by removing entries that exceeded the cache TTL.
func (c *Cache) evict() {
for i := range c.entries {
if c.expired(c.entries[i]) {
delete(c.entries, i)
}
}
}
// Length returns the amount of entries per service key.
func (c *Cache) Length(k string) int {
return len(c.entries[k])
// expired checks if an entry is expired
func (c *Cache) expired(e *Entry) bool {
return e.expiration.Before(time.Now())
}
// fits returns whether the cache fits more entries.
func (c *Cache) fits() bool {
return c.size >= len(c.entries)
return c.size > len(c.entries)
}

View File

@@ -1,99 +0,0 @@
package cache
import (
"testing"
)
// Prevents from invalid import cycle.
type AccountsCacheEntry struct {
Email string
UUID string
}
func TestSet(t *testing.T) {
c := NewCache(
Size(256),
)
err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
Email: "hello@foo.bar",
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
})
if err != nil {
t.Error(err)
}
if c.Length("accounts") != 1 {
t.Errorf("expected length 1 got `%v`", len(c.entries))
}
item, err := c.Get("accounts", "hello@foo.bar")
if err != nil {
t.Error(err)
}
if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok {
t.Errorf("invalid cached value type")
} else {
if cachedEntry.Email != "hello@foo.bar" {
t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email)
}
}
}
func TestEvict(t *testing.T) {
c := NewCache(
Size(256),
)
if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
Email: "hello@foo.bar",
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
}); err != nil {
t.Error(err)
}
if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil {
t.Error(err)
}
v, err := c.Get("accounts", "hello@foo.bar")
if err != nil {
t.Error(err)
}
if v.Valid {
t.Errorf("cache key unexpected valid state")
}
c.Evict()
if c.Length("accounts") != 0 {
t.Errorf("expected length 0 got `%v`", len(c.entries))
}
}
func TestGet(t *testing.T) {
svcCache := NewCache(
Size(256),
)
err := svcCache.Set("accounts", "node", "0.0.0.0:1234")
if err != nil {
t.Error(err)
}
raw, err := svcCache.Get("accounts", "node")
if err != nil {
t.Error(err)
}
v, ok := raw.V.(string)
if !ok {
t.Errorf("invalid type on service node key")
}
if v != "0.0.0.0:1234" {
t.Errorf("expected `0.0.0.0:1234` got `%v`", v)
}
}

View File

@@ -15,12 +15,12 @@ import (
"github.com/coreos/go-oidc"
"github.com/justinas/alice"
"github.com/micro/cli/v2"
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
"github.com/oklog/run"
openzipkin "github.com/openzipkin/zipkin-go"
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
acc "github.com/owncloud/ocis/accounts/pkg/proto/v0"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
"github.com/owncloud/ocis/proxy/pkg/config"
"github.com/owncloud/ocis/proxy/pkg/cs3"
"github.com/owncloud/ocis/proxy/pkg/flagset"
@@ -281,6 +281,8 @@ func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alic
}),
middleware.HTTPClient(oidcHTTPClient),
middleware.OIDCIss(cfg.OIDC.Issuer),
middleware.TokenCacheSize(cfg.OIDC.UserinfoCache.Size),
middleware.TokenCacheTTL(time.Second*time.Duration(cfg.OIDC.UserinfoCache.TTL)),
),
middleware.BasicAuth(
middleware.Logger(l),

View File

@@ -83,6 +83,12 @@ type Reva struct {
Address string
}
// Cache is a TTL cache configuration.
type Cache struct {
Size int
TTL int
}
// Config combines all available configuration parts.
type Config struct {
File string
@@ -105,8 +111,9 @@ type Config struct {
// OIDC is the config for the OpenID-Connect middleware. If set the proxy will try to authenticate every request
// with the configured oidc-provider
type OIDC struct {
Issuer string
Insecure bool
Issuer string
Insecure bool
UserinfoCache Cache
}
// PolicySelector is the toplevel-configuration for different selectors

View File

@@ -202,6 +202,20 @@ func ServerWithConfig(cfg *config.Config) []cli.Flag {
EnvVars: []string{"PROXY_OIDC_INSECURE"},
Destination: &cfg.OIDC.Insecure,
},
&cli.IntFlag{
Name: "oidc-userinfo-cache-tll",
Value: 10,
Usage: "Fallback TTL in seconds for caching userinfo, when no token lifetime can be identified",
EnvVars: []string{"PROXY_OIDC_USERINFO_CACHE_TTL"},
Destination: &cfg.OIDC.UserinfoCache.TTL,
},
&cli.IntFlag{
Name: "oidc-userinfo-cache-size",
Value: 1024,
Usage: "Max entries for caching userinfo",
EnvVars: []string{"PROXY_OIDC_USERINFO_CACHE_SIZE"},
Destination: &cfg.OIDC.UserinfoCache.Size,
},
&cli.BoolFlag{
Name: "autoprovision-accounts",

View File

@@ -3,33 +3,31 @@ package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/micro/go-micro/v2/client"
"github.com/owncloud/ocis/accounts/pkg/proto/v0"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/oidc"
"github.com/owncloud/ocis/proxy/pkg/config"
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetAccountSuccess(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
t.Errorf("expected an account")
}
}
func TestGetAccountInternalError(t *testing.T) {
svcCache.Invalidate(AccountsKey, "failure")
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
t.Errorf("expected an internal server error")
}
}
func TestAccountResolverMiddleware(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := AccountResolver(
Logger(log.NewLogger()),
@@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) {
}
func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) {
svcCache.Invalidate(AccountsKey, "failure")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := AccountResolver(
Logger(log.NewLogger()),

View File

@@ -12,5 +12,5 @@ var (
ErrUnauthorized = errors.New("unauthorized")
// ErrInternal is returned if something went wrong
ErrInternal = errors.New("internal error")
ErrInternal = errors.New("internal error")
)

View File

@@ -1,17 +0,0 @@
package middleware
import (
"github.com/owncloud/ocis/proxy/pkg/cache"
)
const (
// AccountsKey declares the svcKey for the Accounts service.
AccountsKey = "accounts"
)
var (
// svcCache caches requests for given services to prevent round trips to the service
svcCache = cache.NewCache(
cache.Size(256),
)
)

View File

@@ -2,12 +2,18 @@ package middleware
import (
"context"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
gOidc "github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/oidc"
"github.com/owncloud/ocis/proxy/pkg/cache"
"golang.org/x/oauth2"
"net/http"
"strings"
)
// OIDCProvider used to mock the oidc provider during tests
@@ -18,25 +24,30 @@ type OIDCProvider interface {
// OIDCAuth provides a middleware to check access secured by a static token.
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
options := newOptions(optionSetters...)
tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize))
return func(next http.Handler) http.Handler {
return &oidcAuth{
next: next,
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
next: next,
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
tokenCacheTTL: options.UserinfoCacheTTL,
}
}
}
type oidcAuth struct {
next http.Handler
logger log.Logger
provider OIDCProvider
providerFunc func() (OIDCProvider, error)
httpClient *http.Client
oidcIss string
next http.Handler
logger log.Logger
provider OIDCProvider
providerFunc func() (OIDCProvider, error)
httpClient *http.Client
oidcIss string
tokenCache *cache.Cache
tokenCacheTTL time.Duration
}
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@@ -46,59 +57,95 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
if m.provider == nil {
// Lazily initialize a provider
// provider needs to be cached as when it is created
// it will fetch the keys from the issuer using the .well-known
// endpoint
provider, err := m.providerFunc()
if err != nil {
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
w.WriteHeader(http.StatusInternalServerError)
return
}
m.provider = provider
}
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
oauth2Token := &oauth2.Token{
AccessToken: token,
}
userInfo, err := m.provider.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
return
}
var claims oidc.StandardClaims
if err := userInfo.Claims(&claims); err != nil {
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
if m.getProvider() == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
//TODO: This should be read from the token instead of config
claims.Iss = m.oidcIss
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
claims, status := m.getClaims(token, req)
if status != 0 {
w.WriteHeader(status)
return
}
// inject claims to the request context for the account_uuid middleware.
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
// store claims in context
// uses the original context, not the one with probably reduced security
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
}
func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.StandardClaims, status int) {
hit := m.tokenCache.Get(token)
if hit == nil {
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
oauth2Token := &oauth2.Token{
AccessToken: token,
}
userInfo, err := m.getProvider().UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
status = http.StatusUnauthorized
return
}
if err := userInfo.Claims(&claims); err != nil {
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
status = http.StatusInternalServerError
return
}
//TODO: This should be read from the token instead of config
claims.Iss = m.oidcIss
expiration := m.extractExpiration(token)
m.tokenCache.Set(token, claims, expiration)
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
return
}
var ok = false
if claims, ok = hit.V.(oidc.StandardClaims); !ok {
status = http.StatusInternalServerError
return
}
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
return
}
// extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt.
// defaults to the configured fallback TTL.
// TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL.
func (m oidcAuth) extractExpiration(token string) time.Time {
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
s := strings.SplitN(token, ".", 4)
if len(s) != 3 {
return defaultExpiration
}
b, err := jwt.DecodeSegment(s[1])
if err != nil {
return defaultExpiration
}
at := &jwt.StandardClaims{}
err = json.Unmarshal(b, at)
if err != nil || at.ExpiresAt == 0 {
return defaultExpiration
}
return time.Unix(at.ExpiresAt, 0)
}
func (m oidcAuth) shouldServe(req *http.Request) bool {
header := req.Header.Get("Authorization")
@@ -107,6 +154,7 @@ func (m oidcAuth) shouldServe(req *http.Request) bool {
}
// todo: looks dirty, check later
// TODO: make a PR to coreos/go-oidc for exposing userinfo endpoint on provider, see https://github.com/coreos/go-oidc/issues/248
for _, ignoringPath := range []string{"/konnect/v1/userinfo"} {
if req.URL.Path == ignoringPath {
return false
@@ -115,3 +163,21 @@ func (m oidcAuth) shouldServe(req *http.Request) bool {
return strings.HasPrefix(header, "Bearer ")
}
func (m oidcAuth) getProvider() OIDCProvider {
if m.provider == nil {
// Lazily initialize a provider
// provider needs to be cached as when it is created
// it will fetch the keys from the issuer using the .well-known
// endpoint
provider, err := m.providerFunc()
if err != nil {
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
return nil
}
m.provider = provider
}
return m.provider
}

View File

@@ -3,17 +3,16 @@ package middleware
import (
"context"
"fmt"
"github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"golang.org/x/oauth2"
"net/http"
"net/http/httptest"
"testing"
"github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"golang.org/x/oauth2"
)
func TestOIDCAuthMiddleware(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := OIDCAuth(

View File

@@ -2,6 +2,7 @@ package middleware
import (
"net/http"
"time"
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
@@ -27,7 +28,7 @@ type Options struct {
AccountsClient acc.AccountsService
// SettingsRoleService for the roles API in settings
SettingsRoleService settings.RoleService
// OIDCProviderFunc to lazily initialize a provider, must be set for the oidcProvider middleware
// OIDCProviderFunc to lazily initialize an oidc provider, must be set for the oidc_auth middleware
OIDCProviderFunc func() (OIDCProvider, error)
// OIDCIss is the oidcAuth-issuer
OIDCIss string
@@ -41,6 +42,10 @@ type Options struct {
AutoprovisionAccounts bool
// EnableBasicAuth to allow basic auth
EnableBasicAuth bool
// UserinfoCacheSize defines the max number of entries in the userinfo cache, intended for the oidc_auth middleware
UserinfoCacheSize int
// UserinfoCacheTTL sets the max cache duration for the userinfo cache, intended for the oidc_auth middleware
UserinfoCacheTTL time.Duration
}
// newOptions initializes the available default options.
@@ -89,7 +94,7 @@ func SettingsRoleService(rc settings.RoleService) Option {
}
}
// OIDCProviderFunc provides a function to set the the oidcAuth provider function option.
// OIDCProviderFunc provides a function to set the the oidc provider function option.
func OIDCProviderFunc(f func() (OIDCProvider, error)) Option {
return func(o *Options) {
o.OIDCProviderFunc = f
@@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option {
o.EnableBasicAuth = enableBasicAuth
}
}
// TokenCacheSize provides a function to set the TokenCacheSize
func TokenCacheSize(size int) Option {
return func(o *Options) {
o.UserinfoCacheSize = size
}
}
// TokenCacheTTL provides a function to set the TokenCacheTTL
func TokenCacheTTL(ttl time.Duration) Option {
return func(o *Options) {
o.UserinfoCacheTTL = ttl
}
}