package middleware import ( "context" "net/http" "strings" "time" "github.com/golang-jwt/jwt/v4" gOidc "github.com/coreos/go-oidc/v3/oidc" "github.com/owncloud/ocis/ocis-pkg/log" "github.com/owncloud/ocis/ocis-pkg/oidc" "github.com/owncloud/ocis/ocis-pkg/sync" "github.com/owncloud/ocis/proxy/pkg/config" "golang.org/x/oauth2" ) // OIDCProvider used to mock the oidc provider during tests type OIDCProvider interface { UserInfo(ctx context.Context, ts oauth2.TokenSource) (*gOidc.UserInfo, error) } // OIDCAuth provides a middleware to check access secured by a static token. func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { options := newOptions(optionSetters...) tokenCache := sync.NewCache(options.UserinfoCacheSize) h := oidcAuth{ logger: options.Logger, providerFunc: options.OIDCProviderFunc, httpClient: options.HTTPClient, oidcIss: options.OIDCIss, TokenManagerConfig: options.TokenManagerConfig, tokenCache: &tokenCache, tokenCacheTTL: options.UserinfoCacheTTL, } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // there is no bearer token on the request, if !h.shouldServe(req) { // oidc supported but token not present, add header and handover to the next middleware. userAgentAuthenticateLockIn(w, req, options.CredentialsByUserAgent, "bearer") next.ServeHTTP(w, req) return } if h.getProvider() == nil { w.WriteHeader(http.StatusInternalServerError) return } token := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") claims, status := h.getClaims(token, req) if status != 0 { w.WriteHeader(status) return } // inject claims to the request context for the account_uuid middleware. req = req.WithContext(oidc.NewContext(req.Context(), claims)) // store claims in context // uses the original context, not the one with probably reduced security next.ServeHTTP(w, req.WithContext(oidc.NewContext(req.Context(), claims))) }) } } type oidcAuth struct { logger log.Logger provider OIDCProvider providerFunc func() (OIDCProvider, error) httpClient *http.Client oidcIss string tokenCache *sync.Cache tokenCacheTTL time.Duration TokenManagerConfig config.TokenManager } func (m oidcAuth) getClaims(token string, req *http.Request) (claims map[string]interface{}, status int) { hit := m.tokenCache.Load(token) if hit == nil { // TODO cache userinfo for access token if we can determine the expiry (which works in case it is a jwt based access token) oauth2Token := &oauth2.Token{ AccessToken: token, } userInfo, err := m.getProvider().UserInfo( context.WithValue(req.Context(), oauth2.HTTPClient, m.httpClient), oauth2.StaticTokenSource(oauth2Token), ) if err != nil { m.logger.Error().Err(err).Msg("Failed to get userinfo") status = http.StatusUnauthorized return } if err := userInfo.Claims(&claims); err != nil { m.logger.Error().Err(err).Interface("userinfo", userInfo).Msg("failed to unmarshal userinfo claims") status = http.StatusInternalServerError return } expiration := m.extractExpiration(token) m.tokenCache.Store(token, claims, expiration) m.logger.Debug().Interface("claims", claims).Interface("userInfo", userInfo).Time("expiration", expiration.UTC()).Msg("unmarshalled and cached userinfo") return } var ok bool if claims, ok = hit.V.(map[string]interface{}); !ok { status = http.StatusInternalServerError return } m.logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo") return } // extractExpiration tries to parse and extract the expiration from the provided token. It might not even be a jwt. // defaults to the configured fallback TTL. // TODO: use introspection endpoint if available in the oidc configuration. Still needs a fallback to configured TTL. func (m oidcAuth) extractExpiration(token string) time.Time { defaultExpiration := time.Now().Add(m.tokenCacheTTL) t, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { return []byte(m.TokenManagerConfig.JWTSecret), nil }) if err != nil { return defaultExpiration } at, ok := t.Claims.(jwt.StandardClaims) if !ok || at.ExpiresAt == 0 { return defaultExpiration } return time.Unix(at.ExpiresAt, 0) } func (m oidcAuth) shouldServe(req *http.Request) bool { header := req.Header.Get("Authorization") if m.oidcIss == "" { return false } // todo: looks dirty, check later // TODO: make a PR to coreos/go-oidc for exposing userinfo endpoint on provider, see https://github.com/coreos/go-oidc/issues/248 for _, ignoringPath := range []string{"/konnect/v1/userinfo", "/status.php"} { if req.URL.Path == ignoringPath { return false } } return strings.HasPrefix(header, "Bearer ") } func (m *oidcAuth) getProvider() OIDCProvider { if m.provider == nil { // Lazily initialize a provider // provider needs to be cached as when it is created // it will fetch the keys from the issuer using the .well-known // endpoint provider, err := m.providerFunc() if err != nil { m.logger.Error().Err(err).Msg("could not initialize oidcAuth provider") return nil } m.provider = provider } return m.provider }