extract full claims from jwt token to get session id

Signed-off-by: Christian Richter <crichter@owncloud.com>
This commit is contained in:
Christian Richter
2023-04-11 15:49:24 +02:00
parent e543c8f60d
commit a3640b0565
6 changed files with 655 additions and 22 deletions
+42 -9
View File
@@ -3,6 +3,7 @@ package middleware
import (
"context"
"encoding/base64"
"fmt"
"net/http"
"strings"
"sync"
@@ -39,6 +40,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
return &OIDCAuthenticator{
Logger: options.Logger,
userInfoCache: options.Cache,
sessionLookupCache: options.Cache,
DefaultTokenCacheTTL: options.DefaultAccessTokenTTL,
HTTPClient: options.HTTPClient,
OIDCIss: options.OIDCIss,
@@ -56,6 +58,7 @@ type OIDCAuthenticator struct {
HTTPClient *http.Client
OIDCIss string
userInfoCache store.Store
sessionLookupCache store.Store
DefaultTokenCacheTTL time.Duration
ProviderFunc func() (OIDCProvider, error)
AccessTokenVerifyMethod string
@@ -87,7 +90,16 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
}
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
}
aClaims, err := m.verifyAccessToken(token)
// TODO: use mClaims
aClaims, mClaims, err := m.verifyAccessToken(token)
//fmt.Println(mClaims)
vals := make([]string, len(mClaims))
for k, v := range mClaims {
s, _ := base64.StdEncoding.DecodeString(v)
vals[k] = string(s)
}
fmt.Println(vals)
if err != nil {
return nil, errors.Wrap(err, "failed to verify access token")
}
@@ -120,6 +132,17 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
if err != nil {
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
}
if sid, ok := claims["sid"]; ok {
err = m.sessionLookupCache.Write(&store.Record{
Key: fmt.Sprintf("%s", sid),
Value: []byte(encodedHash),
Expiry: time.Until(expiration),
})
}
if err != nil {
m.Logger.Error().Err(err).Msg("failed to write session lookup cache")
}
}
}()
@@ -127,42 +150,52 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
return claims, nil
}
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
// TODO: update jwt lib to have access to session id, or extract the session id and return it
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, []string, error) {
var mapClaims []string
switch m.AccessTokenVerifyMethod {
case config.AccessTokenVerificationJWT:
return m.verifyAccessTokenJWT(token)
case config.AccessTokenVerificationNone:
m.Logger.Debug().Msg("Access Token verification disabled")
return jwt.RegisteredClaims{}, nil
return jwt.RegisteredClaims{}, mapClaims, nil
default:
m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
return jwt.RegisteredClaims{}, errors.New("Unknown Access Token Verification method")
return jwt.RegisteredClaims{}, mapClaims, errors.New("Unknown Access Token Verification method")
}
}
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) {
var claims jwt.RegisteredClaims
var mapClaims []string
jwks := m.getKeyfunc()
if jwks == nil {
return claims, errors.New("Error initializing jwks keyfunc")
return claims, mapClaims, errors.New("Error initializing jwks keyfunc")
}
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
// TODO: decode mapClaims to sth readable
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, err
return claims, mapClaims, err
}
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
if err != nil {
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
return claims, mapClaims, err
}
if !claims.VerifyIssuer(m.OIDCIss, true) {
vErr := jwt.ValidationError{}
vErr.Inner = jwt.ErrTokenInvalidIssuer
vErr.Errors |= jwt.ValidationErrorIssuer
return claims, vErr
return claims, mapClaims, vErr
}
return claims, nil
return claims, mapClaims, nil
}
// extractExpiration tries to extract the expriration time from the access token