mirror of
https://github.com/opencloud-eu/opencloud.git
synced 2026-04-26 05:58:27 -05:00
extract full claims from jwt token to get session id
Signed-off-by: Christian Richter <crichter@owncloud.com>
This commit is contained in:
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -39,6 +40,7 @@ func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
|
||||
return &OIDCAuthenticator{
|
||||
Logger: options.Logger,
|
||||
userInfoCache: options.Cache,
|
||||
sessionLookupCache: options.Cache,
|
||||
DefaultTokenCacheTTL: options.DefaultAccessTokenTTL,
|
||||
HTTPClient: options.HTTPClient,
|
||||
OIDCIss: options.OIDCIss,
|
||||
@@ -56,6 +58,7 @@ type OIDCAuthenticator struct {
|
||||
HTTPClient *http.Client
|
||||
OIDCIss string
|
||||
userInfoCache store.Store
|
||||
sessionLookupCache store.Store
|
||||
DefaultTokenCacheTTL time.Duration
|
||||
ProviderFunc func() (OIDCProvider, error)
|
||||
AccessTokenVerifyMethod string
|
||||
@@ -87,7 +90,16 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
|
||||
}
|
||||
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
|
||||
}
|
||||
aClaims, err := m.verifyAccessToken(token)
|
||||
// TODO: use mClaims
|
||||
aClaims, mClaims, err := m.verifyAccessToken(token)
|
||||
//fmt.Println(mClaims)
|
||||
vals := make([]string, len(mClaims))
|
||||
for k, v := range mClaims {
|
||||
s, _ := base64.StdEncoding.DecodeString(v)
|
||||
vals[k] = string(s)
|
||||
}
|
||||
fmt.Println(vals)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to verify access token")
|
||||
}
|
||||
@@ -120,6 +132,17 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
|
||||
if err != nil {
|
||||
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
|
||||
}
|
||||
|
||||
if sid, ok := claims["sid"]; ok {
|
||||
err = m.sessionLookupCache.Write(&store.Record{
|
||||
Key: fmt.Sprintf("%s", sid),
|
||||
Value: []byte(encodedHash),
|
||||
Expiry: time.Until(expiration),
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
m.Logger.Error().Err(err).Msg("failed to write session lookup cache")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -127,42 +150,52 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, error) {
|
||||
// TODO: update jwt lib to have access to session id, or extract the session id and return it
|
||||
func (m OIDCAuthenticator) verifyAccessToken(token string) (jwt.RegisteredClaims, []string, error) {
|
||||
var mapClaims []string
|
||||
switch m.AccessTokenVerifyMethod {
|
||||
case config.AccessTokenVerificationJWT:
|
||||
return m.verifyAccessTokenJWT(token)
|
||||
case config.AccessTokenVerificationNone:
|
||||
m.Logger.Debug().Msg("Access Token verification disabled")
|
||||
return jwt.RegisteredClaims{}, nil
|
||||
return jwt.RegisteredClaims{}, mapClaims, nil
|
||||
default:
|
||||
m.Logger.Error().Str("access_token_verify_method", m.AccessTokenVerifyMethod).Msg("Unknown Access Token verification setting")
|
||||
return jwt.RegisteredClaims{}, errors.New("Unknown Access Token Verification method")
|
||||
return jwt.RegisteredClaims{}, mapClaims, errors.New("Unknown Access Token Verification method")
|
||||
}
|
||||
}
|
||||
|
||||
// verifyAccessTokenJWT tries to parse and verify the access token as a JWT.
|
||||
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, error) {
|
||||
func (m OIDCAuthenticator) verifyAccessTokenJWT(token string) (jwt.RegisteredClaims, []string, error) {
|
||||
var claims jwt.RegisteredClaims
|
||||
var mapClaims []string
|
||||
jwks := m.getKeyfunc()
|
||||
if jwks == nil {
|
||||
return claims, errors.New("Error initializing jwks keyfunc")
|
||||
return claims, mapClaims, errors.New("Error initializing jwks keyfunc")
|
||||
}
|
||||
|
||||
_, err := jwt.ParseWithClaims(token, &claims, jwks.Keyfunc)
|
||||
_, mapClaims, err = new(jwt.Parser).ParseUnverified(token, jwt.MapClaims{})
|
||||
// TODO: decode mapClaims to sth readable
|
||||
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
if err != nil {
|
||||
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
||||
return claims, err
|
||||
return claims, mapClaims, err
|
||||
}
|
||||
m.Logger.Debug().Interface("access token", &claims).Msg("parsed access token")
|
||||
if err != nil {
|
||||
m.Logger.Info().Err(err).Msg("Failed to parse/verify the access token.")
|
||||
return claims, mapClaims, err
|
||||
}
|
||||
|
||||
if !claims.VerifyIssuer(m.OIDCIss, true) {
|
||||
vErr := jwt.ValidationError{}
|
||||
vErr.Inner = jwt.ErrTokenInvalidIssuer
|
||||
vErr.Errors |= jwt.ValidationErrorIssuer
|
||||
return claims, vErr
|
||||
return claims, mapClaims, vErr
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
return claims, mapClaims, nil
|
||||
}
|
||||
|
||||
// extractExpiration tries to extract the expriration time from the access token
|
||||
|
||||
Reference in New Issue
Block a user