Files
computer/libs/python/computer-server/computer_server/main.py
2025-08-28 13:59:44 -04:00

689 lines
26 KiB
Python

from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
from fastapi.responses import StreamingResponse, JSONResponse
from typing import List, Dict, Any, Optional, Union, Literal, cast
import uvicorn
import logging
import asyncio
import json
import traceback
import inspect
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
import hashlib
import time
import platform
from fastapi.middleware.cors import CORSMiddleware
# Set up logging with more detail
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Configure WebSocket with larger message size
WEBSOCKET_MAX_SIZE = 1024 * 1024 * 10 # 10MB limit
# Configure application with WebSocket settings
app = FastAPI(
title="Computer API",
description="API for the Computer project",
version="0.1.0",
websocket_max_size=WEBSOCKET_MAX_SIZE,
debug=True
)
# CORS configuration
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
protocol_version = 1
try:
from importlib.metadata import version
package_version = version("cua-computer-server")
except Exception:
# Fallback for cases where package is not installed or importlib.metadata is not available
try:
import pkg_resources
package_version = pkg_resources.get_distribution("cua-computer-server").version
except Exception:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
"diorama_cmd": diorama_handler.diorama_cmd,
# Accessibility commands
"get_accessibility_tree": accessibility_handler.get_accessibility_tree,
"find_element": accessibility_handler.find_element,
# Shell commands
"run_command": automation_handler.run_command,
# File system commands
"file_exists": file_handler.file_exists,
"directory_exists": file_handler.directory_exists,
"list_dir": file_handler.list_dir,
"read_text": file_handler.read_text,
"write_text": file_handler.write_text,
"read_bytes": file_handler.read_bytes,
"write_bytes": file_handler.write_bytes,
"get_file_size": file_handler.get_file_size,
"delete_file": file_handler.delete_file,
"create_dir": file_handler.create_dir,
"delete_dir": file_handler.delete_dir,
# Mouse commands
"mouse_down": automation_handler.mouse_down,
"mouse_up": automation_handler.mouse_up,
"left_click": automation_handler.left_click,
"right_click": automation_handler.right_click,
"double_click": automation_handler.double_click,
"move_cursor": automation_handler.move_cursor,
"drag_to": automation_handler.drag_to,
"drag": automation_handler.drag,
# Keyboard commands
"key_down": automation_handler.key_down,
"key_up": automation_handler.key_up,
"type_text": automation_handler.type_text,
"press_key": automation_handler.press_key,
"hotkey": automation_handler.hotkey,
# Scrolling actions
"scroll": automation_handler.scroll,
"scroll_down": automation_handler.scroll_down,
"scroll_up": automation_handler.scroll_up,
# Screen actions
"screenshot": automation_handler.screenshot,
"get_cursor_position": automation_handler.get_cursor_position,
"get_screen_size": automation_handler.get_screen_size,
# Clipboard actions
"copy_to_clipboard": automation_handler.copy_to_clipboard,
"set_clipboard": automation_handler.set_clipboard,
}
class AuthenticationManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.container_name = os.environ.get("CONTAINER_NAME")
def _hash_credentials(self, container_name: str, api_key: str) -> str:
"""Create a hash of container name and API key for session identification"""
combined = f"{container_name}:{api_key}"
return hashlib.sha256(combined.encode()).hexdigest()
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
"""Check if a session is still valid based on expiration time"""
if not session_data.get('valid', False):
return False
expires_at = session_data.get('expires_at', 0)
return time.time() < expires_at
async def auth(self, container_name: str, api_key: str) -> bool:
"""Authenticate container name and API key, using cached sessions when possible"""
# If no CONTAINER_NAME is set, always allow access (local development)
if not self.container_name:
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
return True
# Layer 1: VM Identity Verification
if container_name != self.container_name:
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
return False
# Create hash for session lookup
session_hash = self._hash_credentials(container_name, api_key)
# Check if we have a valid cached session
if session_hash in self.sessions:
session_data = self.sessions[session_hash]
if self._is_session_valid(session_data):
logger.info(f"Using cached authentication for container: {container_name}")
return session_data['valid']
else:
# Remove expired session
del self.sessions[session_hash]
# No valid cached session, authenticate with API
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
is_valid = resp.status == 200 and bool((await resp.text()).strip())
# Cache the result with 5 second expiration
self.sessions[session_hash] = {
'valid': is_valid,
'expires_at': time.time() + 5 # 5 seconds from now
}
if is_valid:
logger.info(f"Authentication successful for container: {container_name}")
else:
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
return is_valid
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
async def connect(self, websocket: WebSocket):
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
manager = ConnectionManager()
auth_manager = AuthenticationManager()
@app.get("/status")
async def status():
sys = platform.system().lower()
if "darwin" in sys or sys == "macos" or sys == "mac":
sys = "mac"
if "windows" in sys:
sys = "windows"
return {"status": "ok", "os": sys}
@app.websocket("/ws", name="websocket_endpoint")
async def websocket_endpoint(websocket: WebSocket):
global handlers
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if server_container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Validate credentials using AuthenticationManager
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.close()
manager.disconnect(websocket)
return
if not client_container_name:
await websocket.send_json({
"success": False,
"error": "Container name required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Use AuthenticationManager for validation
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
if not is_authenticated:
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {client_container_name}")
await websocket.send_json({
"success": True,
"message": "Authentication successful"
})
except Exception as e:
logger.error(f"Error during authentication handshake: {str(e)}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
try:
while True:
try:
data = await websocket.receive_json()
command = data.get("command")
params = data.get("params", {})
if command not in handlers:
await websocket.send_json(
{"success": False, "error": f"Unknown command: {command}"}
)
continue
try:
# Filter params to only include those accepted by the handler function
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
await websocket.send_json({"success": True, **result})
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
await websocket.send_json({"success": False, "error": str(cmd_error)})
except WebSocketDisconnect:
raise
except json.JSONDecodeError as json_err:
logger.error(f"JSON decode error: {str(json_err)}")
await websocket.send_json(
{"success": False, "error": f"Invalid JSON: {str(json_err)}"}
)
except Exception as loop_error:
logger.error(f"Error in message loop: {str(loop_error)}")
logger.error(traceback.format_exc())
await websocket.send_json({"success": False, "error": str(loop_error)})
except WebSocketDisconnect:
logger.info("Client disconnected")
manager.disconnect(websocket)
except Exception as e:
logger.error(f"Fatal error in websocket connection: {str(e)}")
logger.error(traceback.format_exc())
try:
await websocket.close()
except:
pass
manager.disconnect(websocket)
@app.post("/cmd")
async def cmd_endpoint(
request: Request,
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
api_key: Optional[str] = Header(None, alias="X-API-Key")
):
"""
Backup endpoint for when WebSocket connections fail.
Accepts commands via HTTP POST with streaming response.
Headers:
- X-Container-Name: Container name for cloud authentication
- X-API-Key: API key for cloud authentication
Body:
{
"command": "command_name",
"params": {...}
}
"""
global handlers
# Parse request body
try:
body = await request.json()
command = body.get("command")
params = body.get("params", {})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
# Validate required headers
if not container_name:
raise HTTPException(status_code=401, detail="Container name required")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with AuthenticationManager
is_authenticated = await auth_manager.auth(container_name, api_key)
if not is_authenticated:
raise HTTPException(status_code=401, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
async def generate_response():
"""Generate streaming response for the command execution"""
try:
# Filter params to only include those accepted by the handler function
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
# Stream the successful result
response_data = {"success": True, **result}
yield f"data: {json.dumps(response_data)}\n\n"
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
# Stream the error result
error_data = {"success": False, "error": str(cmd_error)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
)
@app.post("/responses")
async def agent_response_endpoint(
request: Request,
api_key: Optional[str] = Header(None, alias="X-API-Key"),
):
"""
Minimal proxy to run ComputerAgent for up to 2 turns.
Security:
- If CONTAINER_NAME is set on the server, require X-API-Key
and validate using AuthenticationManager unless CUA_ENABLE_PUBLIC_PROXY is true.
Body JSON:
{
"model": "...", # required
"input": "... or messages[]", # required
"agent_kwargs": { ... }, # optional, passed directly to ComputerAgent
"env": { ... } # optional env overrides for agent
}
"""
from agent.agent import ComputerAgent
# Authenticate via AuthenticationManager if running in cloud (CONTAINER_NAME set)
container_name = os.environ.get("CONTAINER_NAME")
if container_name:
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in ["1", "true", "yes", "y", "on"]
if not is_public:
if not api_key:
raise HTTPException(status_code=401, detail="Missing AGENT PROXY auth headers")
ok = await auth_manager.auth(container_name, api_key)
if not ok:
raise HTTPException(status_code=401, detail="Unauthorized")
# Parse request body
try:
body = await request.json()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
model = body.get("model")
input_data = body.get("input")
if not model or input_data is None:
raise HTTPException(status_code=400, detail="'model' and 'input' are required")
agent_kwargs: Dict[str, Any] = body.get("agent_kwargs") or {}
env_overrides: Dict[str, str] = body.get("env") or {}
# Simple env override context
class _EnvOverride:
def __init__(self, overrides: Dict[str, str]):
self.overrides = overrides
self._original: Dict[str, Optional[str]] = {}
def __enter__(self):
for k, v in (self.overrides or {}).items():
self._original[k] = os.environ.get(k)
os.environ[k] = str(v)
def __exit__(self, exc_type, exc, tb):
for k, old in self._original.items():
if old is None:
os.environ.pop(k, None)
else:
os.environ[k] = old
# Convert input to messages
def _to_messages(data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
if isinstance(data, str):
return [{"role": "user", "content": data}]
if isinstance(data, list):
return data
messages = _to_messages(input_data)
# Define a direct computer tool that implements the AsyncComputerHandler protocol
# and delegates to our existing automation/file/accessibility handlers.
from agent.computers import AsyncComputerHandler # runtime-checkable Protocol
class DirectComputer(AsyncComputerHandler):
def __init__(self):
# use module-scope handler singletons created by HandlerFactory
self._auto = automation_handler
self._file = file_handler
self._access = accessibility_handler
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
sys = platform.system().lower()
if "darwin" in sys or sys in ("macos", "mac"):
return "mac"
if "windows" in sys:
return "windows"
return "linux"
async def get_dimensions(self) -> tuple[int, int]:
size = await self._auto.get_screen_size()
return size["width"], size["height"]
async def screenshot(self) -> str:
img_b64 = await self._auto.screenshot()
return img_b64["image_data"]
async def click(self, x: int, y: int, button: str = "left") -> None:
if button == "left":
await self._auto.left_click(x, y)
elif button == "right":
await self._auto.right_click(x, y)
else:
await self._auto.left_click(x, y)
async def double_click(self, x: int, y: int) -> None:
await self._auto.double_click(x, y)
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
await self._auto.move_cursor(x, y)
await self._auto.scroll(scroll_x, scroll_y)
async def type(self, text: str) -> None:
await self._auto.type_text(text)
async def wait(self, ms: int = 1000) -> None:
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
await self._auto.move_cursor(x, y)
async def keypress(self, keys: Union[List[str], str]) -> None:
if isinstance(keys, str):
parts = keys.replace("-", "+").split("+")
else:
parts = keys
if len(parts) == 1:
await self._auto.press_key(parts[0])
else:
await self._auto.hotkey(*parts)
async def drag(self, path: List[Dict[str, int]]) -> None:
if not path:
return
start = path[0]
await self._auto.mouse_down(start["x"], start["y"])
for pt in path[1:]:
await self._auto.move_cursor(pt["x"], pt["y"])
end = path[-1]
await self._auto.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
# Not available in this server context
return ""
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._auto.mouse_down(x, y, button="left")
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._auto.mouse_up(x, y, button="left")
# # Inline image URLs to base64
# import base64, mimetypes, requests
# # Use a browser-like User-Agent to avoid 403s from some CDNs (e.g., Wikimedia)
# HEADERS = {
# "User-Agent": (
# "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
# "AppleWebKit/537.36 (KHTML, like Gecko) "
# "Chrome/124.0.0.0 Safari/537.36"
# )
# }
# def _to_data_url(content_bytes: bytes, url: str, resp: requests.Response) -> str:
# ctype = resp.headers.get("Content-Type") or mimetypes.guess_type(url)[0] or "application/octet-stream"
# b64 = base64.b64encode(content_bytes).decode("utf-8")
# return f"data:{ctype};base64,{b64}"
# def inline_image_urls(messages):
# # messages: List[{"role": "...","content":[...]}]
# out = []
# for m in messages:
# if not isinstance(m.get("content"), list):
# out.append(m)
# continue
# new_content = []
# for part in (m.get("content") or []):
# if part.get("type") == "input_image" and (url := part.get("image_url")):
# resp = requests.get(url, headers=HEADERS, timeout=30)
# resp.raise_for_status()
# new_content.append({
# "type": "input_image",
# "image_url": _to_data_url(resp.content, url, resp)
# })
# else:
# new_content.append(part)
# out.append({**m, "content": new_content})
# return out
# messages = inline_image_urls(messages)
error = None
with _EnvOverride(env_overrides):
# Prepare tools: if caller did not pass tools, inject our DirectComputer
tools = agent_kwargs.get("tools")
if not tools:
tools = [DirectComputer()]
agent_kwargs = {**agent_kwargs, "tools": tools}
# Instantiate agent with our tools
agent = ComputerAgent(model=model, **agent_kwargs) # type: ignore[arg-type]
total_output: List[Any] = []
total_usage: Dict[str, Any] = {}
turns = 0
try:
async for result in agent.run(messages):
total_output += result["output"]
# Try to collect usage if present
if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict):
# Merge usage counters
for k, v in result["usage"].items():
if isinstance(v, (int, float)):
total_usage[k] = total_usage.get(k, 0) + v
else:
total_usage[k] = v
turns += 1
if turns > 2:
break
except Exception as e:
logger.error(f"Error running agent: {str(e)}")
logger.error(traceback.format_exc())
error = str(e)
# Build response payload
payload = {
"model": model,
"error": error,
"output": total_output,
"usage": total_usage,
"status": "completed" if not error else "failed"
}
# CORS: allow any origin
headers = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
return JSONResponse(content=payload, headers=headers)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)