From 217b108bd145d6fabd121009a4be6794d0a22c80 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Thu, 10 Jul 2025 14:57:55 -0400 Subject: [PATCH] Added /cmd REST endpoint to computer-server --- .../computer-server/computer_server/main.py | 234 ++++++++++++++---- 1 file changed, 181 insertions(+), 53 deletions(-) diff --git a/libs/python/computer-server/computer_server/main.py b/libs/python/computer-server/computer_server/main.py index 29b19faf..7dc0a61d 100644 --- a/libs/python/computer-server/computer_server/main.py +++ b/libs/python/computer-server/computer_server/main.py @@ -1,5 +1,6 @@ -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from typing import List, Dict, Any +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header +from fastapi.responses import StreamingResponse +from typing import List, Dict, Any, Optional import uvicorn import logging import asyncio @@ -27,12 +28,67 @@ app = FastAPI( websocket_max_size=WEBSOCKET_MAX_SIZE, ) +protocol_version = 1 +try: + import pkg_resources + package_version = pkg_resources.get_distribution("cua-computer-server").version +except pkg_resources.DistributionNotFound: + package_version = "unknown" + +accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers() +handlers = { + "version": lambda: {"protocol": protocol_version, "package": package_version}, + # App-Use commands + "diorama_cmd": diorama_handler.diorama_cmd, + # Accessibility commands + "get_accessibility_tree": accessibility_handler.get_accessibility_tree, + "find_element": accessibility_handler.find_element, + # Shell commands + "run_command": automation_handler.run_command, + # File system commands + "file_exists": file_handler.file_exists, + "directory_exists": file_handler.directory_exists, + "list_dir": file_handler.list_dir, + "read_text": file_handler.read_text, + "write_text": file_handler.write_text, + "read_bytes": file_handler.read_bytes, + "write_bytes": file_handler.write_bytes, + "get_file_size": file_handler.get_file_size, + "delete_file": file_handler.delete_file, + "create_dir": file_handler.create_dir, + "delete_dir": file_handler.delete_dir, + # Mouse commands + "mouse_down": automation_handler.mouse_down, + "mouse_up": automation_handler.mouse_up, + "left_click": automation_handler.left_click, + "right_click": automation_handler.right_click, + "double_click": automation_handler.double_click, + "move_cursor": automation_handler.move_cursor, + "drag_to": automation_handler.drag_to, + "drag": automation_handler.drag, + # Keyboard commands + "key_down": automation_handler.key_down, + "key_up": automation_handler.key_up, + "type_text": automation_handler.type_text, + "press_key": automation_handler.press_key, + "hotkey": automation_handler.hotkey, + # Scrolling actions + "scroll": automation_handler.scroll, + "scroll_down": automation_handler.scroll_down, + "scroll_up": automation_handler.scroll_up, + # Screen actions + "screenshot": automation_handler.screenshot, + "get_cursor_position": automation_handler.get_cursor_position, + "get_screen_size": automation_handler.get_screen_size, + # Clipboard actions + "copy_to_clipboard": automation_handler.copy_to_clipboard, + "set_clipboard": automation_handler.set_clipboard, +} + class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] - # Create OS-specific handlers - self.accessibility_handler, self.automation_handler, self.diorama_handler, self.file_handler = HandlerFactory.create_handlers() async def connect(self, websocket: WebSocket): await websocket.accept() @@ -47,6 +103,8 @@ manager = ConnectionManager() @app.websocket("/ws", name="websocket_endpoint") async def websocket_endpoint(websocket: WebSocket): + global handlers + # WebSocket message size is configured at the app or endpoint level, not on the instance await manager.connect(websocket) @@ -156,55 +214,6 @@ async def websocket_endpoint(websocket: WebSocket): manager.disconnect(websocket) return - # Map commands to appropriate handler methods - handlers = { - # App-Use commands - "diorama_cmd": manager.diorama_handler.diorama_cmd, - # Accessibility commands - "get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree, - "find_element": manager.accessibility_handler.find_element, - # Shell commands - "run_command": manager.automation_handler.run_command, - # File system commands - "file_exists": manager.file_handler.file_exists, - "directory_exists": manager.file_handler.directory_exists, - "list_dir": manager.file_handler.list_dir, - "read_text": manager.file_handler.read_text, - "write_text": manager.file_handler.write_text, - "read_bytes": manager.file_handler.read_bytes, - "write_bytes": manager.file_handler.write_bytes, - "get_file_size": manager.file_handler.get_file_size, - "delete_file": manager.file_handler.delete_file, - "create_dir": manager.file_handler.create_dir, - "delete_dir": manager.file_handler.delete_dir, - # Mouse commands - "mouse_down": manager.automation_handler.mouse_down, - "mouse_up": manager.automation_handler.mouse_up, - "left_click": manager.automation_handler.left_click, - "right_click": manager.automation_handler.right_click, - "double_click": manager.automation_handler.double_click, - "move_cursor": manager.automation_handler.move_cursor, - "drag_to": manager.automation_handler.drag_to, - "drag": manager.automation_handler.drag, - # Keyboard commands - "key_down": manager.automation_handler.key_down, - "key_up": manager.automation_handler.key_up, - "type_text": manager.automation_handler.type_text, - "press_key": manager.automation_handler.press_key, - "hotkey": manager.automation_handler.hotkey, - # Scrolling actions - "scroll": manager.automation_handler.scroll, - "scroll_down": manager.automation_handler.scroll_down, - "scroll_up": manager.automation_handler.scroll_up, - # Screen actions - "screenshot": manager.automation_handler.screenshot, - "get_cursor_position": manager.automation_handler.get_cursor_position, - "get_screen_size": manager.automation_handler.get_screen_size, - # Clipboard actions - "copy_to_clipboard": manager.automation_handler.copy_to_clipboard, - "set_clipboard": manager.automation_handler.set_clipboard, - } - try: while True: try: @@ -256,5 +265,124 @@ async def websocket_endpoint(websocket: WebSocket): manager.disconnect(websocket) +@app.post("/cmd") +async def cmd_endpoint( + request: Request, + container_name: Optional[str] = Header(None, alias="X-Container-Name"), + api_key: Optional[str] = Header(None, alias="X-API-Key") +): + """ + Backup endpoint for when WebSocket connections fail. + Accepts commands via HTTP POST with streaming response. + + Headers: + - X-Container-Name: Container name for cloud authentication + - X-API-Key: API key for cloud authentication + + Body: + { + "command": "command_name", + "params": {...} + } + """ + global handlers + + # Parse request body + try: + body = await request.json() + command = body.get("command") + params = body.get("params", {}) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}") + + if not command: + raise HTTPException(status_code=400, detail="Command is required") + + # Check if CONTAINER_NAME is set (indicating cloud provider) + server_container_name = os.environ.get("CONTAINER_NAME") + + # If cloud provider, perform authentication + if server_container_name: + logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...") + + # Layer 1: VM Identity Verification + if container_name != server_container_name: + logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}") + raise HTTPException(status_code=401, detail="VM name mismatch") + + # Layer 2: API Key Validation with TryCUA API + if not api_key: + raise HTTPException(status_code=401, detail="API key required") + + # Validate with TryCUA API + try: + async with aiohttp.ClientSession() as session: + headers = { + "Authorization": f"Bearer {api_key}" + } + + async with session.get( + f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}", + headers=headers, + ) as resp: + if resp.status != 200: + logger.warning(f"API key validation failed. Status: {resp.status}") + raise HTTPException(status_code=401, detail="Invalid API key") + + auth_response = await resp.json() + if not auth_response.get("success"): + logger.warning(f"API key validation failed. Response: {auth_response}") + raise HTTPException(status_code=401, detail="Invalid API key") + + logger.info("Authentication successful") + + except aiohttp.ClientError as e: + logger.error(f"Failed to validate API key with TryCUA API: {str(e)}") + raise HTTPException(status_code=500, detail="Authentication service unavailable") + except HTTPException: + raise + except Exception as e: + logger.error(f"Unexpected error during authentication: {str(e)}") + raise HTTPException(status_code=500, detail="Authentication failed") + + if command not in handlers: + raise HTTPException(status_code=400, detail=f"Unknown command: {command}") + + async def generate_response(): + """Generate streaming response for the command execution""" + try: + # Filter params to only include those accepted by the handler function + handler_func = handlers[command] + sig = inspect.signature(handler_func) + filtered_params = {k: v for k, v in params.items() if k in sig.parameters} + + # Execute the command + result = await handler_func(**filtered_params) + + # Stream the successful result + response_data = {"success": True, **result} + yield f"data: {json.dumps(response_data)}\n\n" + + except Exception as cmd_error: + logger.error(f"Error executing command {command}: {str(cmd_error)}") + logger.error(traceback.format_exc()) + + # Stream the error result + error_data = {"success": False, "error": str(cmd_error)} + yield f"data: {json.dumps(error_data)}\n\n" + + return StreamingResponse( + generate_response(), + media_type="text/plain", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "POST, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, X-Container-Name, X-API-Key" + } + ) + + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)