mirror of
https://github.com/trycua/lume.git
synced 2026-01-06 12:29:56 -06:00
remove duplicate mcp server code
This commit is contained in:
@@ -1,458 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import anyio
|
||||
|
||||
# Configure logging to output to stderr for debug visibility
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, # Changed to DEBUG
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
stream=sys.stderr,
|
||||
)
|
||||
logger = logging.getLogger("mcp-server")
|
||||
|
||||
# More visible startup message
|
||||
logger.debug("MCP Server module loading...")
|
||||
|
||||
try:
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
|
||||
# Use the canonical Image type
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
|
||||
logger.debug("Successfully imported FastMCP")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import FastMCP: {e}")
|
||||
logger.info("Attempting to install missing dependencies...")
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
subprocess.check_call([sys.executable, "setup.py"])
|
||||
# Try importing again
|
||||
from mcp.server.fastmcp import Context, FastMCP
|
||||
from mcp.server.fastmcp.utilities.types import Image
|
||||
|
||||
logger.info("Dependencies installed successfully, retrying...")
|
||||
except Exception as setup_error:
|
||||
logger.error(f"Failed to install dependencies: {setup_error}")
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
|
||||
logger.debug("Successfully imported Agent module")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import Computer/Agent modules: {e}")
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from session_manager import (
|
||||
get_session_manager,
|
||||
initialize_session_manager,
|
||||
shutdown_session_manager,
|
||||
)
|
||||
|
||||
logger.debug("Successfully imported session manager")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import session manager: {e}")
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_env_bool(key: str, default: bool = False) -> bool:
|
||||
"""Get boolean value from environment variable."""
|
||||
return os.getenv(key, str(default)).lower() in ("true", "1", "yes")
|
||||
|
||||
|
||||
async def _maybe_call_ctx_method(ctx: Context, method_name: str, *args, **kwargs) -> None:
|
||||
"""Call a context helper if it exists, awaiting the result when necessary."""
|
||||
method = getattr(ctx, method_name, None)
|
||||
if not callable(method):
|
||||
return
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
|
||||
|
||||
def _normalise_message_content(content: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
"""Normalise message content to a list of structured parts."""
|
||||
if isinstance(content, list):
|
||||
return content
|
||||
if content is None:
|
||||
return []
|
||||
return [{"type": "output_text", "text": str(content)}]
|
||||
|
||||
|
||||
def _extract_text_from_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract textual content for inclusion in the aggregated result string."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
texts: List[str] = []
|
||||
for part in content or []:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
if part.get("type") in {"output_text", "text"} and part.get("text"):
|
||||
texts.append(str(part["text"]))
|
||||
return "\n".join(texts)
|
||||
|
||||
|
||||
def _serialise_tool_content(content: Any) -> str:
|
||||
"""Convert tool outputs into a string for aggregation."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
texts: List[str] = []
|
||||
for part in content:
|
||||
if (
|
||||
isinstance(part, dict)
|
||||
and part.get("type") in {"output_text", "text"}
|
||||
and part.get("text")
|
||||
):
|
||||
texts.append(str(part["text"]))
|
||||
if texts:
|
||||
return "\n".join(texts)
|
||||
if content is None:
|
||||
return ""
|
||||
return str(content)
|
||||
|
||||
|
||||
def serve() -> FastMCP:
|
||||
"""Create and configure the MCP server."""
|
||||
# NOTE: Do not pass model_config here; FastMCP 2.12.x doesn't support it.
|
||||
server = FastMCP(name="cua-agent")
|
||||
|
||||
@server.tool(structured_output=False)
|
||||
async def screenshot_cua(ctx: Context, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Take a screenshot of the current MacOS VM screen and return the image.
|
||||
|
||||
Args:
|
||||
session_id: Optional session ID for multi-client support.
|
||||
If not provided, a new session will be created.
|
||||
"""
|
||||
session_manager = get_session_manager()
|
||||
|
||||
async with session_manager.get_session(session_id) as session:
|
||||
screenshot = await session.computer.interface.screenshot()
|
||||
# Returning Image object is fine when structured_output=False
|
||||
return Image(format="png", data=screenshot)
|
||||
|
||||
@server.tool(structured_output=False)
|
||||
async def run_cua_task(ctx: Context, task: str, session_id: Optional[str] = None) -> Any:
|
||||
"""
|
||||
Run a Computer-Use Agent (CUA) task in a MacOS VM and return
|
||||
(combined text, final screenshot).
|
||||
|
||||
Args:
|
||||
task: The task description for the agent to execute
|
||||
session_id: Optional session ID for multi-client support.
|
||||
If not provided, a new session will be created.
|
||||
"""
|
||||
session_manager = get_session_manager()
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
logger.info(f"Starting CUA task: {task} (task_id: {task_id})")
|
||||
|
||||
async with session_manager.get_session(session_id) as session:
|
||||
# Register this task with the session
|
||||
await session_manager.register_task(session.session_id, task_id)
|
||||
|
||||
try:
|
||||
# Get API key from user config and set for CUA library
|
||||
api_key = os.getenv("API_KEY")
|
||||
if api_key:
|
||||
os.environ["ANTHROPIC_API_KEY"] = api_key
|
||||
logger.info("API key configured successfully")
|
||||
else:
|
||||
logger.warning(
|
||||
"No API key provided. Please configure your API key in Claude Desktop."
|
||||
)
|
||||
|
||||
# Get model name from user config or environment
|
||||
model_name = os.getenv("CUA_MODEL_NAME", "anthropic/claude-3-5-sonnet-20241022")
|
||||
logger.info(f"Using model: {model_name}")
|
||||
|
||||
# Create agent with the new v0.4.x API
|
||||
agent = ComputerAgent(
|
||||
model=model_name,
|
||||
only_n_most_recent_images=int(os.getenv("CUA_MAX_IMAGES", "3")),
|
||||
verbosity=logging.INFO,
|
||||
tools=[session.computer],
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# Collect all results
|
||||
aggregated_messages: List[str] = []
|
||||
async for result in agent.run(messages):
|
||||
logger.info("Agent processing step")
|
||||
ctx.info("Agent processing step")
|
||||
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
|
||||
if output_type == "message":
|
||||
logger.debug("Streaming assistant message: %s", output)
|
||||
content = _normalise_message_content(output.get("content"))
|
||||
aggregated_text = _extract_text_from_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_message",
|
||||
role=output.get("role", "assistant"),
|
||||
content=content,
|
||||
)
|
||||
|
||||
elif output_type in {"tool_use", "computer_call", "function_call"}:
|
||||
logger.debug("Streaming tool call: %s", output)
|
||||
call_id = output.get("id") or output.get("call_id")
|
||||
tool_name = output.get("name") or output.get("action", {}).get(
|
||||
"type"
|
||||
)
|
||||
tool_input = (
|
||||
output.get("input")
|
||||
or output.get("arguments")
|
||||
or output.get("action")
|
||||
)
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_call",
|
||||
name=tool_name,
|
||||
call_id=call_id,
|
||||
input=tool_input,
|
||||
)
|
||||
|
||||
elif output_type in {
|
||||
"tool_result",
|
||||
"computer_call_output",
|
||||
"function_call_output",
|
||||
}:
|
||||
logger.debug("Streaming tool output: %s", output)
|
||||
call_id = output.get("call_id") or output.get("id")
|
||||
content = output.get("content") or output.get("output")
|
||||
aggregated_text = _serialise_tool_content(content)
|
||||
if aggregated_text:
|
||||
aggregated_messages.append(aggregated_text)
|
||||
if call_id:
|
||||
await _maybe_call_ctx_method(
|
||||
ctx,
|
||||
"yield_tool_output",
|
||||
call_id=call_id,
|
||||
output=content,
|
||||
is_error=output.get("status") == "failed"
|
||||
or output.get("is_error", False),
|
||||
)
|
||||
|
||||
logger.info("CUA task completed successfully")
|
||||
ctx.info("CUA task completed successfully")
|
||||
|
||||
screenshot_image = Image(
|
||||
format="png",
|
||||
data=await session.computer.interface.screenshot(),
|
||||
)
|
||||
|
||||
return (
|
||||
"\n".join(aggregated_messages).strip()
|
||||
or "Task completed with no text output.",
|
||||
screenshot_image,
|
||||
)
|
||||
|
||||
finally:
|
||||
# Unregister the task from the session
|
||||
await session_manager.unregister_task(session.session_id, task_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error running CUA task: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
ctx.error(error_msg)
|
||||
|
||||
# Try to get a screenshot from the session if available
|
||||
try:
|
||||
if session_id:
|
||||
async with session_manager.get_session(session_id) as session:
|
||||
screenshot = await session.computer.interface.screenshot()
|
||||
return (
|
||||
f"Error during task execution: {str(e)}",
|
||||
Image(format="png", data=screenshot),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If we can't get a screenshot, return a placeholder
|
||||
return (
|
||||
f"Error during task execution: {str(e)}",
|
||||
Image(format="png", data=b""),
|
||||
)
|
||||
|
||||
@server.tool(structured_output=False)
|
||||
async def run_multi_cua_tasks(
|
||||
ctx: Context, tasks: List[str], session_id: Optional[str] = None, concurrent: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
Run multiple CUA tasks and return a list of (combined text, screenshot).
|
||||
|
||||
Args:
|
||||
tasks: List of task descriptions to execute
|
||||
session_id: Optional session ID for multi-client support.
|
||||
If not provided, a new session will be created.
|
||||
concurrent: If True, run tasks concurrently. If False, run sequentially (default).
|
||||
"""
|
||||
total_tasks = len(tasks)
|
||||
if total_tasks == 0:
|
||||
ctx.report_progress(1.0)
|
||||
return []
|
||||
|
||||
if concurrent and total_tasks > 1:
|
||||
# Run tasks concurrently
|
||||
logger.info(f"Running {total_tasks} tasks concurrently")
|
||||
ctx.info(f"Running {total_tasks} tasks concurrently")
|
||||
|
||||
# Create tasks with progress tracking
|
||||
async def run_task_with_progress(
|
||||
task_index: int, task: str
|
||||
) -> Tuple[int, Tuple[str, Image]]:
|
||||
ctx.report_progress(task_index / total_tasks)
|
||||
result = await run_cua_task(ctx, task, session_id)
|
||||
ctx.report_progress((task_index + 1) / total_tasks)
|
||||
return task_index, result
|
||||
|
||||
# Create all task coroutines
|
||||
task_coroutines = [run_task_with_progress(i, task) for i, task in enumerate(tasks)]
|
||||
|
||||
# Wait for all tasks to complete
|
||||
results_with_indices = await asyncio.gather(*task_coroutines, return_exceptions=True)
|
||||
|
||||
# Sort results by original task order and handle exceptions
|
||||
results: List[Tuple[str, Image]] = []
|
||||
for result in results_with_indices:
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Task failed with exception: {result}")
|
||||
ctx.error(f"Task failed: {str(result)}")
|
||||
results.append((f"Task failed: {str(result)}", Image(format="png", data=b"")))
|
||||
else:
|
||||
_, task_result = result
|
||||
results.append(task_result)
|
||||
|
||||
return results
|
||||
else:
|
||||
# Run tasks sequentially (original behavior)
|
||||
logger.info(f"Running {total_tasks} tasks sequentially")
|
||||
ctx.info(f"Running {total_tasks} tasks sequentially")
|
||||
|
||||
results: List[Tuple[str, Image]] = []
|
||||
for i, task in enumerate(tasks):
|
||||
logger.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
ctx.info(f"Running task {i+1}/{total_tasks}: {task}")
|
||||
|
||||
ctx.report_progress(i / total_tasks)
|
||||
task_result = await run_cua_task(ctx, task, session_id)
|
||||
results.append(task_result)
|
||||
ctx.report_progress((i + 1) / total_tasks)
|
||||
|
||||
return results
|
||||
|
||||
@server.tool(structured_output=False)
|
||||
async def get_session_stats(ctx: Context) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics about active sessions and resource usage.
|
||||
"""
|
||||
session_manager = get_session_manager()
|
||||
return session_manager.get_session_stats()
|
||||
|
||||
@server.tool(structured_output=False)
|
||||
async def cleanup_session(ctx: Context, session_id: str) -> str:
|
||||
"""
|
||||
Cleanup a specific session and release its resources.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to cleanup
|
||||
"""
|
||||
session_manager = get_session_manager()
|
||||
await session_manager.cleanup_session(session_id)
|
||||
return f"Session {session_id} cleanup initiated"
|
||||
|
||||
return server
|
||||
|
||||
|
||||
server = serve()
|
||||
|
||||
|
||||
async def run_server():
|
||||
"""Run the MCP server with proper lifecycle management."""
|
||||
session_manager = None
|
||||
try:
|
||||
logger.debug("Starting MCP server...")
|
||||
|
||||
# Initialize session manager
|
||||
session_manager = await initialize_session_manager()
|
||||
logger.info("Session manager initialized")
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}, initiating graceful shutdown...")
|
||||
# Create a task to shutdown gracefully
|
||||
asyncio.create_task(graceful_shutdown())
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Start the server
|
||||
logger.info("Starting FastMCP server...")
|
||||
# Use run_stdio_async directly instead of server.run() to avoid nested event loops
|
||||
await server.run_stdio_async()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting server: {e}")
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
raise
|
||||
finally:
|
||||
# Ensure cleanup happens
|
||||
if session_manager:
|
||||
logger.info("Shutting down session manager...")
|
||||
await shutdown_session_manager()
|
||||
|
||||
|
||||
async def graceful_shutdown():
|
||||
"""Gracefully shutdown the server and all sessions."""
|
||||
logger.info("Initiating graceful shutdown...")
|
||||
try:
|
||||
await shutdown_session_manager()
|
||||
logger.info("Graceful shutdown completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during graceful shutdown: {e}")
|
||||
finally:
|
||||
# Exit the process
|
||||
import os
|
||||
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the MCP server with proper async lifecycle management."""
|
||||
try:
|
||||
# Use anyio.run instead of asyncio.run to avoid nested event loop issues
|
||||
anyio.run(run_server)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Server interrupted by user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting server: {e}")
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,321 +0,0 @@
|
||||
"""
|
||||
Session Manager for MCP Server - Handles concurrent client sessions with proper resource isolation.
|
||||
|
||||
This module provides:
|
||||
- Per-session computer instance management
|
||||
- Resource pooling and lifecycle management
|
||||
- Graceful session cleanup
|
||||
- Concurrent task execution support
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger("mcp-server.session_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Information about an active session."""
|
||||
|
||||
session_id: str
|
||||
computer: Any # Computer instance
|
||||
created_at: float
|
||||
last_activity: float
|
||||
active_tasks: Set[str] = field(default_factory=set)
|
||||
is_shutting_down: bool = False
|
||||
|
||||
|
||||
class ComputerPool:
|
||||
"""Pool of computer instances for efficient resource management."""
|
||||
|
||||
def __init__(self, max_size: int = 5, idle_timeout: float = 300.0):
|
||||
self.max_size = max_size
|
||||
self.idle_timeout = idle_timeout
|
||||
self._available: List[Any] = []
|
||||
self._in_use: Set[Any] = set()
|
||||
self._creation_lock = asyncio.Lock()
|
||||
|
||||
async def acquire(self) -> Any:
|
||||
"""Acquire a computer instance from the pool."""
|
||||
# Try to get an available instance
|
||||
if self._available:
|
||||
computer = self._available.pop()
|
||||
self._in_use.add(computer)
|
||||
logger.debug("Reusing computer instance from pool")
|
||||
return computer
|
||||
|
||||
# Check if we can create a new one
|
||||
async with self._creation_lock:
|
||||
if len(self._in_use) < self.max_size:
|
||||
logger.debug("Creating new computer instance")
|
||||
from computer import Computer
|
||||
|
||||
computer = Computer(verbosity=logging.INFO)
|
||||
await computer.run()
|
||||
self._in_use.add(computer)
|
||||
return computer
|
||||
|
||||
# Wait for an instance to become available
|
||||
logger.debug("Waiting for computer instance to become available")
|
||||
while not self._available:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
computer = self._available.pop()
|
||||
self._in_use.add(computer)
|
||||
return computer
|
||||
|
||||
async def release(self, computer: Any) -> None:
|
||||
"""Release a computer instance back to the pool."""
|
||||
if computer in self._in_use:
|
||||
self._in_use.remove(computer)
|
||||
self._available.append(computer)
|
||||
logger.debug("Released computer instance back to pool")
|
||||
|
||||
async def cleanup_idle(self) -> None:
|
||||
"""Clean up idle computer instances."""
|
||||
current_time = time.time()
|
||||
idle_instances = []
|
||||
|
||||
for computer in self._available[:]:
|
||||
# Check if computer has been idle too long
|
||||
# Note: We'd need to track last use time per instance for this
|
||||
# For now, we'll keep instances in the pool
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Shutdown all computer instances in the pool."""
|
||||
logger.info("Shutting down computer pool")
|
||||
|
||||
# Close all available instances
|
||||
for computer in self._available:
|
||||
try:
|
||||
if hasattr(computer, "close"):
|
||||
await computer.close()
|
||||
elif hasattr(computer, "stop"):
|
||||
await computer.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing computer instance: {e}")
|
||||
|
||||
# Close all in-use instances
|
||||
for computer in self._in_use:
|
||||
try:
|
||||
if hasattr(computer, "close"):
|
||||
await computer.close()
|
||||
elif hasattr(computer, "stop"):
|
||||
await computer.stop()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing computer instance: {e}")
|
||||
|
||||
self._available.clear()
|
||||
self._in_use.clear()
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages concurrent client sessions with proper resource isolation."""
|
||||
|
||||
def __init__(self, max_concurrent_sessions: int = 10):
|
||||
self.max_concurrent_sessions = max_concurrent_sessions
|
||||
self._sessions: Dict[str, SessionInfo] = {}
|
||||
self._computer_pool = ComputerPool()
|
||||
self._session_lock = asyncio.Lock()
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the session manager and cleanup task."""
|
||||
logger.info("Starting session manager")
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the session manager and cleanup all resources."""
|
||||
logger.info("Stopping session manager")
|
||||
self._shutdown_event.set()
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Force cleanup all sessions
|
||||
async with self._session_lock:
|
||||
session_ids = list(self._sessions.keys())
|
||||
|
||||
for session_id in session_ids:
|
||||
await self._force_cleanup_session(session_id)
|
||||
|
||||
await self._computer_pool.shutdown()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self, session_id: Optional[str] = None) -> Any:
|
||||
"""Get or create a session with proper resource management."""
|
||||
if session_id is None:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Check if session exists and is not shutting down
|
||||
async with self._session_lock:
|
||||
if session_id in self._sessions:
|
||||
session = self._sessions[session_id]
|
||||
if session.is_shutting_down:
|
||||
raise RuntimeError(f"Session {session_id} is shutting down")
|
||||
session.last_activity = time.time()
|
||||
computer = session.computer
|
||||
else:
|
||||
# Create new session
|
||||
if len(self._sessions) >= self.max_concurrent_sessions:
|
||||
raise RuntimeError(
|
||||
f"Maximum concurrent sessions ({self.max_concurrent_sessions}) reached"
|
||||
)
|
||||
|
||||
computer = await self._computer_pool.acquire()
|
||||
session = SessionInfo(
|
||||
session_id=session_id,
|
||||
computer=computer,
|
||||
created_at=time.time(),
|
||||
last_activity=time.time(),
|
||||
)
|
||||
self._sessions[session_id] = session
|
||||
logger.info(f"Created new session: {session_id}")
|
||||
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Update last activity
|
||||
async with self._session_lock:
|
||||
if session_id in self._sessions:
|
||||
self._sessions[session_id].last_activity = time.time()
|
||||
|
||||
async def register_task(self, session_id: str, task_id: str) -> None:
|
||||
"""Register a task for a session."""
|
||||
async with self._session_lock:
|
||||
if session_id in self._sessions:
|
||||
self._sessions[session_id].active_tasks.add(task_id)
|
||||
logger.debug(f"Registered task {task_id} for session {session_id}")
|
||||
|
||||
async def unregister_task(self, session_id: str, task_id: str) -> None:
|
||||
"""Unregister a task from a session."""
|
||||
async with self._session_lock:
|
||||
if session_id in self._sessions:
|
||||
self._sessions[session_id].active_tasks.discard(task_id)
|
||||
logger.debug(f"Unregistered task {task_id} from session {session_id}")
|
||||
|
||||
async def cleanup_session(self, session_id: str) -> None:
|
||||
"""Cleanup a specific session."""
|
||||
async with self._session_lock:
|
||||
if session_id not in self._sessions:
|
||||
return
|
||||
|
||||
session = self._sessions[session_id]
|
||||
|
||||
# Check if session has active tasks
|
||||
if session.active_tasks:
|
||||
logger.info(f"Session {session_id} has active tasks, marking for shutdown")
|
||||
session.is_shutting_down = True
|
||||
return
|
||||
|
||||
# Actually cleanup the session
|
||||
await self._force_cleanup_session(session_id)
|
||||
|
||||
async def _force_cleanup_session(self, session_id: str) -> None:
|
||||
"""Force cleanup a session regardless of active tasks."""
|
||||
async with self._session_lock:
|
||||
if session_id not in self._sessions:
|
||||
return
|
||||
|
||||
session = self._sessions[session_id]
|
||||
logger.info(f"Cleaning up session: {session_id}")
|
||||
|
||||
# Release computer back to pool
|
||||
await self._computer_pool.release(session.computer)
|
||||
|
||||
# Remove session
|
||||
del self._sessions[session_id]
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
"""Background task to cleanup idle sessions."""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
await asyncio.sleep(60) # Run cleanup every minute
|
||||
|
||||
current_time = time.time()
|
||||
idle_timeout = 600.0 # 10 minutes
|
||||
|
||||
async with self._session_lock:
|
||||
idle_sessions = []
|
||||
for session_id, session in self._sessions.items():
|
||||
if not session.is_shutting_down and not session.active_tasks:
|
||||
if current_time - session.last_activity > idle_timeout:
|
||||
idle_sessions.append(session_id)
|
||||
|
||||
# Cleanup idle sessions
|
||||
for session_id in idle_sessions:
|
||||
await self._force_cleanup_session(session_id)
|
||||
logger.info(f"Cleaned up idle session: {session_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about active sessions."""
|
||||
|
||||
async def _get_stats():
|
||||
async with self._session_lock:
|
||||
return {
|
||||
"total_sessions": len(self._sessions),
|
||||
"max_concurrent": self.max_concurrent_sessions,
|
||||
"sessions": {
|
||||
session_id: {
|
||||
"created_at": session.created_at,
|
||||
"last_activity": session.last_activity,
|
||||
"active_tasks": len(session.active_tasks),
|
||||
"is_shutting_down": session.is_shutting_down,
|
||||
}
|
||||
for session_id, session in self._sessions.items()
|
||||
},
|
||||
}
|
||||
|
||||
# Run in current event loop or create new one
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.run_coroutine_threadsafe(_get_stats(), loop).result()
|
||||
except RuntimeError:
|
||||
# No event loop running, create a new one
|
||||
return asyncio.run(_get_stats())
|
||||
|
||||
|
||||
# Global session manager instance
|
||||
_session_manager: Optional[SessionManager] = None
|
||||
|
||||
|
||||
def get_session_manager() -> SessionManager:
|
||||
"""Get the global session manager instance."""
|
||||
global _session_manager
|
||||
if _session_manager is None:
|
||||
_session_manager = SessionManager()
|
||||
return _session_manager
|
||||
|
||||
|
||||
async def initialize_session_manager() -> None:
|
||||
"""Initialize the global session manager."""
|
||||
global _session_manager
|
||||
if _session_manager is None:
|
||||
_session_manager = SessionManager()
|
||||
await _session_manager.start()
|
||||
return _session_manager
|
||||
|
||||
|
||||
async def shutdown_session_manager() -> None:
|
||||
"""Shutdown the global session manager."""
|
||||
global _session_manager
|
||||
if _session_manager is not None:
|
||||
await _session_manager.stop()
|
||||
_session_manager = None
|
||||
Reference in New Issue
Block a user