Added /cmd REST endpoint to computer-server

This commit is contained in:
Dillon DuPont
2025-07-10 14:57:55 -04:00
parent cb2f8c3d2a
commit 217b108bd1

View File

@@ -1,5 +1,6 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List, Dict, Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
from fastapi.responses import StreamingResponse
from typing import List, Dict, Any, Optional
import uvicorn
import logging
import asyncio
@@ -27,12 +28,67 @@ app = FastAPI(
websocket_max_size=WEBSOCKET_MAX_SIZE,
)
protocol_version = 1
try:
import pkg_resources
package_version = pkg_resources.get_distribution("cua-computer-server").version
except pkg_resources.DistributionNotFound:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
"diorama_cmd": diorama_handler.diorama_cmd,
# Accessibility commands
"get_accessibility_tree": accessibility_handler.get_accessibility_tree,
"find_element": accessibility_handler.find_element,
# Shell commands
"run_command": automation_handler.run_command,
# File system commands
"file_exists": file_handler.file_exists,
"directory_exists": file_handler.directory_exists,
"list_dir": file_handler.list_dir,
"read_text": file_handler.read_text,
"write_text": file_handler.write_text,
"read_bytes": file_handler.read_bytes,
"write_bytes": file_handler.write_bytes,
"get_file_size": file_handler.get_file_size,
"delete_file": file_handler.delete_file,
"create_dir": file_handler.create_dir,
"delete_dir": file_handler.delete_dir,
# Mouse commands
"mouse_down": automation_handler.mouse_down,
"mouse_up": automation_handler.mouse_up,
"left_click": automation_handler.left_click,
"right_click": automation_handler.right_click,
"double_click": automation_handler.double_click,
"move_cursor": automation_handler.move_cursor,
"drag_to": automation_handler.drag_to,
"drag": automation_handler.drag,
# Keyboard commands
"key_down": automation_handler.key_down,
"key_up": automation_handler.key_up,
"type_text": automation_handler.type_text,
"press_key": automation_handler.press_key,
"hotkey": automation_handler.hotkey,
# Scrolling actions
"scroll": automation_handler.scroll,
"scroll_down": automation_handler.scroll_down,
"scroll_up": automation_handler.scroll_up,
# Screen actions
"screenshot": automation_handler.screenshot,
"get_cursor_position": automation_handler.get_cursor_position,
"get_screen_size": automation_handler.get_screen_size,
# Clipboard actions
"copy_to_clipboard": automation_handler.copy_to_clipboard,
"set_clipboard": automation_handler.set_clipboard,
}
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
# Create OS-specific handlers
self.accessibility_handler, self.automation_handler, self.diorama_handler, self.file_handler = HandlerFactory.create_handlers()
async def connect(self, websocket: WebSocket):
await websocket.accept()
@@ -47,6 +103,8 @@ manager = ConnectionManager()
@app.websocket("/ws", name="websocket_endpoint")
async def websocket_endpoint(websocket: WebSocket):
global handlers
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
@@ -156,55 +214,6 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
return
# Map commands to appropriate handler methods
handlers = {
# App-Use commands
"diorama_cmd": manager.diorama_handler.diorama_cmd,
# Accessibility commands
"get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree,
"find_element": manager.accessibility_handler.find_element,
# Shell commands
"run_command": manager.automation_handler.run_command,
# File system commands
"file_exists": manager.file_handler.file_exists,
"directory_exists": manager.file_handler.directory_exists,
"list_dir": manager.file_handler.list_dir,
"read_text": manager.file_handler.read_text,
"write_text": manager.file_handler.write_text,
"read_bytes": manager.file_handler.read_bytes,
"write_bytes": manager.file_handler.write_bytes,
"get_file_size": manager.file_handler.get_file_size,
"delete_file": manager.file_handler.delete_file,
"create_dir": manager.file_handler.create_dir,
"delete_dir": manager.file_handler.delete_dir,
# Mouse commands
"mouse_down": manager.automation_handler.mouse_down,
"mouse_up": manager.automation_handler.mouse_up,
"left_click": manager.automation_handler.left_click,
"right_click": manager.automation_handler.right_click,
"double_click": manager.automation_handler.double_click,
"move_cursor": manager.automation_handler.move_cursor,
"drag_to": manager.automation_handler.drag_to,
"drag": manager.automation_handler.drag,
# Keyboard commands
"key_down": manager.automation_handler.key_down,
"key_up": manager.automation_handler.key_up,
"type_text": manager.automation_handler.type_text,
"press_key": manager.automation_handler.press_key,
"hotkey": manager.automation_handler.hotkey,
# Scrolling actions
"scroll": manager.automation_handler.scroll,
"scroll_down": manager.automation_handler.scroll_down,
"scroll_up": manager.automation_handler.scroll_up,
# Screen actions
"screenshot": manager.automation_handler.screenshot,
"get_cursor_position": manager.automation_handler.get_cursor_position,
"get_screen_size": manager.automation_handler.get_screen_size,
# Clipboard actions
"copy_to_clipboard": manager.automation_handler.copy_to_clipboard,
"set_clipboard": manager.automation_handler.set_clipboard,
}
try:
while True:
try:
@@ -256,5 +265,124 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
@app.post("/cmd")
async def cmd_endpoint(
request: Request,
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
api_key: Optional[str] = Header(None, alias="X-API-Key")
):
"""
Backup endpoint for when WebSocket connections fail.
Accepts commands via HTTP POST with streaming response.
Headers:
- X-Container-Name: Container name for cloud authentication
- X-API-Key: API key for cloud authentication
Body:
{
"command": "command_name",
"params": {...}
}
"""
global handlers
# Parse request body
try:
body = await request.json()
command = body.get("command")
params = body.get("params", {})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
# Layer 1: VM Identity Verification
if container_name != server_container_name:
logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}")
raise HTTPException(status_code=401, detail="VM name mismatch")
# Layer 2: API Key Validation with TryCUA API
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}",
headers=headers,
) as resp:
if resp.status != 200:
logger.warning(f"API key validation failed. Status: {resp.status}")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_response = await resp.json()
if not auth_response.get("success"):
logger.warning(f"API key validation failed. Response: {auth_response}")
raise HTTPException(status_code=401, detail="Invalid API key")
logger.info("Authentication successful")
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
raise HTTPException(status_code=500, detail="Authentication service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
raise HTTPException(status_code=500, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
async def generate_response():
"""Generate streaming response for the command execution"""
try:
# Filter params to only include those accepted by the handler function
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Execute the command
result = await handler_func(**filtered_params)
# Stream the successful result
response_data = {"success": True, **result}
yield f"data: {json.dumps(response_data)}\n\n"
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
# Stream the error result
error_data = {"success": False, "error": str(cmd_error)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, X-Container-Name, X-API-Key"
}
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)