Merge pull request #6100 from rhafer/backchannel

Fix backchannel logout
This commit is contained in:
Christian Richter
2023-04-20 18:42:03 +02:00
committed by GitHub
4 changed files with 32 additions and 14 deletions

View File

@@ -27,7 +27,7 @@ import (
// OIDCClient used to mock the oidc client during tests
type OIDCClient interface {
UserInfo(ctx context.Context, ts oauth2.TokenSource) (*UserInfo, error)
VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error)
VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error)
VerifyLogoutToken(ctx context.Context, token string) (*LogoutToken, error)
}
@@ -46,6 +46,11 @@ type KeySet interface {
VerifySignature(ctx context.Context, jwt string) (payload []byte, err error)
}
type RegClaimsWithSID struct {
SessionID string `json:"sid"`
jwt.RegisteredClaims
}
type oidcClient struct {
// Logger to use for logging, must be set
Logger log.Logger
@@ -270,26 +275,26 @@ func (c *oidcClient) UserInfo(ctx context.Context, tokenSource oauth2.TokenSourc
}, nil
}
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (jwt.RegisteredClaims, []string, error) {
func (c *oidcClient) VerifyAccessToken(ctx context.Context, token string) (RegClaimsWithSID, []string, error) {
var mapClaims []string
if err := c.lookupWellKnownOpenidConfiguration(ctx); err != nil {
return jwt.RegisteredClaims{}, mapClaims, err
return RegClaimsWithSID{}, mapClaims, err
}
switch c.accessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return c.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
c.Logger.Debug().Msg("Access Token verification disabled")
return jwt.RegisteredClaims{}, mapClaims, nil
return RegClaimsWithSID{}, mapClaims, nil
default:
c.Logger.Error().Str("access_token_verify_method", c.accessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return jwt.RegisteredClaims{}, mapClaims, errors.New("unknown Access Token Verification method")
return RegClaimsWithSID{}, mapClaims, errors.New("unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (c *oidcClient) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) {
var claims jwt.RegisteredClaims
func (c *oidcClient) verifyAccessTokenJWT(token string) (RegClaimsWithSID, []string, error) {
var claims RegClaimsWithSID
var mapClaims []string
jwks := c.getKeyfunc()
if jwks == nil {

View File

@@ -395,7 +395,7 @@ func (i *LDAP) AddUsersToEducationSchool(ctx context.Context, schoolNumberOrID s
user, err := i.getEducationUserByNameOrID(memberID)
if err != nil {
i.logger.Warn().Str("userid", memberID).Msg("User does not exist")
return err
return errorcode.New(errorcode.ItemNotFound, fmt.Sprintf("user '%s' not found", memberID))
}
userEntries = append(userEntries, user)
}

View File

@@ -220,7 +220,9 @@ type jse struct {
// handle backchannel logout requests as per https://openid.net/specs/openid-connect-backchannel-1_0.html#BCRequest
func (h *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Request) {
// parse the application/x-www-form-urlencoded POST request
logger := h.logger.SubloggerWithRequestID(r.Context())
if err := r.ParseForm(); err != nil {
logger.Warn().Err(err).Msg("ParseForm failed")
render.Status(r, http.StatusBadRequest)
render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()})
return
@@ -228,6 +230,7 @@ func (h *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re
logoutToken, err := h.oidcClient.VerifyLogoutToken(r.Context(), r.PostFormValue("logout_token"))
if err != nil {
logger.Warn().Err(err).Msg("VerifyLogoutToken failed")
render.Status(r, http.StatusBadRequest)
render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()})
return
@@ -240,19 +243,30 @@ func (h *StaticRouteHandler) backchannelLogout(w http.ResponseWriter, r *http.Re
return
}
if err != nil {
logger.Error().Err(err).Msg("Error reading userinfo cache")
render.Status(r, http.StatusBadRequest)
render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()})
return
}
for _, record := range records {
err = h.userInfoCache.Delete(string(record.Value))
if !errors.Is(err, microstore.ErrNotFound) {
if err != nil && !errors.Is(err, microstore.ErrNotFound) {
// Spec requires us to return a 400 BadRequest when the session could not be destroyed
h.logger.Err(err).Msg("could not delete user info from cache")
logger.Err(err).Msg("could not delete user info from cache")
render.Status(r, http.StatusBadRequest)
render.JSON(w, r, jse{Error: "invalid_request", ErrorDescription: err.Error()})
return
}
logger.Debug().Msg("Deleted userinfo from cache")
}
// we can ignore errors when cleaning up the lookup table
_ = h.userInfoCache.Delete(logoutToken.SessionId)
err = h.userInfoCache.Delete(logoutToken.SessionId)
if err != nil {
logger.Debug().Err(err).Msg("Failed to cleanup sessionid lookup entry")
}
render.Status(r, http.StatusOK)
render.JSON(w, r, nil)

View File

@@ -11,7 +11,6 @@ import (
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"github.com/shamaton/msgpack/v2"
store "go-micro.dev/v4/store"
@@ -104,7 +103,7 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
}
if sid, ok := claims["sid"]; ok {
if sid := aClaims.SessionID; sid != "" {
// reuse user cache for session id lookup
err = m.userInfoCache.Write(&store.Record{
Key: fmt.Sprintf("%s", sid),
@@ -125,7 +124,7 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
// extractExpiration tries to extract the expriration time from the access token
// If the access token does not have an exp claim it will fallback to the configured
// default expiration
func (m OIDCAuthenticator) extractExpiration(aClaims jwt.RegisteredClaims) time.Time {
func (m OIDCAuthenticator) extractExpiration(aClaims oidc.RegClaimsWithSID) time.Time {
defaultExpiration := time.Now().Add(m.DefaultTokenCacheTTL)
if aClaims.ExpiresAt != nil {
m.Logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")