mirror of
https://github.com/trycua/computer.git
synced 2026-01-05 04:50:08 -06:00
Implement concurrent session management for MCP server with: - SessionManager with computer instance pooling - Concurrent task execution support - New tools: get_session_stats, cleanup_session - Graceful shutdown and resource cleanup - Fix nested asyncio event loop issues - Add comprehensive tests and documentation Enables multiple concurrent clients with proper resource isolation while maintaining backward compatibility.
414 lines
13 KiB
Python
414 lines
13 KiB
Python
"""
|
|
Tests for MCP Server Session Management functionality.
|
|
|
|
This module tests the new concurrent session management and resource lifecycle features.
|
|
"""
|
|
|
|
import asyncio
|
|
import importlib.util
|
|
import sys
|
|
import types
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
|
|
def _install_stub_module(name: str, module: types.ModuleType, registry: dict[str, types.ModuleType | None]) -> None:
|
|
registry[name] = sys.modules.get(name)
|
|
sys.modules[name] = module
|
|
|
|
|
|
@pytest.fixture
|
|
def server_module():
|
|
"""Create a server module with stubbed dependencies for testing."""
|
|
stubbed_modules: dict[str, types.ModuleType | None] = {}
|
|
|
|
# Stub MCP Context primitives
|
|
mcp_module = types.ModuleType("mcp")
|
|
mcp_module.__path__ = [] # mark as package
|
|
|
|
mcp_server_module = types.ModuleType("mcp.server")
|
|
mcp_server_module.__path__ = []
|
|
|
|
fastmcp_module = types.ModuleType("mcp.server.fastmcp")
|
|
|
|
class _StubContext:
|
|
async def yield_message(self, *args, **kwargs):
|
|
return None
|
|
|
|
async def yield_tool_call(self, *args, **kwargs):
|
|
return None
|
|
|
|
async def yield_tool_output(self, *args, **kwargs):
|
|
return None
|
|
|
|
def report_progress(self, *_args, **_kwargs):
|
|
return None
|
|
|
|
def info(self, *_args, **_kwargs):
|
|
return None
|
|
|
|
def error(self, *_args, **_kwargs):
|
|
return None
|
|
|
|
class _StubImage:
|
|
def __init__(self, format: str, data: bytes):
|
|
self.format = format
|
|
self.data = data
|
|
|
|
class _StubFastMCP:
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
self._tools: dict[str, types.FunctionType] = {}
|
|
|
|
def tool(self, *args, **kwargs):
|
|
def decorator(func):
|
|
self._tools[func.__name__] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def run(self):
|
|
return None
|
|
|
|
fastmcp_module.Context = _StubContext
|
|
fastmcp_module.FastMCP = _StubFastMCP
|
|
fastmcp_module.Image = _StubImage
|
|
|
|
_install_stub_module("mcp", mcp_module, stubbed_modules)
|
|
_install_stub_module("mcp.server", mcp_server_module, stubbed_modules)
|
|
_install_stub_module("mcp.server.fastmcp", fastmcp_module, stubbed_modules)
|
|
|
|
# Stub Computer module
|
|
computer_module = types.ModuleType("computer")
|
|
|
|
class _StubInterface:
|
|
async def screenshot(self) -> bytes:
|
|
return b"test-screenshot-data"
|
|
|
|
class _StubComputer:
|
|
def __init__(self, *args, **kwargs):
|
|
self.interface = _StubInterface()
|
|
|
|
async def run(self):
|
|
return None
|
|
|
|
computer_module.Computer = _StubComputer
|
|
|
|
_install_stub_module("computer", computer_module, stubbed_modules)
|
|
|
|
# Stub agent module
|
|
agent_module = types.ModuleType("agent")
|
|
|
|
class _StubComputerAgent:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def run(self, *_args, **_kwargs):
|
|
# Simulate agent execution with streaming
|
|
yield {
|
|
"output": [
|
|
{
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"content": [{"type": "output_text", "text": "Task completed"}]
|
|
}
|
|
]
|
|
}
|
|
|
|
agent_module.ComputerAgent = _StubComputerAgent
|
|
|
|
_install_stub_module("agent", agent_module, stubbed_modules)
|
|
|
|
# Stub session manager module
|
|
session_manager_module = types.ModuleType("mcp_server.session_manager")
|
|
|
|
class _StubSessionInfo:
|
|
def __init__(self, session_id: str, computer, created_at: float, last_activity: float):
|
|
self.session_id = session_id
|
|
self.computer = computer
|
|
self.created_at = created_at
|
|
self.last_activity = last_activity
|
|
self.active_tasks = set()
|
|
self.is_shutting_down = False
|
|
|
|
class _StubSessionManager:
|
|
def __init__(self):
|
|
self._sessions = {}
|
|
self._session_lock = asyncio.Lock()
|
|
|
|
async def get_session(self, session_id=None):
|
|
"""Context manager that returns a session."""
|
|
if session_id is None:
|
|
session_id = "test-session-123"
|
|
|
|
async with self._session_lock:
|
|
if session_id not in self._sessions:
|
|
computer = _StubComputer()
|
|
session = _StubSessionInfo(
|
|
session_id=session_id,
|
|
computer=computer,
|
|
created_at=time.time(),
|
|
last_activity=time.time()
|
|
)
|
|
self._sessions[session_id] = session
|
|
|
|
return self._sessions[session_id]
|
|
|
|
async def register_task(self, session_id: str, task_id: str):
|
|
pass
|
|
|
|
async def unregister_task(self, session_id: str, task_id: str):
|
|
pass
|
|
|
|
async def cleanup_session(self, session_id: str):
|
|
async with self._session_lock:
|
|
self._sessions.pop(session_id, None)
|
|
|
|
def get_session_stats(self):
|
|
return {
|
|
"total_sessions": len(self._sessions),
|
|
"max_concurrent": 10,
|
|
"sessions": {sid: {"active_tasks": 0} for sid in self._sessions}
|
|
}
|
|
|
|
_stub_session_manager = _StubSessionManager()
|
|
|
|
def get_session_manager():
|
|
return _stub_session_manager
|
|
|
|
async def initialize_session_manager():
|
|
return _stub_session_manager
|
|
|
|
async def shutdown_session_manager():
|
|
pass
|
|
|
|
session_manager_module.get_session_manager = get_session_manager
|
|
session_manager_module.initialize_session_manager = initialize_session_manager
|
|
session_manager_module.shutdown_session_manager = shutdown_session_manager
|
|
|
|
_install_stub_module("mcp_server.session_manager", session_manager_module, stubbed_modules)
|
|
|
|
# Load the actual server module
|
|
module_name = "mcp_server_server_under_test"
|
|
module_path = Path("libs/python/mcp-server/mcp_server/server.py").resolve()
|
|
spec = importlib.util.spec_from_file_location(module_name, module_path)
|
|
server_module = importlib.util.module_from_spec(spec)
|
|
assert spec and spec.loader
|
|
spec.loader.exec_module(server_module)
|
|
|
|
server_instance = getattr(server_module, "server", None)
|
|
if server_instance is not None and hasattr(server_instance, "_tools"):
|
|
for name, func in server_instance._tools.items():
|
|
setattr(server_module, name, func)
|
|
|
|
try:
|
|
yield server_module
|
|
finally:
|
|
sys.modules.pop(module_name, None)
|
|
for name, original in stubbed_modules.items():
|
|
if original is None:
|
|
sys.modules.pop(name, None)
|
|
else:
|
|
sys.modules[name] = original
|
|
|
|
|
|
class FakeContext:
|
|
"""Fake context for testing."""
|
|
|
|
def __init__(self) -> None:
|
|
self.events: list[tuple] = []
|
|
self.progress_updates: list[float] = []
|
|
|
|
def info(self, message: str) -> None:
|
|
self.events.append(("info", message))
|
|
|
|
def error(self, message: str) -> None:
|
|
self.events.append(("error", message))
|
|
|
|
def report_progress(self, value: float) -> None:
|
|
self.progress_updates.append(value)
|
|
|
|
async def yield_message(self, *, role: str, content):
|
|
timestamp = asyncio.get_running_loop().time()
|
|
self.events.append(("message", role, content, timestamp))
|
|
|
|
async def yield_tool_call(self, *, name: str | None, call_id: str, input):
|
|
timestamp = asyncio.get_running_loop().time()
|
|
self.events.append(("tool_call", name, call_id, input, timestamp))
|
|
|
|
async def yield_tool_output(self, *, call_id: str, output, is_error: bool = False):
|
|
timestamp = asyncio.get_running_loop().time()
|
|
self.events.append(("tool_output", call_id, output, is_error, timestamp))
|
|
|
|
|
|
def test_screenshot_cua_with_session_id(server_module):
|
|
"""Test that screenshot_cua works with session management."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
result = await server_module.screenshot_cua(ctx, session_id="test-session")
|
|
|
|
assert result.format == "png"
|
|
assert result.data == b"test-screenshot-data"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_screenshot_cua_creates_new_session(server_module):
|
|
"""Test that screenshot_cua creates a new session when none provided."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
result = await server_module.screenshot_cua(ctx)
|
|
|
|
assert result.format == "png"
|
|
assert result.data == b"test-screenshot-data"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_run_cua_task_with_session_management(server_module):
|
|
"""Test that run_cua_task works with session management."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
task = "Test task"
|
|
session_id = "test-session-456"
|
|
|
|
text_result, image = await server_module.run_cua_task(ctx, task, session_id)
|
|
|
|
assert "Task completed" in text_result
|
|
assert image.format == "png"
|
|
assert image.data == b"test-screenshot-data"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_run_multi_cua_tasks_sequential(server_module):
|
|
"""Test that run_multi_cua_tasks works sequentially."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
tasks = ["Task 1", "Task 2", "Task 3"]
|
|
|
|
results = await server_module.run_multi_cua_tasks(ctx, tasks, concurrent=False)
|
|
|
|
assert len(results) == 3
|
|
for i, (text, image) in enumerate(results):
|
|
assert "Task completed" in text
|
|
assert image.format == "png"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_run_multi_cua_tasks_concurrent(server_module):
|
|
"""Test that run_multi_cua_tasks works concurrently."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
tasks = ["Task 1", "Task 2", "Task 3"]
|
|
|
|
results = await server_module.run_multi_cua_tasks(ctx, tasks, concurrent=True)
|
|
|
|
assert len(results) == 3
|
|
for i, (text, image) in enumerate(results):
|
|
assert "Task completed" in text
|
|
assert image.format == "png"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_get_session_stats(server_module):
|
|
"""Test that get_session_stats returns proper statistics."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
stats = await server_module.get_session_stats()
|
|
|
|
assert "total_sessions" in stats
|
|
assert "max_concurrent" in stats
|
|
assert "sessions" in stats
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_cleanup_session(server_module):
|
|
"""Test that cleanup_session works properly."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
session_id = "test-cleanup-session"
|
|
|
|
result = await server_module.cleanup_session(ctx, session_id)
|
|
|
|
assert f"Session {session_id} cleanup initiated" in result
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_concurrent_sessions_isolation(server_module):
|
|
"""Test that concurrent sessions are properly isolated."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
|
|
# Run multiple tasks with different session IDs concurrently
|
|
task1 = asyncio.create_task(
|
|
server_module.run_cua_task(ctx, "Task for session 1", "session-1")
|
|
)
|
|
task2 = asyncio.create_task(
|
|
server_module.run_cua_task(ctx, "Task for session 2", "session-2")
|
|
)
|
|
|
|
results = await asyncio.gather(task1, task2)
|
|
|
|
assert len(results) == 2
|
|
for text, image in results:
|
|
assert "Task completed" in text
|
|
assert image.format == "png"
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_session_reuse_with_same_id(server_module):
|
|
"""Test that sessions are reused when the same session ID is provided."""
|
|
async def _run_test():
|
|
ctx = FakeContext()
|
|
session_id = "reuse-session"
|
|
|
|
# First call
|
|
result1 = await server_module.screenshot_cua(ctx, session_id)
|
|
|
|
# Second call with same session ID
|
|
result2 = await server_module.screenshot_cua(ctx, session_id)
|
|
|
|
assert result1.format == result2.format
|
|
assert result1.data == result2.data
|
|
|
|
asyncio.run(_run_test())
|
|
|
|
|
|
def test_error_handling_with_session_management(server_module):
|
|
"""Test that errors are handled properly with session management."""
|
|
async def _run_test():
|
|
# Mock an agent that raises an exception
|
|
class _FailingAgent:
|
|
def __init__(self, *args, **kwargs):
|
|
pass
|
|
|
|
async def run(self, *_args, **_kwargs):
|
|
raise RuntimeError("Simulated agent failure")
|
|
|
|
# Replace the ComputerAgent with our failing one
|
|
original_agent = server_module.ComputerAgent
|
|
server_module.ComputerAgent = _FailingAgent
|
|
|
|
try:
|
|
ctx = FakeContext()
|
|
task = "This will fail"
|
|
|
|
text_result, image = await server_module.run_cua_task(ctx, task, "error-session")
|
|
|
|
assert "Error during task execution" in text_result
|
|
assert image.format == "png"
|
|
|
|
finally:
|
|
# Restore original agent
|
|
server_module.ComputerAgent = original_agent
|
|
|
|
asyncio.run(_run_test())
|