mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-01-06 04:09:40 -06:00
284 lines
8.7 KiB
Go
284 lines
8.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/MicahParks/keyfunc"
|
|
gOidc "github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/golang-jwt/jwt/v4"
|
|
"github.com/owncloud/ocis/v2/ocis-pkg/log"
|
|
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
|
|
osync "github.com/owncloud/ocis/v2/ocis-pkg/sync"
|
|
"github.com/owncloud/ocis/v2/services/proxy/pkg/config"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// OIDCProvider used to mock the oidc provider during tests
|
|
type OIDCProvider interface {
|
|
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*gOidc.UserInfo, error)
|
|
}
|
|
|
|
// OIDCAuth provides a middleware to check access secured by a static token.
|
|
func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler {
|
|
options := newOptions(optionSetters...)
|
|
tokenCache := osync.NewCache(options.UserinfoCacheSize)
|
|
|
|
h := oidcAuth{
|
|
logger: options.Logger,
|
|
providerFunc: options.OIDCProviderFunc,
|
|
httpClient: options.HTTPClient,
|
|
oidcIss: options.OIDCIss,
|
|
tokenCache: &tokenCache,
|
|
tokenCacheTTL: options.UserinfoCacheTTL,
|
|
accessTokenVerifyMethod: options.AccessTokenVerifyMethod,
|
|
jwksOptions: options.JWKS,
|
|
jwksLock: &sync.Mutex{},
|
|
providerLock: &sync.Mutex{},
|
|
}
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
// there is no bearer token on the request,
|
|
if !h.shouldServe(req) {
|
|
// oidc supported but token not present, add header and handover to the next middleware.
|
|
userAgentAuthenticateLockIn(w, req, options.CredentialsByUserAgent, "bearer")
|
|
next.ServeHTTP(w, req)
|
|
return
|
|
}
|
|
|
|
if h.getProvider() == nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Force init of jwks keyfunc if needed (contacts the .well-known and jwks endpoints on first call)
|
|
if h.accessTokenVerifyMethod == config.AccessTokenVerificationJWT && h.getKeyfunc() == nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
|
|
|
claims, status := h.getClaims(token, req)
|
|
if status != 0 {
|
|
w.WriteHeader(status)
|
|
return
|
|
}
|
|
|
|
// inject claims to the request context for the account_resolver middleware.
|
|
next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), claims)))
|
|
})
|
|
}
|
|
}
|
|
|
|
type oidcAuth struct {
|
|
logger log.Logger
|
|
provider OIDCProvider
|
|
providerLock *sync.Mutex
|
|
jwksOptions config.JWKS
|
|
jwks *keyfunc.JWKS
|
|
jwksLock *sync.Mutex
|
|
providerFunc func() (OIDCProvider, error)
|
|
httpClient *http.Client
|
|
oidcIss string
|
|
tokenCache *osync.Cache
|
|
tokenCacheTTL time.Duration
|
|
accessTokenVerifyMethod string
|
|
}
|
|
|
|
func (m oidcAuth) getClaims(token string, req *http.Request) (claims map[string]interface{}, status int) {
|
|
hit := m.tokenCache.Load(token)
|
|
if hit == nil {
|
|
aClaims, err := m.verifyAccessToken(token)
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("Failed to verify access token")
|
|
status = http.StatusUnauthorized
|
|
return
|
|
}
|
|
|
|
oauth2Token := &oauth2.Token{
|
|
AccessToken: token,
|
|
}
|
|
|
|
userInfo, err := m.getProvider().UserInfo(
|
|
context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient),
|
|
oauth2.StaticTokenSource(oauth2Token),
|
|
)
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("Failed to get userinfo")
|
|
status = http.StatusUnauthorized
|
|
return
|
|
}
|
|
|
|
if err := userInfo.Claims(&claims); err != nil {
|
|
m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims")
|
|
status = http.StatusInternalServerError
|
|
return
|
|
}
|
|
|
|
expiration := m.extractExpiration(aClaims)
|
|
m.tokenCache.Store(token, claims, expiration)
|
|
|
|
m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo")
|
|
return
|
|
}
|
|
|
|
var ok bool
|
|
if claims, ok = hit.V.(map[string]interface{}); !ok {
|
|
status = http.StatusInternalServerError
|
|
return
|
|
}
|
|
m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
|
|
return
|
|
}
|
|
|
|
func (m oidcAuth) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
|
|
switch m.accessTokenVerifyMethod {
|
|
case config.AccessTokenVerificationJWT:
|
|
return m.verifyAccessTokenJWT(token)
|
|
case config.AccessTokenVerificationNone:
|
|
m.logger.Debug().Msg("Access Token verification disabled")
|
|
return jwt.RegisteredClaims{}, nil
|
|
default:
|
|
m.logger.Error().Str("access_token_verify_method", m.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
|
return jwt.RegisteredClaims{}, errors.New("Unknown Access Token Verification method")
|
|
}
|
|
}
|
|
|
|
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
|
func (m oidcAuth) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
|
|
var claims jwt.RegisteredClaims
|
|
jwks := m.getKeyfunc()
|
|
if jwks == nil {
|
|
return claims, errors.New("Error initializing jwks keyfunc")
|
|
}
|
|
|
|
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
|
|
m.logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
|
if err != nil {
|
|
m.logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
|
return claims, err
|
|
}
|
|
|
|
if !claims.VerifyIssuer(m.oidcIss, true) {
|
|
vErr := jwt.ValidationError{}
|
|
vErr.Inner = jwt.ErrTokenInvalidIssuer
|
|
vErr.Errors |= jwt.ValidationErrorIssuer
|
|
return claims, vErr
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// extractExpiration tries to extract the expriration time from the access token
|
|
// If the access token does not have an exp claim it will fallback to the configured
|
|
// default expiration
|
|
func (m oidcAuth) extractExpiration(aClaims jwt.RegisteredClaims) time.Time {
|
|
defaultExpiration := time.Now().Add(m.tokenCacheTTL)
|
|
if aClaims.ExpiresAt != nil {
|
|
m.logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")
|
|
return aClaims.ExpiresAt.Time
|
|
}
|
|
return defaultExpiration
|
|
}
|
|
|
|
func (m oidcAuth) shouldServe(req *http.Request) bool {
|
|
header := req.Header.Get("Authorization")
|
|
|
|
if m.oidcIss == "" {
|
|
return false
|
|
}
|
|
|
|
// todo: looks dirty, check later
|
|
// TODO: make a PR to coreos/go-oidc for exposing userinfo endpoint on provider, see https://github.com/coreos/go-oidc/issues/248
|
|
for _, ignoringPath := range []string{"/konnect/v1/userinfo", "/status.php"} {
|
|
if req.URL.Path == ignoringPath {
|
|
return false
|
|
}
|
|
}
|
|
|
|
return strings.HasPrefix(header, "Bearer ")
|
|
}
|
|
|
|
type jwksJSON struct {
|
|
JWKSURL string `json:"jwks_uri"`
|
|
}
|
|
|
|
func (m *oidcAuth) getKeyfunc() *keyfunc.JWKS {
|
|
m.jwksLock.Lock()
|
|
defer m.jwksLock.Unlock()
|
|
if m.jwks == nil {
|
|
wellKnown := strings.TrimSuffix(m.oidcIss, "/") + "/.well-known/openid-configuration"
|
|
|
|
resp, err := m.httpClient.Get(wellKnown)
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("Failed to set request for .well-known/openid-configuration")
|
|
return nil
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := ioutil.ReadAll(resp.Body)
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("unable to read discovery response body")
|
|
return nil
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
m.logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration")
|
|
return nil
|
|
}
|
|
|
|
var j jwksJSON
|
|
err = json.Unmarshal(body, &j)
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("failed to decode provider openid-configuration")
|
|
return nil
|
|
}
|
|
m.logger.Debug().Str("jwks", j.JWKSURL).Msg("discovered jwks endpoint")
|
|
options := keyfunc.Options{
|
|
Client: m.httpClient,
|
|
RefreshErrorHandler: func(err error) {
|
|
m.logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc")
|
|
},
|
|
RefreshInterval: time.Minute * time.Duration(m.jwksOptions.RefreshInterval),
|
|
RefreshRateLimit: time.Second * time.Duration(m.jwksOptions.RefreshRateLimit),
|
|
RefreshTimeout: time.Second * time.Duration(m.jwksOptions.RefreshTimeout),
|
|
RefreshUnknownKID: m.jwksOptions.RefreshUnknownKID,
|
|
}
|
|
m.jwks, err = keyfunc.Get(j.JWKSURL, options)
|
|
if err != nil {
|
|
m.jwks = nil
|
|
m.logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.")
|
|
return nil
|
|
}
|
|
}
|
|
return m.jwks
|
|
}
|
|
|
|
func (m *oidcAuth) getProvider() OIDCProvider {
|
|
m.providerLock.Lock()
|
|
defer m.providerLock.Unlock()
|
|
if m.provider == nil {
|
|
// Lazily initialize a provider
|
|
|
|
// provider needs to be cached as when it is created
|
|
// it will fetch the keys from the issuer using the .well-known
|
|
// endpoint
|
|
provider, err := m.providerFunc()
|
|
if err != nil {
|
|
m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider")
|
|
return nil
|
|
}
|
|
|
|
m.provider = provider
|
|
}
|
|
return m.provider
|
|
}
|