Import client sessions into Infinispan concurrently for persistent sessions

Closes #41074

Signed-off-by: Pedro Ruivo <1492066+pruivo@users.noreply.github.com>
Co-authored-by: Pedro Ruivo <1492066+pruivo@users.noreply.github.com>
This commit is contained in:
Pedro Ruivo
2025-08-26 21:16:04 +01:00
committed by GitHub
parent 600f03d1d0
commit a01571c2cc
4 changed files with 197 additions and 117 deletions

View File

@@ -30,7 +30,6 @@ import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -92,6 +91,7 @@ import org.keycloak.models.utils.KeycloakModelUtils;
import org.keycloak.models.utils.UserModelDelegate;
import static org.keycloak.models.Constants.SESSION_NOTE_LIGHTWEIGHT_USER;
import static org.keycloak.models.sessions.infinispan.changes.ClientSessionPersistentChangelogBasedTransaction.createAuthenticatedClientSessionInstance;
import static org.keycloak.utils.StreamsUtil.paginatedStream;
/**
@@ -211,9 +211,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
SessionUpdateTask<AuthenticatedClientSessionEntity> createClientSessionTask = Tasks.addIfAbsentSync();
clientSessionTx.addTask(clientSessionId, createClientSessionTask, entity, persistenceState);
SessionUpdateTask<UserSessionEntity> registerClientSessionTask = new ClientSessionPersistentChangelogBasedTransaction.RegisterClientSessionTask(client.getId(), clientSessionId, userSession.isOffline());
sessionTx.addTask(userSession.getId(), registerClientSessionTask);
sessionTx.registerClientSession(userSession.getId(), client.getId(), clientSessionId, userSession.isOffline());
return adapter;
}
@@ -231,7 +229,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
SessionUpdateTask<UserSessionEntity> createSessionTask = Tasks.addIfAbsentSync();
sessionTx.addTask(id, createSessionTask, entity, persistenceState);
UserSessionAdapter adapter = user instanceof LightweightUserAdapter
UserSessionAdapter<?> adapter = user instanceof LightweightUserAdapter
? wrap(realm, entity, false, user)
: wrap(realm, entity, false);
adapter.setPersistenceState(persistenceState);
@@ -259,7 +257,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
return getUserSession(realm, id, null, false);
}
private UserSessionAdapter getUserSession(RealmModel realm, String id, UserSessionModel userSession, boolean offline) {
private UserSessionAdapter<?> getUserSession(RealmModel realm, String id, UserSessionModel userSession, boolean offline) {
SessionEntityWrapper<UserSessionEntity> entityWrapper = sessionTx.get(realm, id, userSession, offline);
return entityWrapper != null ? wrap(realm, entityWrapper.getEntity(), offline) : null;
}
@@ -526,9 +524,6 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
@Override
public void onClientRemoved(RealmModel realm, ClientModel client) {
// clusterEventsSenderTx.addEvent(
// ClientRemovedSessionEvent.createEvent(ClientRemovedSessionEvent.class, InfinispanUserSessionProviderFactory.CLIENT_REMOVED_SESSION_EVENT, session, realm.getId(), true),
// ClusterProvider.DCNotify.LOCAL_DC_ONLY);
UserSessionPersisterProvider sessionsPersister = session.getProvider(UserSessionPersisterProvider.class);
if (sessionsPersister != null) {
sessionsPersister.onClientRemoved(realm, client);
@@ -561,19 +556,19 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
sessionTx.addTask(sessionEntity.getId(), removeTask);
}
UserSessionAdapter wrap(RealmModel realm, UserSessionEntity entity, boolean offline, UserModel user) {
UserSessionAdapter<?> wrap(RealmModel realm, UserSessionEntity entity, boolean offline, UserModel user) {
if (entity == null) {
return null;
}
return new UserSessionAdapter(session, user, this, sessionTx, clientSessionTx, realm, entity, offline);
return new UserSessionAdapter<>(session, user, this, sessionTx, clientSessionTx, realm, entity, offline);
}
UserSessionAdapter wrap(RealmModel realm, UserSessionEntity entity, boolean offline) {
UserSessionAdapter<?> wrap(RealmModel realm, UserSessionEntity entity, boolean offline) {
UserModel user;
if (Profile.isFeatureEnabled(Feature.TRANSIENT_USERS) && entity.getNotes().containsKey(SESSION_NOTE_LIGHTWEIGHT_USER)) {
LightweightUserAdapter lua = LightweightUserAdapter.fromString(session, realm, entity.getNotes().get(SESSION_NOTE_LIGHTWEIGHT_USER));
final UserSessionAdapter us = wrap(realm, entity, offline, lua);
final UserSessionAdapter<?> us = wrap(realm, entity, offline, lua);
lua.setUpdateHandler(lua1 -> {
if (lua == lua1) { // Ensure there is no conflicting user model, only the latest lightweight user can be used
us.setNote(SESSION_NOTE_LIGHTWEIGHT_USER, lua1.serialize());
@@ -594,11 +589,11 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
}
UserSessionEntity getUserSessionEntity(RealmModel realm, UserSessionModel userSession, boolean offline) {
if (userSession instanceof UserSessionAdapter) {
if (userSession instanceof UserSessionAdapter<?> usa) {
if (!userSession.getRealm().equals(realm)) {
return null;
}
return ((UserSessionAdapter) userSession).getEntity();
return usa.getEntity();
} else {
return getUserSessionEntity(realm, userSession.getId(), offline);
}
@@ -613,7 +608,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
SessionUpdateTask<UserSessionEntity> importTask = Tasks.addIfAbsentSync();
sessionTx.addTask(userSession.getId(), importTask, entity, UserSessionModel.SessionPersistenceState.PERSISTENT);
UserSessionAdapter offlineUserSession = wrap(userSession.getRealm(), entity, true);
UserSessionAdapter<?> offlineUserSession = wrap(userSession.getRealm(), entity, true);
// started and lastSessionRefresh set to current time
int currentTime = Time.currentTime();
@@ -624,7 +619,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
}
@Override
public UserSessionAdapter getOfflineUserSession(RealmModel realm, String userSessionId) {
public UserSessionAdapter<?> getOfflineUserSession(RealmModel realm, String userSessionId) {
return getUserSession(realm, userSessionId, null, true);
}
@@ -643,7 +638,8 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
@Override
public AuthenticatedClientSessionModel createOfflineClientSession(AuthenticatedClientSessionModel clientSession, UserSessionModel offlineUserSession) {
UserSessionAdapter userSessionAdapter = (offlineUserSession instanceof UserSessionAdapter) ? (UserSessionAdapter) offlineUserSession :
UserSessionAdapter<?> userSessionAdapter = offlineUserSession instanceof UserSessionAdapter<?> ousa ?
ousa :
getOfflineUserSession(offlineUserSession.getRealm(), offlineUserSession.getId());
AuthenticatedClientSessionAdapter offlineClientSession = importOfflineClientSession(userSessionAdapter, clientSession);
@@ -679,17 +675,40 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
persistentUserSessions.forEach(userSessionModel -> importUserSession(userSessionModel, offline));
}
/**
* Imports a {@link UserSessionModel} and its {@link AuthenticatedClientSessionModel}.
*
* @param persistentUserSession The {@link UserSessionModel} read from the database.
* @param offline {@code true} if it is an offline user session.
* @return The {@link SessionEntityWrapper} to be used to keep track of any further session changes.
*/
public SessionEntityWrapper<UserSessionEntity> importUserSession(UserSessionModel persistentUserSession, boolean offline) {
Map<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessionsById = new HashMap<>();
UserSessionEntity userSessionEntityToImport = createUserSessionEntityInstance(persistentUserSession);
String realmId = userSessionEntityToImport.getRealmId();
String sessionId = userSessionEntityToImport.getId();
RealmModel realm = session.realms().getRealm(realmId);
long lifespan = offline ?
SessionTimeouts.getOfflineSessionLifespanMs(realm, null, userSessionEntityToImport) :
SessionTimeouts.getUserSessionLifespanMs(realm, null, userSessionEntityToImport);
long maxIdle = offline ?
SessionTimeouts.getOfflineSessionMaxIdleMs(realm, null, userSessionEntityToImport) :
SessionTimeouts.getUserSessionMaxIdleMs(realm, null, userSessionEntityToImport);
if (lifespan == SessionTimeouts.ENTRY_EXPIRED_FLAG || maxIdle == SessionTimeouts.ENTRY_EXPIRED_FLAG) {
log.debugf("Session has expired. Do not import user-session for sessionId=%s offline=%s", sessionId, offline);
return null;
}
Map<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessionsById = new HashMap<>();
for (Map.Entry<String, AuthenticatedClientSessionModel> entry : persistentUserSession.getAuthenticatedClientSessions().entrySet()) {
String clientUUID = entry.getKey();
AuthenticatedClientSessionModel clientSession = entry.getValue();
AuthenticatedClientSessionEntity clientSessionToImport = createAuthenticatedClientSessionInstance(userSessionEntityToImport.getId(), clientSession,
userSessionEntityToImport.getRealmId(), clientUUID, offline);
clientSessionToImport.setUserSessionId(userSessionEntityToImport.getId());
AuthenticatedClientSessionEntity clientSessionToImport = createAuthenticatedClientSessionInstance(sessionId, clientSession,
realmId, clientUUID, offline);
clientSessionToImport.setUserSessionId(sessionId);
if (offline) {
// Update timestamp to the same value as userSession. LastSessionRefresh of userSession from DB will have a correct value.
@@ -707,29 +726,20 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
SessionEntityWrapper<UserSessionEntity> wrappedUserSessionEntity = new SessionEntityWrapper<>(userSessionEntityToImport);
Map<String, SessionEntityWrapper<UserSessionEntity>> sessionsById =
Stream.of(wrappedUserSessionEntity).collect(Collectors.toMap(sessionEntityWrapper -> sessionEntityWrapper.getEntity().getId(), Function.identity()));
Cache<String, SessionEntityWrapper<UserSessionEntity>> cache = getCache(offline);
sessionsById = importSessionsWithExpiration(sessionsById, cache,
offline ? SessionTimeouts::getOfflineSessionLifespanMs : SessionTimeouts::getUserSessionLifespanMs,
offline ? SessionTimeouts::getOfflineSessionMaxIdleMs : SessionTimeouts::getUserSessionMaxIdleMs);
if (sessionsById.isEmpty()) {
return null;
SessionEntityWrapper<UserSessionEntity> existingSession = sessionTx.importSession(realm, sessionId, wrappedUserSessionEntity, offline, lifespan, maxIdle);
if (existingSession != null) {
// skip import the client sessions, they should have been imported too.
log.debugf("The user-session already imported by another transaction for sessionId=%s offline=%s", sessionId, offline);
return existingSession;
}
// Import client sessions
Cache<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> clientSessCache = getClientSessionCache(offline);
importSessionsWithExpiration(clientSessionsById, clientSessCache,
offline ? SessionTimeouts::getOfflineClientSessionLifespanMs : SessionTimeouts::getClientSessionLifespanMs,
offline ? SessionTimeouts::getOfflineClientSessionMaxIdleMs : SessionTimeouts::getClientSessionMaxIdleMs);
return sessionsById.entrySet().stream().findFirst().map(Map.Entry::getValue).orElse(null);
clientSessionTx.importSessionsConcurrently(realm, clientSessionsById, offline);
return wrappedUserSessionEntity;
}
// new import logic has been added to PersistentSessionsChangelogBasedTransaction, no longer in use.
@Deprecated(forRemoval = true, since = "26.4")
public <T extends SessionEntity, K> Map<K, SessionEntityWrapper<T>> importSessionsWithExpiration(Map<K, SessionEntityWrapper<T>> sessionsById,
BasicCache<K, SessionEntityWrapper<T>> cache, SessionFunction<T> lifespanMsCalculator,
SessionFunction<T> maxIdleTimeMsCalculator) {
@@ -769,7 +779,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
}).filter(Objects::nonNull).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
private UserSessionEntity createUserSessionEntityInstance(UserSessionModel userSession) {
private static UserSessionEntity createUserSessionEntityInstance(UserSessionModel userSession) {
UserSessionEntity entity = new UserSessionEntity(userSession.getId());
entity.setRealmId(userSession.getRealm().getId());
@@ -801,7 +811,7 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
}
private AuthenticatedClientSessionAdapter importOfflineClientSession(UserSessionAdapter sessionToImportInto,
private AuthenticatedClientSessionAdapter importOfflineClientSession(UserSessionAdapter<?> sessionToImportInto,
AuthenticatedClientSessionModel clientSession) {
AuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionInstance(sessionToImportInto.getId(), clientSession,
sessionToImportInto.getRealm().getId(), clientSession.getClient().getId(), true);
@@ -817,32 +827,11 @@ public class PersistentUserSessionProvider implements UserSessionProvider, Sessi
AuthenticatedClientSessionStore clientSessions = sessionToImportInto.getEntity().getAuthenticatedClientSessions();
clientSessions.put(clientSession.getClient().getId(), clientSessionId);
SessionUpdateTask<UserSessionEntity> registerClientSessionTask = new ClientSessionPersistentChangelogBasedTransaction.RegisterClientSessionTask(clientSession.getClient().getId(), clientSessionId, true);
sessionTx.addTask(sessionToImportInto.getId(), registerClientSessionTask);
sessionTx.registerClientSession(sessionToImportInto.getId(), clientSession.getClient().getId(), clientSessionId, true);
return new AuthenticatedClientSessionAdapter(session, entity, clientSession.getClient(), sessionToImportInto, clientSessionTx, true);
}
private AuthenticatedClientSessionEntity createAuthenticatedClientSessionInstance(String userSessionId, AuthenticatedClientSessionModel clientSession,
String realmId, String clientId, boolean offline) {
final UUID clientSessionId = PersistentUserSessionProvider.createClientSessionUUID(userSessionId, clientId);
AuthenticatedClientSessionEntity entity = new AuthenticatedClientSessionEntity(clientSessionId);
entity.setRealmId(realmId);
entity.setAction(clientSession.getAction());
entity.setAuthMethod(clientSession.getProtocol());
entity.setNotes(clientSession.getNotes() == null ? new ConcurrentHashMap<>() : clientSession.getNotes());
entity.setClientId(clientId);
entity.setRedirectUri(clientSession.getRedirectUri());
entity.setTimestamp(clientSession.getTimestamp());
entity.setOffline(offline);
return entity;
}
public SessionEntityWrapper<UserSessionEntity> wrapPersistentEntity(RealmModel realm, boolean offline, UserSessionModel persistentUserSession) {
UserSessionEntity userSessionEntity = createUserSessionEntityInstance(persistentUserSession);

View File

@@ -24,16 +24,15 @@ import org.keycloak.models.ClientModel;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.UserSessionProvider;
import org.keycloak.models.session.UserSessionPersisterProvider;
import org.keycloak.models.sessions.infinispan.PersistentUserSessionProvider;
import org.keycloak.models.sessions.infinispan.SessionFunction;
import org.keycloak.models.sessions.infinispan.UserSessionAdapter;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionEntity;
import org.keycloak.models.sessions.infinispan.entities.AuthenticatedClientSessionStore;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import org.keycloak.models.sessions.infinispan.util.SessionTimeouts;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
@@ -132,8 +131,8 @@ public class ClientSessionPersistentChangelogBasedTransaction extends Persistent
return authenticatedClientSessionEntitySessionEntityWrapper;
}
private AuthenticatedClientSessionEntity createAuthenticatedClientSessionInstance(String userSessionId, AuthenticatedClientSessionModel clientSession,
String realmId, String clientId) {
public static AuthenticatedClientSessionEntity createAuthenticatedClientSessionInstance(String userSessionId, AuthenticatedClientSessionModel clientSession,
String realmId, String clientId, boolean offline) {
UUID clientSessionId = PersistentUserSessionProvider.createClientSessionUUID(userSessionId, clientId);
AuthenticatedClientSessionEntity entity = new AuthenticatedClientSessionEntity(clientSessionId);
@@ -146,14 +145,14 @@ public class ClientSessionPersistentChangelogBasedTransaction extends Persistent
entity.setClientId(clientId);
entity.setRedirectUri(clientSession.getRedirectUri());
entity.setTimestamp(clientSession.getTimestamp());
entity.setOffline(clientSession.getUserSession().isOffline());
entity.setOffline(offline);
return entity;
}
private SessionEntityWrapper<AuthenticatedClientSessionEntity> importClientSession(RealmModel realm, ClientModel client, UserSessionModel userSession, AuthenticatedClientSessionModel persistentClientSession, UUID clientSessionId) {
AuthenticatedClientSessionEntity entity = createAuthenticatedClientSessionInstance(userSession.getId(), persistentClientSession,
realm.getId(), client.getId());
realm.getId(), client.getId(), userSession.isOffline());
boolean offline = userSession.isOffline();
entity.setUserSessionId(userSession.getId());
@@ -165,59 +164,38 @@ public class ClientSessionPersistentChangelogBasedTransaction extends Persistent
entity.setTimestamp(userSession.getLastSessionRefresh());
}
SessionEntityWrapper<AuthenticatedClientSessionEntity> wrapper = new SessionEntityWrapper<>(entity);
Map<UUID, SessionEntityWrapper<AuthenticatedClientSessionEntity>> imported = ((PersistentUserSessionProvider) kcSession.getProvider(UserSessionProvider.class)).importSessionsWithExpiration(Map.of(clientSessionId, wrapper), getCache(offline),
getLifespanMsLoader(offline),
getMaxIdleMsLoader(offline));
long lifespan = getLifespanMsLoader(offline).apply(realm, client, entity);
long maxIdle = getMaxIdleMsLoader(offline).apply(realm, client, entity);
if (imported.isEmpty()) {
if (lifespan == SessionTimeouts.ENTRY_EXPIRED_FLAG || maxIdle == SessionTimeouts.ENTRY_EXPIRED_FLAG) {
LOG.debugf("Client-session has expired, not importing it. userSessionId=%s, clientSessionId=%s, clientId=%s, offline=%s",
userSession.getId(), clientSessionId, client.getId(), offline);
return null;
}
SessionEntityWrapper<AuthenticatedClientSessionEntity> wrapper = new SessionEntityWrapper<>(entity);
SessionUpdateTask<AuthenticatedClientSessionEntity> createClientSessionTask = Tasks.addIfAbsentSync();
this.addTask(entity.getId(), createClientSessionTask, entity, UserSessionModel.SessionPersistenceState.PERSISTENT);
SessionEntityWrapper<AuthenticatedClientSessionEntity> imported = importSession(realm, clientSessionId, wrapper, offline, lifespan, maxIdle);
if (imported != null) {
LOG.debugf("Client-session already imported by another transaction. userSessionId=%s, clientSessionId=%s, clientId=%s, offline=%s",
userSession.getId(), clientSessionId, client.getId(), offline);
return imported;
}
// TODO do we need the code below? In theory, if we are importing a client session, it is already mapped in the user session
if (! (userSession instanceof UserSessionAdapter<?> sessionToImportInto)) {
throw new IllegalStateException("UserSessionModel must be instance of UserSessionAdapter");
}
AuthenticatedClientSessionStore clientSessions = sessionToImportInto.getEntity().getAuthenticatedClientSessions();
clientSessions.put(client.getId(), clientSessionId);
UUID existingId = clientSessions.put(client.getId(), clientSessionId);
SessionUpdateTask<UserSessionEntity> registerClientSessionTask = new RegisterClientSessionTask(client.getId(), clientSessionId, offline);
userSessionTx.addTask(sessionToImportInto.getId(), registerClientSessionTask);
if (!Objects.equals(existingId, clientSessionId)) {
userSessionTx.registerClientSession(sessionToImportInto.getId(), client.getClientId(), clientSessionId, offline);
}
return wrapper;
}
public static class RegisterClientSessionTask implements PersistentSessionUpdateTask<UserSessionEntity> {
private final String clientUuid;
private final UUID clientSessionId;
private final boolean offline;
public RegisterClientSessionTask(String clientUuid, UUID clientSessionId, boolean offline) {
this.clientUuid = clientUuid;
this.clientSessionId = clientSessionId;
this.offline = offline;
}
@Override
public void runUpdate(UserSessionEntity session) {
AuthenticatedClientSessionStore clientSessions = session.getAuthenticatedClientSessions();
clientSessions.put(clientUuid, clientSessionId);
}
@Override
public CacheOperation getOperation() {
return CacheOperation.REPLACE;
}
@Override
public boolean isOffline() {
return offline;
}
}
}

View File

@@ -18,6 +18,7 @@
package org.keycloak.models.sessions.infinispan.changes;
import org.infinispan.Cache;
import org.infinispan.commons.util.concurrent.CompletionStages;
import org.jboss.logging.Logger;
import org.keycloak.models.AbstractKeycloakTransaction;
import org.keycloak.models.KeycloakSession;
@@ -25,6 +26,7 @@ import org.keycloak.models.RealmModel;
import org.keycloak.models.UserSessionModel;
import org.keycloak.models.sessions.infinispan.SessionFunction;
import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
import org.keycloak.models.sessions.infinispan.util.SessionTimeouts;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.util.HashMap;
@@ -32,6 +34,8 @@ import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends SessionEntity> extends AbstractKeycloakTransaction implements SessionsChangelogBasedTransaction<K, V> {
@@ -107,7 +111,7 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
}
}
public SessionEntityWrapper<V> get(K key, boolean offline){
public SessionEntityWrapper<V> get(K key, boolean offline) {
SessionUpdatesList<V> myUpdates = getUpdates(offline).get(key);
if (myUpdates == null) {
SessionEntityWrapper<V> wrappedEntity = getCache(offline).get(key);
@@ -123,8 +127,6 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
return wrappedEntity;
} else {
V entity = myUpdates.getEntityWrapper().getEntity();
// If entity is scheduled for remove, we don't return it.
boolean scheduledForRemove = myUpdates.getUpdateTasks().stream()
.map(SessionUpdateTask::getOperation)
@@ -159,7 +161,7 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
}
if (offlineCache != null) {
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(offlineCache, serializerOffline){
changesPerformers.add(new EmbeddedCachesChangesPerformer<>(offlineCache, serializerOffline) {
@Override
public boolean shouldConsumeChange(V entity) {
return entity.isOffline();
@@ -209,11 +211,10 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
@Override
public void addTask(K key, SessionUpdateTask<V> originalTask) {
if (! (originalTask instanceof PersistentSessionUpdateTask)) {
if (!(originalTask instanceof PersistentSessionUpdateTask<V> task)) {
throw new IllegalArgumentException("Task must be instance of PersistentSessionUpdateTask");
}
PersistentSessionUpdateTask<V> task = (PersistentSessionUpdateTask<V>) originalTask;
SessionUpdatesList<V> myUpdates = getUpdates(task.isOffline()).get(key);
if (myUpdates == null) {
// Lookup entity from cache
@@ -253,6 +254,8 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
}
}
// method not currently in use, remove in the next major.
@Deprecated(forRemoval = true, since = "26.4")
public void reloadEntityInCurrentTransaction(RealmModel realm, K key, SessionEntityWrapper<V> entity) {
if (entity == null) {
throw new IllegalArgumentException("Null entity not allowed");
@@ -279,4 +282,95 @@ abstract public class PersistentSessionsChangelogBasedTransaction<K, V extends S
}
/**
* Imports a session from an external source into the {@link Cache}.
* <p>
* If a session already exists in the cache, this method does not insert the {@code session}. The invoker should use
* the session returned by this method invocation. When the session is successfully imported, this method returns
* null and the {@code session} can be used by the transaction.
* <p>
* This transaction will keep track of further changes in the session.
*
* @param realmModel The {@link RealmModel} where the session belong to.
* @param key The cache's key.
* @param session The session to import.
* @param lifespan How long the session stays cached until it is expired and removed.
* @param maxIdle How long the session can be idle (without reading or writing) before being removed.
* @param offline {@code true} if it is an offline session.
* @return The existing cached session. If it returns {@code null}, it means the {@code session} used in the
* parameters was cached.
*/
public SessionEntityWrapper<V> importSession(RealmModel realmModel, K key, SessionEntityWrapper<V> session, boolean offline, long lifespan, long maxIdle) {
var updates = getUpdates(offline);
var updatesList = updates.get(key);
if (updatesList != null) {
// exists in transaction, avoid import operation
return updatesList.getEntityWrapper();
}
SessionEntityWrapper<V> existing = null;
try {
if (getCache(offline) != null) {
existing = getCache(offline).putIfAbsent(key, session, lifespan, TimeUnit.MILLISECONDS, maxIdle, TimeUnit.MILLISECONDS);
}
} catch (RuntimeException exception) {
// If the import fails, the transaction can continue with the data from the database.
LOG.debugf(exception, "Failed to import session %s", session);
}
if (existing == null) {
// keep track of the imported session for updates
updates.put(key, new SessionUpdatesList<>(realmModel, session));
return null;
}
updates.put(key, new SessionUpdatesList<>(realmModel, existing));
return existing;
}
/**
* Imports multiple sessions from an external source into the {@link Cache}.
* <p>
* If one or more sessions already exist in the {@link Cache}, or is expired, it will not be imported.
* <p>
* This transaction will keep track of further changes in the sessions.
*
* @param realmModel The {@link RealmModel} where the sessions belong to.
* @param sessions The {@link Map} with the cache's key/session mapping to be imported.
* @param offline {@code true} if it is an offline session.
*/
public void importSessionsConcurrently(RealmModel realmModel, Map<K, SessionEntityWrapper<V>> sessions, boolean offline) {
var cache = getCache(offline);
if (sessions.isEmpty() || cache == null) {
//nothing to import
return;
}
var stage = CompletionStages.aggregateCompletionStage();
var allSessions = new ConcurrentHashMap<K, SessionEntityWrapper<V>>();
var updates = getUpdates(offline);
var lifespanFunction = getLifespanMsLoader(offline);
var maxIdleFunction = getMaxIdleMsLoader(offline);
sessions.forEach((key, session) -> {
if (updates.containsKey(key)) {
//nothing to import, already exists in transaction
return;
}
var clientModel = session.getClientIfNeeded(realmModel);
var sessionEntity = session.getEntity();
var lifespan = lifespanFunction.apply(realmModel, clientModel, sessionEntity);
var maxIdle = maxIdleFunction.apply(realmModel, clientModel, sessionEntity);
if (lifespan == SessionTimeouts.ENTRY_EXPIRED_FLAG || maxIdle == SessionTimeouts.ENTRY_EXPIRED_FLAG) {
//nothing to import, already expired
return;
}
var future = cache.putIfAbsentAsync(key, session, lifespan, TimeUnit.MILLISECONDS, maxIdle, TimeUnit.MILLISECONDS)
.exceptionally(throwable -> {
// If the import fails, the transaction can continue with the data from the database.
LOG.debugf(throwable, "Failed to import session %s", session);
return null;
});
// write result into concurrent hash map because the consumer is invoked in a different thread each time.
stage.dependsOn(future.thenAccept(existing -> allSessions.put(key, existing == null ? session : existing)));
});
CompletionStages.join(stage.freeze());
allSessions.forEach((key, wrapper) -> updates.put(key, new SessionUpdatesList<>(realmModel, wrapper)));
}
}

View File

@@ -29,6 +29,7 @@ import org.keycloak.models.sessions.infinispan.SessionFunction;
import org.keycloak.models.sessions.infinispan.entities.SessionEntity;
import org.keycloak.models.sessions.infinispan.entities.UserSessionEntity;
import java.util.UUID;
import java.util.concurrent.ArrayBlockingQueue;
import static org.keycloak.connections.infinispan.InfinispanConnectionProvider.USER_SESSION_CACHE_NAME;
@@ -135,6 +136,25 @@ public class UserSessionPersistentChangelogBasedTransaction extends PersistentSe
return isScheduledForRemove(getUpdates(offline).get(key));
}
public void registerClientSession(String userSessionId, String clientId, UUID clientSessionId, boolean offline) {
addTask(userSessionId, new PersistentSessionUpdateTask<>() {
@Override
public boolean isOffline() {
return offline;
}
@Override
public void runUpdate(UserSessionEntity entity) {
entity.getAuthenticatedClientSessions().put(clientId, clientSessionId);
}
@Override
public CacheOperation getOperation() {
return CacheOperation.REPLACE;
}
});
}
private static <V extends SessionEntity> boolean isScheduledForRemove(SessionUpdatesList<V> myUpdates) {
if (myUpdates == null) {
return false;
@@ -145,5 +165,4 @@ public class UserSessionPersistentChangelogBasedTransaction extends PersistentSe
.stream()
.anyMatch(task -> task.getOperation() == SessionUpdateTask.CacheOperation.REMOVE);
}
}