fix oidc middleware provider lazy initialization

This commit is contained in:
Florian Schade
2020-11-25 22:50:11 +01:00
parent c742389ebb
commit ab85245093

View File

@@ -26,21 +26,46 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
options := newOptions(optionSetters...)
tokenCache := cache.NewCache(cache.Size(options.UserinfoCacheSize))
h := oidcAuth{
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
tokenCacheTTL: options.UserinfoCacheTTL,
}
return func(next http.Handler) http.Handler {
return &oidcAuth{
next: next,
logger: options.Logger,
providerFunc: options.OIDCProviderFunc,
httpClient: options.HTTPClient,
oidcIss: options.OIDCIss,
tokenCache: &tokenCache,
tokenCacheTTL: options.UserinfoCacheTTL,
}
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if !h.shouldServe(req) {
next.ServeHTTP(w, req)
return
}
if h.getProvider() == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
claims, status := h.getClaims(token, req)
if status != 0 {
w.WriteHeader(status)
return
}
// inject claims to the request context for the account_uuid middleware.
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
// store claims in context
// uses the original context, not the one with probably reduced security
next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
})
}
}
type oidcAuth struct {
next http.Handler
logger log.Logger
provider OIDCProvider
providerFunc func() (OIDCProvider, error)
@@ -50,34 +75,6 @@ type oidcAuth struct {
tokenCacheTTL time.Duration
}
func (m oidcAuth) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if !m.shouldServe(req) {
m.next.ServeHTTP(w, req)
return
}
if m.getProvider() == nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
claims, status := m.getClaims(token, req)
if status != 0 {
w.WriteHeader(status)
return
}
// inject claims to the request context for the account_uuid middleware.
req = req.WithContext(oidc.NewContext(req.Context(), &claims))
// store claims in context
// uses the original context, not the one with probably reduced security
m.next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), &claims)))
}
func (m oidcAuth) getClaims(token string, req *http.Request) (claims oidc.StandardClaims, status int) {
hit := m.tokenCache.Get(token)
if hit == nil {
@@ -164,7 +161,7 @@ func (m oidcAuth) shouldServe(req *http.Request) bool {
return strings.HasPrefix(header, "Bearer ")
}
func (m oidcAuth) getProvider() OIDCProvider {
func (m *oidcAuth) getProvider() OIDCProvider {
if m.provider == nil {
// Lazily initialize a provider