First implementation for userinfo cache without config

This commit is contained in:
Benedikt Kulmann
2020-11-17 17:10:14 +01:00
parent 1034548ecd
commit a5c09453b9
9 changed files with 109 additions and 209 deletions

View File

@@ -1,20 +1,21 @@
package cache
import (
"fmt"
"sync"
"time"
)
// Entry represents an entry on the cache. You can type assert on V.
type Entry struct {
V interface{}
Valid bool
V interface{}
inserted time.Time
}
// Cache is a barebones cache implementation.
type Cache struct {
entries map[string]map[string]Entry
entries map[string]*Entry
size int
ttl time.Duration
m sync.Mutex
}
@@ -24,78 +25,56 @@ func NewCache(o ...Option) Cache {
return Cache{
size: opts.size,
entries: map[string]map[string]Entry{},
ttl: opts.ttl,
entries: map[string]*Entry{},
}
}
// Get gets an entry on a service `svcKey` by a give `key`.
func (c *Cache) Get(svcKey, key string) (*Entry, error) {
var value Entry
ok := true
// Get gets a role-bundle by a given `roleID`.
func (c *Cache) Get(k string) *Entry {
c.m.Lock()
defer c.m.Unlock()
if value, ok = c.entries[svcKey][key]; !ok {
return nil, fmt.Errorf("invalid service key: `%v`", key)
if _, ok := c.entries[k]; ok {
if c.expired(c.entries[k]) {
delete(c.entries, k)
return nil
}
return c.entries[k]
}
return &value, nil
return nil
}
// Set sets a key / value. It lets a service add entries on a request basis.
func (c *Cache) Set(svcKey, key string, val interface{}) error {
// Set sets a roleID / role-bundle.
func (c *Cache) Set(k string, val interface{}) {
c.m.Lock()
defer c.m.Unlock()
if !c.fits() {
return fmt.Errorf("cache is full")
c.evict()
}
if _, ok := c.entries[svcKey]; !ok {
c.entries[svcKey] = map[string]Entry{}
c.entries[k] = &Entry{
val,
time.Now(),
}
if _, ok := c.entries[svcKey][key]; ok {
return fmt.Errorf("key `%v` already exists", key)
}
c.entries[svcKey][key] = Entry{
V: val,
Valid: true,
}
return nil
}
// Invalidate invalidates a cache Entry by key.
func (c *Cache) Invalidate(svcKey, key string) error {
r, err := c.Get(svcKey, key)
if err != nil {
return err
}
r.Valid = false
c.entries[svcKey][key] = *r
return nil
}
// Evict frees memory from the cache by removing invalid keys. It is a noop.
func (c *Cache) Evict() {
for _, v := range c.entries {
for k, svcEntry := range v {
if !svcEntry.Valid {
delete(v, k)
}
// evict frees memory from the cache by removing entries that exceeded the cache TTL.
func (c *Cache) evict() {
for i := range c.entries {
if c.expired(c.entries[i]) {
delete(c.entries, i)
}
}
}
// Length returns the amount of entries per service key.
func (c *Cache) Length(k string) int {
return len(c.entries[k])
// expired checks if an entry is expired
func (c *Cache) expired(e *Entry) bool {
return e.inserted.Add(c.ttl).Before(time.Now())
}
// fits returns whether the cache fits more entries.
func (c *Cache) fits() bool {
return c.size >= len(c.entries)
return c.size > len(c.entries)
}

View File

@@ -1,99 +0,0 @@
package cache
import (
"testing"
)
// Prevents from invalid import cycle.
type AccountsCacheEntry struct {
Email string
UUID string
}
func TestSet(t *testing.T) {
c := NewCache(
Size(256),
)
err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
Email: "hello@foo.bar",
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
})
if err != nil {
t.Error(err)
}
if c.Length("accounts") != 1 {
t.Errorf("expected length 1 got `%v`", len(c.entries))
}
item, err := c.Get("accounts", "hello@foo.bar")
if err != nil {
t.Error(err)
}
if cachedEntry, ok := item.V.(AccountsCacheEntry); !ok {
t.Errorf("invalid cached value type")
} else {
if cachedEntry.Email != "hello@foo.bar" {
t.Errorf("invalid value. Expected `hello@foo.bar` got: `%v`", cachedEntry.Email)
}
}
}
func TestEvict(t *testing.T) {
c := NewCache(
Size(256),
)
if err := c.Set("accounts", "hello@foo.bar", AccountsCacheEntry{
Email: "hello@foo.bar",
UUID: "9c31b040-59e2-4a2b-926b-334d9e3fbd05",
}); err != nil {
t.Error(err)
}
if err := c.Invalidate("accounts", "hello@foo.bar"); err != nil {
t.Error(err)
}
v, err := c.Get("accounts", "hello@foo.bar")
if err != nil {
t.Error(err)
}
if v.Valid {
t.Errorf("cache key unexpected valid state")
}
c.Evict()
if c.Length("accounts") != 0 {
t.Errorf("expected length 0 got `%v`", len(c.entries))
}
}
func TestGet(t *testing.T) {
svcCache := NewCache(
Size(256),
)
err := svcCache.Set("accounts", "node", "0.0.0.0:1234")
if err != nil {
t.Error(err)
}
raw, err := svcCache.Get("accounts", "node")
if err != nil {
t.Error(err)
}
v, ok := raw.V.(string)
if !ok {
t.Errorf("invalid type on service node key")
}
if v != "0.0.0.0:1234" {
t.Errorf("expected `0.0.0.0:1234` got `%v`", v)
}
}

View File

@@ -15,12 +15,12 @@ import (
"github.com/coreos/go-oidc"
"github.com/justinas/alice"
"github.com/micro/cli/v2"
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
"github.com/oklog/run"
openzipkin "github.com/openzipkin/zipkin-go"
zipkinhttp "github.com/openzipkin/zipkin-go/reporter/http"
acc "github.com/owncloud/ocis/accounts/pkg/proto/v0"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/service/grpc"
"github.com/owncloud/ocis/proxy/pkg/config"
"github.com/owncloud/ocis/proxy/pkg/cs3"
"github.com/owncloud/ocis/proxy/pkg/flagset"
@@ -281,6 +281,8 @@ func loadMiddlewares(ctx context.Context, l log.Logger, cfg *config.Config) alic
}),
middleware.HTTPClient(oidcHTTPClient),
middleware.OIDCIss(cfg.OIDC.Issuer),
middleware.TokenCacheSize(1024),
middleware.TokenCacheTTL(time.Second*10),
),
middleware.BasicAuth(
middleware.Logger(l),

View File

@@ -3,33 +3,31 @@ package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/micro/go-micro/v2/client"
"github.com/owncloud/ocis/accounts/pkg/proto/v0"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/oidc"
"github.com/owncloud/ocis/proxy/pkg/config"
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetAccountSuccess(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(false, true), "mail eq 'success'"); status != 0 {
t.Errorf("expected an account")
}
}
func TestGetAccountInternalError(t *testing.T) {
svcCache.Invalidate(AccountsKey, "failure")
if _, status := getAccount(log.NewLogger(), mockAccountResolverMiddlewareAccSvc(true, false), "mail eq 'failure'"); status != http.StatusInternalServerError {
t.Errorf("expected an internal server error")
}
}
func TestAccountResolverMiddleware(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := AccountResolver(
Logger(log.NewLogger()),
@@ -50,7 +48,6 @@ func TestAccountResolverMiddleware(t *testing.T) {
}
func TestAccountResolverMiddlewareWithDisabledAccount(t *testing.T) {
svcCache.Invalidate(AccountsKey, "failure")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := AccountResolver(
Logger(log.NewLogger()),

View File

@@ -12,5 +12,5 @@ var (
ErrUnauthorized = errors.New("unauthorized")
// ErrInternal is returned if something went wrong
ErrInternal = errors.New("internal error")
ErrInternal = errors.New("internal error")
)

View File

@@ -1,17 +0,0 @@
package middleware
import (
"github.com/owncloud/ocis/proxy/pkg/cache"
)
const (
// AccountsKey declares the svcKey for the Accounts service.
AccountsKey = "accounts"
)
var (
// svcCache caches requests for given services to prevent round trips to the service
svcCache = cache.NewCache(
cache.Size(256),
)
)

View File

@@ -2,12 +2,14 @@ package middleware
import (
"context"
"net/http"
"strings"
gOidc "github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"github.com/owncloud/ocis/ocis-pkg/oidc"
"github.com/owncloud/ocis/proxy/pkg/cache"
"golang.org/x/oauth2"
"net/http"
"strings"
)
// OIDCProvider used to mock the oidc provider during tests
@@ -18,6 +20,10 @@ type OIDCProvider interface {
// OIDCAuth provides a middleware to check access secured by a static token.
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
options := newOptions(optionSetters...)
tokenCache := cache.NewCache(
cache.Size(options.TokenCacheSize),
cache.TTL(options.TokenCacheTTL),
)
return func(next http.Handler) http.Handler {
return &oidcAuth{
@@ -26,6 +32,7 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
}
}
}
@@ -37,6 +44,7 @@ type oidcAuth struct {
providerFunc func() (OIDCProvider, error)
httpClient *http.Client
oidcIss string
tokenCache *cache.Cache
}
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
@@ -64,36 +72,48 @@ func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
oauth2Token := &oauth2.Token{
AccessToken: token,
}
userInfo, err := m.provider.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
return
}
hit := m.tokenCache.Get(token)
var claims oidc.StandardClaims
if err := userInfo.Claims(&claims); err != nil {
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
w.WriteHeader(http.StatusInternalServerError)
return
}
if hit == nil {
// TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token)
oauth2Token := &oauth2.Token{
AccessToken: token,
}
//TODO: This should be read from the token instead of config
claims.Iss = m.oidcIss
userInfo, err := m.provider.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
m.logger.Error().Err(err).Str("token", token).Msg("Failed to get userinfo")
http.Error(w, ErrInvalidToken.Error(), http.StatusUnauthorized)
return
}
if err := userInfo.Claims(&claims); err != nil {
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
w.WriteHeader(http.StatusInternalServerError)
return
}
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
//TODO: This should be read from the token instead of config
claims.Iss = m.oidcIss
m.tokenCache.Set(token, claims)
} else {
var ok = false
if claims, ok = hit.V.(oidc.StandardClaims); !ok {
w.WriteHeader(http.StatusInternalServerError)
return
}
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
}
// inject claims to the request context for the account_uuid middleware.
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Msg("unmarshalled userinfo")
// store claims in context
// uses the original context, not the one with probably reduced security
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))

View File

@@ -3,17 +3,16 @@ package middleware
import (
"context"
"fmt"
"github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"golang.org/x/oauth2"
"net/http"
"net/http/httptest"
"testing"
"github.com/coreos/go-oidc"
"github.com/owncloud/ocis/ocis-pkg/log"
"golang.org/x/oauth2"
)
func TestOIDCAuthMiddleware(t *testing.T) {
svcCache.Invalidate(AccountsKey, "success")
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
m := OIDCAuth(

View File

@@ -2,6 +2,7 @@ package middleware
import (
"net/http"
"time"
settings "github.com/owncloud/ocis/settings/pkg/proto/v0"
@@ -41,6 +42,10 @@ type Options struct {
AutoprovisionAccounts bool
// EnableBasicAuth to allow basic auth
EnableBasicAuth bool
// TokenCacheSize defines the max number of entries in the token cache
TokenCacheSize int
// TokenCacheTTL sets the max cache duration for the token cache
TokenCacheTTL time.Duration
}
// newOptions initializes the available default options.
@@ -137,3 +142,17 @@ func EnableBasicAuth(enableBasicAuth bool) Option {
o.EnableBasicAuth = enableBasicAuth
}
}
// TokenCacheSize provides a function to set the TokenCacheSize
func TokenCacheSize(size int) Option {
return func(o *Options) {
o.TokenCacheSize = size
}
}
// TokenCacheTTL provides a function to set the TokenCacheTTL
func TokenCacheTTL(ttl time.Duration) Option {
return func(o *Options) {
o.TokenCacheTTL = ttl
}
}