diff --git a/go.mod b/go.mod index 02be5b7a9b..83b0c7b716 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.18 require ( github.com/CiscoM31/godata v1.0.5 github.com/Masterminds/semver v1.5.0 + github.com/MicahParks/keyfunc v1.1.0 github.com/ReneKroon/ttlcache/v2 v2.11.0 github.com/blevesearch/bleve/v2 v2.3.3 github.com/blevesearch/bleve_index_api v1.0.2 diff --git a/go.sum b/go.sum index 0ff39116bd..fd04a1ad9a 100644 --- a/go.sum +++ b/go.sum @@ -100,6 +100,8 @@ github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3Q github.com/Masterminds/semver v1.5.0/go.mod h1:MB6lktGJrhw8PrUyiEoblNEGEQ+RzHPF078ddwwvV3Y= github.com/Masterminds/sprig v2.22.0+incompatible h1:z4yfnGrZ7netVz+0EDJ0Wi+5VZCSYp4Z0m2dk6cEM60= github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuNhlNS5hqE0NB0E6fgfo2Br3o= +github.com/MicahParks/keyfunc v1.1.0 h1:9NcnRwS0ciuVeVNi+vTdYVMTmk62OID7VlG6y9BgLK0= +github.com/MicahParks/keyfunc v1.1.0/go.mod h1:a4yfunv77gZ0RgTNw7tOYS+bjtHk5565e+1dPz+YJI8= github.com/Microsoft/go-winio v0.4.14/go.mod h1:qXqCSQ3Xa7+6tgxaGTIe4Kpcdsi+P8jBhyzoq1bpyYA= github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= github.com/Microsoft/go-winio v0.5.0/go.mod h1:JPGBdM1cNvN/6ISo+n8V5iA4v8pBzdOpzfwIujj1a84= @@ -520,6 +522,7 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.3.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= +github.com/golang-jwt/jwt/v4 v4.4.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v4 v4.4.2 h1:rcc4lwaZgFMCZ5jxF9ABolDcIHdBytAFgqFPbSJQAYs= github.com/golang-jwt/jwt/v4 v4.4.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/geo v0.0.0-20210211234256-740aa86cb551 h1:gtexQ/VGyN+VVFRXSFiguSNcXmS6rkKT+X7FdIrTtfo= diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 39a1fc839a..fbb13ec65b 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -2,11 +2,15 @@ package middleware import ( "context" + "encoding/json" + "io/ioutil" "net/http" "strings" "time" + "github.com/MicahParks/keyfunc" gOidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/golang-jwt/jwt/v4" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" "github.com/owncloud/ocis/v2/ocis-pkg/sync" @@ -64,6 +68,7 @@ func OIDCAuth(optionSetters ...Option) func(next http.Handler) http.Handler { type oidcAuth struct { logger log.Logger provider OIDCProvider + jwks *keyfunc.JWKS providerFunc func() (OIDCProvider, error) httpClient *http.Client oidcIss string @@ -110,20 +115,29 @@ func (m oidcAuth) getClaims(token string, req *http.Request) (claims map[string] return } -// extractExpiration currently just returns a hardcoded default for now. It was -// supposed to parse and extract the expiration time from the provided -// access_token. -// As the access_token is defined as an opaque string. Validating and parsing it -// can be tricky: -// 1. Try to treat it as a JWT: -// - Verifying the validity of the token requires downloading the propoer public -// key from the IDP (uri in "jwks_uri" in ".well-known/openid-configuration" -// 2. Verify and extract it via the introspection endpoint of the IDP (RFC7662) for -// IDPs that provide that feature -// 3. Other IDP implementation specific methods. -// 4. Fallback to default value +// extractExpiration tries to extract the expriration time from the access token +// It tries so by parsing (and verifying the signature) the access_token as JWT. +// If it is a valid JWT the `exp` claim will be used that the token expiration time. +// If it is not a valid JWT we fallback to the configured cache TTL. +// This could still be enhanced by trying a to use the introspection endpoint (RFC7662), +// to validate the token. If it exists. func (m oidcAuth) extractExpiration(token string) time.Time { defaultExpiration := time.Now().Add(m.tokenCacheTTL) + jwks := m.getKeyfunc() + if jwks == nil { + return defaultExpiration + } + + claims := jwt.RegisteredClaims{} + _, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc) + if err != nil { + m.logger.Info().Err(err).Msg("Error parsing access_token as JWT") + return defaultExpiration + } + if claims.ExpiresAt != nil { + m.logger.Debug().Str("exp", claims.ExpiresAt.String()).Msg("Expiration Time from access_token") + return claims.ExpiresAt.Time + } return defaultExpiration } @@ -145,6 +159,58 @@ func (m oidcAuth) shouldServe(req *http.Request) bool { return strings.HasPrefix(header, "Bearer ") } +type jwksJSON struct { + JWKSURL string `json:"jwks_uri"` +} + +func (m *oidcAuth) getKeyfunc() *keyfunc.JWKS { + if m.jwks == nil { + wellKnown := strings.TrimSuffix(m.oidcIss, "/") + "/.well-known/openid-configuration" + resp, err := m.httpClient.Get(wellKnown) + if err != nil { + return nil + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + m.logger.Error().Err(err).Msg("unable to read discovery response body") + return nil + } + + if resp.StatusCode != http.StatusOK { + m.logger.Error().Str("status", resp.Status).Str("body", string(body)).Msg("error requesting openid-configuration") + return nil + } + + var j jwksJSON + err = json. + Unmarshal(body, &j) + if err != nil { + m.logger.Error().Err(err).Msg("failed to decode provider discovered openid-configuration") + return nil + } + m.logger.Debug().Str("jwks", j.JWKSURL).Msg("discovered jwks endpoint") + // FIXME: make configurable + options := keyfunc.Options{ + RefreshErrorHandler: func(err error) { + m.logger.Error().Err(err).Msg("There was an error with the jwt.Keyfunc") + }, + RefreshInterval: time.Hour, + RefreshRateLimit: time.Minute * 5, + RefreshTimeout: time.Second * 10, + RefreshUnknownKID: true, + } + m.jwks, err = keyfunc.Get(j.JWKSURL, options) + if err != nil { + m.jwks = nil + m.logger.Error().Err(err).Msg("Failed to create JWKS from resource at the given URL.") + return nil + } + } + return m.jwks +} + func (m *oidcAuth) getProvider() OIDCProvider { if m.provider == nil { // Lazily initialize a provider