diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 8250f834c..66101921e 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -7,9 +7,9 @@ import ( "strings" "time" + "github.com/golang-jwt/jwt/v4" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" - "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" store "go-micro.dev/v4/store" @@ -35,6 +35,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator { oidcClient: options.OIDCClient, AccessTokenVerifyMethod: options.AccessTokenVerifyMethod, skipUserInfo: options.SkipUserInfo, + TimeFunc: time.Now, } } @@ -48,6 +49,7 @@ type OIDCAuthenticator struct { oidcClient oidc.OIDCClient AccessTokenVerifyMethod string skipUserInfo bool + TimeFunc func() time.Time } func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) { @@ -65,6 +67,9 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri if len(record) > 0 { if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil { m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo") + if ok := verifyExpiresAt(claims, m.TimeFunc()); !ok { + return nil, jwt.ErrTokenExpired + } return claims, nil } m.Logger.Error().Err(err).Msg("could not unmarshal userinfo") @@ -93,6 +98,8 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri } expiration := m.extractExpiration(aClaims) + // always set an exp claim + claims["exp"] = expiration.Unix() go func() { if d, err := msgpack.MarshalAsMap(claims); err != nil { m.Logger.Error().Err(err).Msg("failed to marshal claims for userinfo cache") @@ -136,6 +143,19 @@ func (m OIDCAuthenticator) extractExpiration(aClaims oidc.RegClaimsWithSID) time return defaultExpiration } +func verifyExpiresAt(claims map[string]interface{}, cmp time.Time) bool { + var expiry time.Time + switch v := claims["exp"].(type) { + case nil: + return false + case int64: + expiry = time.Unix(v, 0) + case uint32: + expiry = time.Unix(int64(v), 0) + } + return cmp.Before(expiry) +} + func (m OIDCAuthenticator) shouldServe(req *http.Request) bool { if m.OIDCIss == "" { return false