Files
opencloud/services/proxy/pkg/middleware/oidc_auth.go
2024-05-28 09:29:08 +02:00

197 lines
5.9 KiB
Go

package middleware
import (
"context"
"encoding/base64"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
"github.com/pkg/errors"
"github.com/shamaton/msgpack/v2"
store "go-micro.dev/v4/store"
"golang.org/x/crypto/sha3"
"golang.org/x/oauth2"
)
const (
_headerAuthorization = "Authorization"
_bearerPrefix = "Bearer "
)
// NewOIDCAuthenticator returns a ready to use authenticator which can handle OIDC authentication.
func NewOIDCAuthenticator(opts ...Option) *OIDCAuthenticator {
options := newOptions(opts...)
return &OIDCAuthenticator{
Logger: options.Logger,
userInfoCache: options.UserInfoCache,
DefaultTokenCacheTTL: options.DefaultAccessTokenTTL,
HTTPClient: options.HTTPClient,
OIDCIss: options.OIDCIss,
oidcClient: options.OIDCClient,
AccessTokenVerifyMethod: options.AccessTokenVerifyMethod,
skipUserInfo: options.SkipUserInfo,
TimeFunc: time.Now,
}
}
// OIDCAuthenticator is an authenticator responsible for OIDC authentication.
type OIDCAuthenticator struct {
Logger log.Logger
HTTPClient *http.Client
OIDCIss string
userInfoCache store.Store
DefaultTokenCacheTTL time.Duration
oidcClient oidc.OIDCClient
AccessTokenVerifyMethod string
skipUserInfo bool
TimeFunc func() time.Time
}
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
var claims map[string]interface{}
// use a 64 bytes long hash to have 256-bit collision resistance.
hash := make([]byte, 64)
sha3.ShakeSum256(hash, []byte(token))
encodedHash := base64.URLEncoding.EncodeToString(hash)
record, err := m.userInfoCache.Read(encodedHash)
if err != nil && err != store.ErrNotFound {
m.Logger.Error().Err(err).Msg("could not read from userinfo cache")
}
if len(record) > 0 {
if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil {
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
if ok := verifyExpiresAt(claims, m.TimeFunc()); !ok {
return nil, jwt.ErrTokenExpired
}
return claims, nil
}
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
}
aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify access token")
}
if !m.skipUserInfo {
oauth2Token := &oauth2.Token{
AccessToken: token,
}
userInfo, err := m.oidcClient.UserInfo(
context.WithValue(req.Context(), oauth2.HTTPClient, m.HTTPClient),
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
return nil, errors.Wrap(err, "failed to get userinfo")
}
if err := userInfo.Claims(&claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
}
}
expiration := m.extractExpiration(aClaims)
// always set an exp claim
claims["exp"] = expiration.Unix()
go func() {
if d, err := msgpack.MarshalAsMap(claims); err != nil {
m.Logger.Error().Err(err).Msg("failed to marshal claims for userinfo cache")
} else {
err = m.userInfoCache.Write(&store.Record{
Key: encodedHash,
Value: d,
Expiry: time.Until(expiration),
})
if err != nil {
m.Logger.Error().Err(err).Msg("failed to write to userinfo cache")
}
if sid := aClaims.SessionID; sid != "" {
// reuse user cache for session id lookup
err = m.userInfoCache.Write(&store.Record{
Key: sid,
Value: []byte(encodedHash),
Expiry: time.Until(expiration),
})
if err != nil {
m.Logger.Error().Err(err).Msg("failed to write session lookup cache")
}
}
}
}()
m.Logger.Debug().Interface("claims", claims).Msg("extracted claims")
return claims, nil
}
// extractExpiration tries to extract the expriration time from the access token
// If the access token does not have an exp claim it will fallback to the configured
// default expiration
func (m OIDCAuthenticator) extractExpiration(aClaims oidc.RegClaimsWithSID) time.Time {
defaultExpiration := time.Now().Add(m.DefaultTokenCacheTTL)
if aClaims.ExpiresAt != nil {
m.Logger.Debug().Str("exp", aClaims.ExpiresAt.String()).Msg("Expiration Time from access_token")
return aClaims.ExpiresAt.Time
}
return defaultExpiration
}
func verifyExpiresAt(claims map[string]interface{}, cmp time.Time) bool {
var expiry time.Time
switch v := claims["exp"].(type) {
case nil:
return false
case int64:
expiry = time.Unix(v, 0)
case uint32:
expiry = time.Unix(int64(v), 0)
}
return cmp.Before(expiry)
}
func (m OIDCAuthenticator) shouldServe(req *http.Request) bool {
if m.OIDCIss == "" {
return false
}
header := req.Header.Get(_headerAuthorization)
return strings.HasPrefix(header, _bearerPrefix)
}
// Authenticate implements the authenticator interface to authenticate requests via oidc auth.
func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) {
// there is no bearer token on the request,
if !m.shouldServe(r) {
// The authentication of public path requests is handled by another authenticator.
// Since we can't guarantee the order of execution of the authenticators, we better
// implement an early return here for paths we can't authenticate in this authenticator.
return nil, false
}
token := strings.TrimPrefix(r.Header.Get(_headerAuthorization), _bearerPrefix)
if token == "" {
return nil, false
}
claims, err := m.getClaims(token, r)
if err != nil {
m.Logger.Error().
Err(err).
Str("authenticator", "oidc").
Str("path", r.URL.Path).
Msg("failed to authenticate the request")
return nil, false
}
m.Logger.Debug().
Str("authenticator", "oidc").
Str("path", r.URL.Path).
Msg("successfully authenticated request")
return r.WithContext(oidc.NewContext(r.Context(), claims)), true
}