proxy(oidc): Emit a UserSignedIn event on new session

Every time the OIDC middleware sees a new access token (i.e when it needs
to update the userinfo cache) we consider that as a new login. In this case
the middleware add a new flag to the context, which is then used by the
accountresolver middleware to publish a UserSignedIn event.
The event needs to be sent by the accountresolver middleware, because only
at that point we know the user id of the user that just logged in.

(It would probably makes sense to merge the auth and account middleware into a
single component to avoid passing flags around via context)
This commit is contained in:
Ralf Haferkamp
2024-08-29 10:17:07 +02:00
committed by Ralf Haferkamp
parent 3b0ff50bf0
commit cb8934081f
5 changed files with 64 additions and 13 deletions

View File

@@ -5,6 +5,9 @@ import "context"
// contextKey is the key for oidc claims in a context
type contextKey struct{}
// newSessionFlagKey is the key for the new session flag in a context
type newSessionFlagKey struct{}
// NewContext makes a new context that contains the OpenID connect claims in a map.
func NewContext(parent context.Context, c map[string]interface{}) context.Context {
return context.WithValue(parent, contextKey{}, c)
@@ -15,3 +18,14 @@ func FromContext(ctx context.Context) map[string]interface{} {
s, _ := ctx.Value(contextKey{}).(map[string]interface{})
return s
}
// NewContextSessionFlag makes a new context that contains the new session flag.
func NewContextSessionFlag(ctx context.Context, flag bool) context.Context {
return context.WithValue(ctx, newSessionFlagKey{}, flag)
}
// NewSessionFlagFromContext returns the new session flag stored in a context.
func NewSessionFlagFromContext(ctx context.Context) bool {
s, _ := ctx.Value(newSessionFlagKey{}).(bool)
return s
}

View File

@@ -182,7 +182,7 @@ func Server(cfg *config.Config) *cli.Command {
}
{
middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, gatewaySelector, serviceSelector)
middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, publisher, gatewaySelector, serviceSelector)
server, err := proxyHTTP.Server(
proxyHTTP.Handler(lh.Handler()),
@@ -236,8 +236,10 @@ func Server(cfg *config.Config) *cli.Command {
}
func loadMiddlewares(logger log.Logger, cfg *config.Config,
userInfoCache, signingKeyStore microstore.Store, traceProvider trace.TracerProvider, metrics metrics.Metrics,
userProvider backend.UserBackend, gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain {
userInfoCache, signingKeyStore microstore.Store,
traceProvider trace.TracerProvider, metrics metrics.Metrics,
userProvider backend.UserBackend, publisher events.Publisher,
gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain {
rolesClient := settingssvc.NewRoleService("com.owncloud.api.settings", cfg.GrpcClient)
policiesProviderClient := policiessvc.NewPoliciesProviderService("com.owncloud.api.policies", cfg.GrpcClient)
@@ -354,6 +356,7 @@ func loadMiddlewares(logger log.Logger, cfg *config.Config,
middleware.UserOIDCClaim(cfg.UserOIDCClaim),
middleware.UserCS3Claim(cfg.UserCS3Claim),
middleware.AutoprovisionAccounts(cfg.AutoprovisionAccounts),
middleware.EventsPublisher(publisher),
),
middleware.SelectorCookie(
middleware.Logger(logger),

View File

@@ -11,6 +11,8 @@ import (
"github.com/owncloud/ocis/v2/services/proxy/pkg/userroles"
revactx "github.com/cs3org/reva/v2/pkg/ctx"
"github.com/cs3org/reva/v2/pkg/events"
"github.com/cs3org/reva/v2/pkg/utils"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
)
@@ -37,6 +39,7 @@ func AccountResolver(optionSetters ...Option) func(next http.Handler) http.Handl
userRoleAssigner: options.UserRoleAssigner,
autoProvisionAccounts: options.AutoprovisionAccounts,
lastGroupSyncCache: lastGroupSyncCache,
eventsPublisher: options.EventsPublisher,
}
}
}
@@ -53,6 +56,7 @@ type accountResolver struct {
// memberships was done for a specific user. This is used to trigger a sync
// with every single request.
lastGroupSyncCache *ttlcache.Cache[string, struct{}]
eventsPublisher events.Publisher
}
func readUserIDClaim(path string, claims map[string]interface{}) (string, error) {
@@ -172,6 +176,17 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) {
return
}
// If this is a new session, publish user login event
if newSession := oidc.NewSessionFlagFromContext(ctx); newSession && m.eventsPublisher != nil {
event := events.UserSignedIn{
Executant: user.Id,
Timestamp: utils.TimeToTS(time.Now()),
}
if err := events.Publish(req.Context(), m.eventsPublisher, event); err != nil {
m.logger.Error().Err(err).Msg("could not publish user signin event.")
}
}
// add user to context for selectors
ctx = revactx.ContextSetUser(ctx, user)
req = req.WithContext(ctx)

View File

@@ -53,7 +53,7 @@ type OIDCAuthenticator struct {
TimeFunc func() time.Time
}
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) {
func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, bool, error) {
var claims map[string]interface{}
// use a 64 bytes long hash to have 256-bit collision resistance.
@@ -69,16 +69,16 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil {
m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo")
if ok := verifyExpiresAt(claims, m.TimeFunc()); !ok {
return nil, jwt.ErrTokenExpired
return nil, false, jwt.ErrTokenExpired
}
return claims, nil
return claims, false, nil
}
m.Logger.Error().Err(err).Msg("could not unmarshal userinfo")
}
aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token)
if err != nil {
return nil, errors.Wrap(err, "failed to verify access token")
return nil, false, errors.Wrap(err, "failed to verify access token")
}
if !m.skipUserInfo {
@@ -91,10 +91,10 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
oauth2.StaticTokenSource(oauth2Token),
)
if err != nil {
return nil, errors.Wrap(err, "failed to get userinfo")
return nil, false, errors.Wrap(err, "failed to get userinfo")
}
if err := userInfo.Claims(&claims); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal userinfo claims")
return nil, false, errors.Wrap(err, "failed to unmarshal userinfo claims")
}
}
@@ -128,8 +128,12 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri
}
}()
// If we get here this was a new login (or a renewal of the token)
// add a flag about that to the claims, to be able to distinguish
// it in the accountresolver middleware
m.Logger.Debug().Interface("claims", claims).Msg("extracted claims")
return claims, nil
return claims, true, nil
}
// extractExpiration tries to extract the expriration time from the access token
@@ -180,7 +184,7 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool)
return nil, false
}
claims, err := m.getClaims(token, r)
claims, newSession, err := m.getClaims(token, r)
if err != nil {
host, port, _ := net.SplitHostPort(r.RemoteAddr)
m.Logger.Error().
@@ -198,5 +202,11 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool)
Str("authenticator", "oidc").
Str("path", r.URL.Path).
Msg("successfully authenticated request")
return r.WithContext(oidc.NewContext(r.Context(), claims)), true
ctx := r.Context()
if newSession {
ctx = oidc.NewContextSessionFlag(ctx, true)
}
return r.WithContext(oidc.NewContext(ctx, claims)), true
}

View File

@@ -5,6 +5,7 @@ import (
"time"
gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1"
"github.com/cs3org/reva/v2/pkg/events"
"github.com/cs3org/reva/v2/pkg/rgrpc/todo/pool"
"github.com/owncloud/ocis/v2/ocis-pkg/log"
"github.com/owncloud/ocis/v2/ocis-pkg/oidc"
@@ -69,7 +70,8 @@ type Options struct {
// TraceProvider sets the tracing provider.
TraceProvider trace.TracerProvider
// SkipUserInfo prevents the oidc middleware from querying the userinfo endpoint and read any claims directly from the access token instead
SkipUserInfo bool
SkipUserInfo bool
EventsPublisher events.Publisher
}
// newOptions initializes the available default options.
@@ -236,3 +238,10 @@ func SkipUserInfo(val bool) Option {
o.SkipUserInfo = val
}
}
// EventsPublisher sets the events publisher.
func EventsPublisher(ep events.Publisher) Option {
return func(o *Options) {
o.EventsPublisher = ep
}
}