diff --git a/ocis-pkg/oidc/context.go b/ocis-pkg/oidc/context.go index 1f4bf2d65..fbb55cdf1 100644 --- a/ocis-pkg/oidc/context.go +++ b/ocis-pkg/oidc/context.go @@ -5,6 +5,9 @@ import "context" // contextKey is the key for oidc claims in a context type contextKey struct{} +// newSessionFlagKey is the key for the new session flag in a context +type newSessionFlagKey struct{} + // NewContext makes a new context that contains the OpenID connect claims in a map. func NewContext(parent context.Context, c map[string]interface{}) context.Context { return context.WithValue(parent, contextKey{}, c) @@ -15,3 +18,14 @@ func FromContext(ctx context.Context) map[string]interface{} { s, _ := ctx.Value(contextKey{}).(map[string]interface{}) return s } + +// NewContextSessionFlag makes a new context that contains the new session flag. +func NewContextSessionFlag(ctx context.Context, flag bool) context.Context { + return context.WithValue(ctx, newSessionFlagKey{}, flag) +} + +// NewSessionFlagFromContext returns the new session flag stored in a context. +func NewSessionFlagFromContext(ctx context.Context) bool { + s, _ := ctx.Value(newSessionFlagKey{}).(bool) + return s +} diff --git a/services/proxy/pkg/command/server.go b/services/proxy/pkg/command/server.go index 4c2fb9a23..ae17b743a 100644 --- a/services/proxy/pkg/command/server.go +++ b/services/proxy/pkg/command/server.go @@ -182,7 +182,7 @@ func Server(cfg *config.Config) *cli.Command { } { - middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, gatewaySelector, serviceSelector) + middlewares := loadMiddlewares(logger, cfg, userInfoCache, signingKeyStore, traceProvider, *m, userProvider, publisher, gatewaySelector, serviceSelector) server, err := proxyHTTP.Server( proxyHTTP.Handler(lh.Handler()), @@ -236,8 +236,10 @@ func Server(cfg *config.Config) *cli.Command { } func loadMiddlewares(logger log.Logger, cfg *config.Config, - userInfoCache, signingKeyStore microstore.Store, traceProvider trace.TracerProvider, metrics metrics.Metrics, - userProvider backend.UserBackend, gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain { + userInfoCache, signingKeyStore microstore.Store, + traceProvider trace.TracerProvider, metrics metrics.Metrics, + userProvider backend.UserBackend, publisher events.Publisher, + gatewaySelector pool.Selectable[gateway.GatewayAPIClient], serviceSelector selector.Selector) alice.Chain { rolesClient := settingssvc.NewRoleService("com.owncloud.api.settings", cfg.GrpcClient) policiesProviderClient := policiessvc.NewPoliciesProviderService("com.owncloud.api.policies", cfg.GrpcClient) @@ -354,6 +356,7 @@ func loadMiddlewares(logger log.Logger, cfg *config.Config, middleware.UserOIDCClaim(cfg.UserOIDCClaim), middleware.UserCS3Claim(cfg.UserCS3Claim), middleware.AutoprovisionAccounts(cfg.AutoprovisionAccounts), + middleware.EventsPublisher(publisher), ), middleware.SelectorCookie( middleware.Logger(logger), diff --git a/services/proxy/pkg/middleware/account_resolver.go b/services/proxy/pkg/middleware/account_resolver.go index f560af347..422d067ed 100644 --- a/services/proxy/pkg/middleware/account_resolver.go +++ b/services/proxy/pkg/middleware/account_resolver.go @@ -11,6 +11,8 @@ import ( "github.com/owncloud/ocis/v2/services/proxy/pkg/userroles" revactx "github.com/cs3org/reva/v2/pkg/ctx" + "github.com/cs3org/reva/v2/pkg/events" + "github.com/cs3org/reva/v2/pkg/utils" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" ) @@ -37,6 +39,7 @@ func AccountResolver(optionSetters ...Option) func(next http.Handler) http.Handl userRoleAssigner: options.UserRoleAssigner, autoProvisionAccounts: options.AutoprovisionAccounts, lastGroupSyncCache: lastGroupSyncCache, + eventsPublisher: options.EventsPublisher, } } } @@ -53,6 +56,7 @@ type accountResolver struct { // memberships was done for a specific user. This is used to trigger a sync // with every single request. lastGroupSyncCache *ttlcache.Cache[string, struct{}] + eventsPublisher events.Publisher } func readUserIDClaim(path string, claims map[string]interface{}) (string, error) { @@ -172,6 +176,17 @@ func (m accountResolver) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } + // If this is a new session, publish user login event + if newSession := oidc.NewSessionFlagFromContext(ctx); newSession && m.eventsPublisher != nil { + event := events.UserSignedIn{ + Executant: user.Id, + Timestamp: utils.TimeToTS(time.Now()), + } + if err := events.Publish(req.Context(), m.eventsPublisher, event); err != nil { + m.logger.Error().Err(err).Msg("could not publish user signin event.") + } + } + // add user to context for selectors ctx = revactx.ContextSetUser(ctx, user) req = req.WithContext(ctx) diff --git a/services/proxy/pkg/middleware/oidc_auth.go b/services/proxy/pkg/middleware/oidc_auth.go index 47b9c424d..ed29111c0 100644 --- a/services/proxy/pkg/middleware/oidc_auth.go +++ b/services/proxy/pkg/middleware/oidc_auth.go @@ -53,7 +53,7 @@ type OIDCAuthenticator struct { TimeFunc func() time.Time } -func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, error) { +func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[string]interface{}, bool, error) { var claims map[string]interface{} // use a 64 bytes long hash to have 256-bit collision resistance. @@ -69,16 +69,16 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri if err = msgpack.UnmarshalAsMap(record[0].Value, &claims); err == nil { m.Logger.Debug().Interface("claims", claims).Msg("cache hit for userinfo") if ok := verifyExpiresAt(claims, m.TimeFunc()); !ok { - return nil, jwt.ErrTokenExpired + return nil, false, jwt.ErrTokenExpired } - return claims, nil + return claims, false, nil } m.Logger.Error().Err(err).Msg("could not unmarshal userinfo") } aClaims, claims, err := m.oidcClient.VerifyAccessToken(req.Context(), token) if err != nil { - return nil, errors.Wrap(err, "failed to verify access token") + return nil, false, errors.Wrap(err, "failed to verify access token") } if !m.skipUserInfo { @@ -91,10 +91,10 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri oauth2.StaticTokenSource(oauth2Token), ) if err != nil { - return nil, errors.Wrap(err, "failed to get userinfo") + return nil, false, errors.Wrap(err, "failed to get userinfo") } if err := userInfo.Claims(&claims); err != nil { - return nil, errors.Wrap(err, "failed to unmarshal userinfo claims") + return nil, false, errors.Wrap(err, "failed to unmarshal userinfo claims") } } @@ -128,8 +128,12 @@ func (m *OIDCAuthenticator) getClaims(token string, req *http.Request) (map[stri } }() + // If we get here this was a new login (or a renewal of the token) + // add a flag about that to the claims, to be able to distinguish + // it in the accountresolver middleware + m.Logger.Debug().Interface("claims", claims).Msg("extracted claims") - return claims, nil + return claims, true, nil } // extractExpiration tries to extract the expriration time from the access token @@ -180,7 +184,7 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) return nil, false } - claims, err := m.getClaims(token, r) + claims, newSession, err := m.getClaims(token, r) if err != nil { host, port, _ := net.SplitHostPort(r.RemoteAddr) m.Logger.Error(). @@ -198,5 +202,11 @@ func (m *OIDCAuthenticator) Authenticate(r *http.Request) (*http.Request, bool) Str("authenticator", "oidc"). Str("path", r.URL.Path). Msg("successfully authenticated request") - return r.WithContext(oidc.NewContext(r.Context(), claims)), true + + ctx := r.Context() + if newSession { + ctx = oidc.NewContextSessionFlag(ctx, true) + } + + return r.WithContext(oidc.NewContext(ctx, claims)), true } diff --git a/services/proxy/pkg/middleware/options.go b/services/proxy/pkg/middleware/options.go index 917f5ca31..90b60c45b 100644 --- a/services/proxy/pkg/middleware/options.go +++ b/services/proxy/pkg/middleware/options.go @@ -5,6 +5,7 @@ import ( "time" gateway "github.com/cs3org/go-cs3apis/cs3/gateway/v1beta1" + "github.com/cs3org/reva/v2/pkg/events" "github.com/cs3org/reva/v2/pkg/rgrpc/todo/pool" "github.com/owncloud/ocis/v2/ocis-pkg/log" "github.com/owncloud/ocis/v2/ocis-pkg/oidc" @@ -69,7 +70,8 @@ type Options struct { // TraceProvider sets the tracing provider. TraceProvider trace.TracerProvider // SkipUserInfo prevents the oidc middleware from querying the userinfo endpoint and read any claims directly from the access token instead - SkipUserInfo bool + SkipUserInfo bool + EventsPublisher events.Publisher } // newOptions initializes the available default options. @@ -236,3 +238,10 @@ func SkipUserInfo(val bool) Option { o.SkipUserInfo = val } } + +// EventsPublisher sets the events publisher. +func EventsPublisher(ep events.Publisher) Option { + return func(o *Options) { + o.EventsPublisher = ep + } +}