Files
computer/tests/test_mcp_server_session_management.py
Adam 3274cfafe7 add concurrent session management and resource isolation
Implement concurrent session management for MCP server with:

- SessionManager with computer instance pooling
- Concurrent task execution support
- New tools: get_session_stats, cleanup_session
- Graceful shutdown and resource cleanup
- Fix nested asyncio event loop issues
- Add comprehensive tests and documentation

Enables multiple concurrent clients with proper resource isolation
while maintaining backward compatibility.
2025-10-06 18:37:10 -04:00

414 lines
13 KiB
Python

"""
Tests for MCP Server Session Management functionality.
This module tests the new concurrent session management and resource lifecycle features.
"""
import asyncio
import importlib.util
import sys
import types
import time
from pathlib import Path
import pytest
def _install_stub_module(name: str, module: types.ModuleType, registry: dict[str, types.ModuleType | None]) -> None:
registry[name] = sys.modules.get(name)
sys.modules[name] = module
@pytest.fixture
def server_module():
"""Create a server module with stubbed dependencies for testing."""
stubbed_modules: dict[str, types.ModuleType | None] = {}
# Stub MCP Context primitives
mcp_module = types.ModuleType("mcp")
mcp_module.__path__ = [] # mark as package
mcp_server_module = types.ModuleType("mcp.server")
mcp_server_module.__path__ = []
fastmcp_module = types.ModuleType("mcp.server.fastmcp")
class _StubContext:
async def yield_message(self, *args, **kwargs):
return None
async def yield_tool_call(self, *args, **kwargs):
return None
async def yield_tool_output(self, *args, **kwargs):
return None
def report_progress(self, *_args, **_kwargs):
return None
def info(self, *_args, **_kwargs):
return None
def error(self, *_args, **_kwargs):
return None
class _StubImage:
def __init__(self, format: str, data: bytes):
self.format = format
self.data = data
class _StubFastMCP:
def __init__(self, name: str):
self.name = name
self._tools: dict[str, types.FunctionType] = {}
def tool(self, *args, **kwargs):
def decorator(func):
self._tools[func.__name__] = func
return func
return decorator
def run(self):
return None
fastmcp_module.Context = _StubContext
fastmcp_module.FastMCP = _StubFastMCP
fastmcp_module.Image = _StubImage
_install_stub_module("mcp", mcp_module, stubbed_modules)
_install_stub_module("mcp.server", mcp_server_module, stubbed_modules)
_install_stub_module("mcp.server.fastmcp", fastmcp_module, stubbed_modules)
# Stub Computer module
computer_module = types.ModuleType("computer")
class _StubInterface:
async def screenshot(self) -> bytes:
return b"test-screenshot-data"
class _StubComputer:
def __init__(self, *args, **kwargs):
self.interface = _StubInterface()
async def run(self):
return None
computer_module.Computer = _StubComputer
_install_stub_module("computer", computer_module, stubbed_modules)
# Stub agent module
agent_module = types.ModuleType("agent")
class _StubComputerAgent:
def __init__(self, *args, **kwargs):
pass
async def run(self, *_args, **_kwargs):
# Simulate agent execution with streaming
yield {
"output": [
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Task completed"}]
}
]
}
agent_module.ComputerAgent = _StubComputerAgent
_install_stub_module("agent", agent_module, stubbed_modules)
# Stub session manager module
session_manager_module = types.ModuleType("mcp_server.session_manager")
class _StubSessionInfo:
def __init__(self, session_id: str, computer, created_at: float, last_activity: float):
self.session_id = session_id
self.computer = computer
self.created_at = created_at
self.last_activity = last_activity
self.active_tasks = set()
self.is_shutting_down = False
class _StubSessionManager:
def __init__(self):
self._sessions = {}
self._session_lock = asyncio.Lock()
async def get_session(self, session_id=None):
"""Context manager that returns a session."""
if session_id is None:
session_id = "test-session-123"
async with self._session_lock:
if session_id not in self._sessions:
computer = _StubComputer()
session = _StubSessionInfo(
session_id=session_id,
computer=computer,
created_at=time.time(),
last_activity=time.time()
)
self._sessions[session_id] = session
return self._sessions[session_id]
async def register_task(self, session_id: str, task_id: str):
pass
async def unregister_task(self, session_id: str, task_id: str):
pass
async def cleanup_session(self, session_id: str):
async with self._session_lock:
self._sessions.pop(session_id, None)
def get_session_stats(self):
return {
"total_sessions": len(self._sessions),
"max_concurrent": 10,
"sessions": {sid: {"active_tasks": 0} for sid in self._sessions}
}
_stub_session_manager = _StubSessionManager()
def get_session_manager():
return _stub_session_manager
async def initialize_session_manager():
return _stub_session_manager
async def shutdown_session_manager():
pass
session_manager_module.get_session_manager = get_session_manager
session_manager_module.initialize_session_manager = initialize_session_manager
session_manager_module.shutdown_session_manager = shutdown_session_manager
_install_stub_module("mcp_server.session_manager", session_manager_module, stubbed_modules)
# Load the actual server module
module_name = "mcp_server_server_under_test"
module_path = Path("libs/python/mcp-server/mcp_server/server.py").resolve()
spec = importlib.util.spec_from_file_location(module_name, module_path)
server_module = importlib.util.module_from_spec(spec)
assert spec and spec.loader
spec.loader.exec_module(server_module)
server_instance = getattr(server_module, "server", None)
if server_instance is not None and hasattr(server_instance, "_tools"):
for name, func in server_instance._tools.items():
setattr(server_module, name, func)
try:
yield server_module
finally:
sys.modules.pop(module_name, None)
for name, original in stubbed_modules.items():
if original is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = original
class FakeContext:
"""Fake context for testing."""
def __init__(self) -> None:
self.events: list[tuple] = []
self.progress_updates: list[float] = []
def info(self, message: str) -> None:
self.events.append(("info", message))
def error(self, message: str) -> None:
self.events.append(("error", message))
def report_progress(self, value: float) -> None:
self.progress_updates.append(value)
async def yield_message(self, *, role: str, content):
timestamp = asyncio.get_running_loop().time()
self.events.append(("message", role, content, timestamp))
async def yield_tool_call(self, *, name: str | None, call_id: str, input):
timestamp = asyncio.get_running_loop().time()
self.events.append(("tool_call", name, call_id, input, timestamp))
async def yield_tool_output(self, *, call_id: str, output, is_error: bool = False):
timestamp = asyncio.get_running_loop().time()
self.events.append(("tool_output", call_id, output, is_error, timestamp))
def test_screenshot_cua_with_session_id(server_module):
"""Test that screenshot_cua works with session management."""
async def _run_test():
ctx = FakeContext()
result = await server_module.screenshot_cua(ctx, session_id="test-session")
assert result.format == "png"
assert result.data == b"test-screenshot-data"
asyncio.run(_run_test())
def test_screenshot_cua_creates_new_session(server_module):
"""Test that screenshot_cua creates a new session when none provided."""
async def _run_test():
ctx = FakeContext()
result = await server_module.screenshot_cua(ctx)
assert result.format == "png"
assert result.data == b"test-screenshot-data"
asyncio.run(_run_test())
def test_run_cua_task_with_session_management(server_module):
"""Test that run_cua_task works with session management."""
async def _run_test():
ctx = FakeContext()
task = "Test task"
session_id = "test-session-456"
text_result, image = await server_module.run_cua_task(ctx, task, session_id)
assert "Task completed" in text_result
assert image.format == "png"
assert image.data == b"test-screenshot-data"
asyncio.run(_run_test())
def test_run_multi_cua_tasks_sequential(server_module):
"""Test that run_multi_cua_tasks works sequentially."""
async def _run_test():
ctx = FakeContext()
tasks = ["Task 1", "Task 2", "Task 3"]
results = await server_module.run_multi_cua_tasks(ctx, tasks, concurrent=False)
assert len(results) == 3
for i, (text, image) in enumerate(results):
assert "Task completed" in text
assert image.format == "png"
asyncio.run(_run_test())
def test_run_multi_cua_tasks_concurrent(server_module):
"""Test that run_multi_cua_tasks works concurrently."""
async def _run_test():
ctx = FakeContext()
tasks = ["Task 1", "Task 2", "Task 3"]
results = await server_module.run_multi_cua_tasks(ctx, tasks, concurrent=True)
assert len(results) == 3
for i, (text, image) in enumerate(results):
assert "Task completed" in text
assert image.format == "png"
asyncio.run(_run_test())
def test_get_session_stats(server_module):
"""Test that get_session_stats returns proper statistics."""
async def _run_test():
ctx = FakeContext()
stats = await server_module.get_session_stats()
assert "total_sessions" in stats
assert "max_concurrent" in stats
assert "sessions" in stats
asyncio.run(_run_test())
def test_cleanup_session(server_module):
"""Test that cleanup_session works properly."""
async def _run_test():
ctx = FakeContext()
session_id = "test-cleanup-session"
result = await server_module.cleanup_session(ctx, session_id)
assert f"Session {session_id} cleanup initiated" in result
asyncio.run(_run_test())
def test_concurrent_sessions_isolation(server_module):
"""Test that concurrent sessions are properly isolated."""
async def _run_test():
ctx = FakeContext()
# Run multiple tasks with different session IDs concurrently
task1 = asyncio.create_task(
server_module.run_cua_task(ctx, "Task for session 1", "session-1")
)
task2 = asyncio.create_task(
server_module.run_cua_task(ctx, "Task for session 2", "session-2")
)
results = await asyncio.gather(task1, task2)
assert len(results) == 2
for text, image in results:
assert "Task completed" in text
assert image.format == "png"
asyncio.run(_run_test())
def test_session_reuse_with_same_id(server_module):
"""Test that sessions are reused when the same session ID is provided."""
async def _run_test():
ctx = FakeContext()
session_id = "reuse-session"
# First call
result1 = await server_module.screenshot_cua(ctx, session_id)
# Second call with same session ID
result2 = await server_module.screenshot_cua(ctx, session_id)
assert result1.format == result2.format
assert result1.data == result2.data
asyncio.run(_run_test())
def test_error_handling_with_session_management(server_module):
"""Test that errors are handled properly with session management."""
async def _run_test():
# Mock an agent that raises an exception
class _FailingAgent:
def __init__(self, *args, **kwargs):
pass
async def run(self, *_args, **_kwargs):
raise RuntimeError("Simulated agent failure")
# Replace the ComputerAgent with our failing one
original_agent = server_module.ComputerAgent
server_module.ComputerAgent = _FailingAgent
try:
ctx = FakeContext()
task = "This will fail"
text_result, image = await server_module.run_cua_task(ctx, task, "error-session")
assert "Error during task execution" in text_result
assert image.format == "png"
finally:
# Restore original agent
server_module.ComputerAgent = original_agent
asyncio.run(_run_test())