mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 12:30:08 -06:00
Added /cmd REST endpoint to computer-server
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uvicorn
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -27,12 +28,67 @@ app = FastAPI(
|
||||
websocket_max_size=WEBSOCKET_MAX_SIZE,
|
||||
)
|
||||
|
||||
protocol_version = 1
|
||||
try:
|
||||
import pkg_resources
|
||||
package_version = pkg_resources.get_distribution("cua-computer-server").version
|
||||
except pkg_resources.DistributionNotFound:
|
||||
package_version = "unknown"
|
||||
|
||||
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
|
||||
handlers = {
|
||||
"version": lambda: {"protocol": protocol_version, "package": package_version},
|
||||
# App-Use commands
|
||||
"diorama_cmd": diorama_handler.diorama_cmd,
|
||||
# Accessibility commands
|
||||
"get_accessibility_tree": accessibility_handler.get_accessibility_tree,
|
||||
"find_element": accessibility_handler.find_element,
|
||||
# Shell commands
|
||||
"run_command": automation_handler.run_command,
|
||||
# File system commands
|
||||
"file_exists": file_handler.file_exists,
|
||||
"directory_exists": file_handler.directory_exists,
|
||||
"list_dir": file_handler.list_dir,
|
||||
"read_text": file_handler.read_text,
|
||||
"write_text": file_handler.write_text,
|
||||
"read_bytes": file_handler.read_bytes,
|
||||
"write_bytes": file_handler.write_bytes,
|
||||
"get_file_size": file_handler.get_file_size,
|
||||
"delete_file": file_handler.delete_file,
|
||||
"create_dir": file_handler.create_dir,
|
||||
"delete_dir": file_handler.delete_dir,
|
||||
# Mouse commands
|
||||
"mouse_down": automation_handler.mouse_down,
|
||||
"mouse_up": automation_handler.mouse_up,
|
||||
"left_click": automation_handler.left_click,
|
||||
"right_click": automation_handler.right_click,
|
||||
"double_click": automation_handler.double_click,
|
||||
"move_cursor": automation_handler.move_cursor,
|
||||
"drag_to": automation_handler.drag_to,
|
||||
"drag": automation_handler.drag,
|
||||
# Keyboard commands
|
||||
"key_down": automation_handler.key_down,
|
||||
"key_up": automation_handler.key_up,
|
||||
"type_text": automation_handler.type_text,
|
||||
"press_key": automation_handler.press_key,
|
||||
"hotkey": automation_handler.hotkey,
|
||||
# Scrolling actions
|
||||
"scroll": automation_handler.scroll,
|
||||
"scroll_down": automation_handler.scroll_down,
|
||||
"scroll_up": automation_handler.scroll_up,
|
||||
# Screen actions
|
||||
"screenshot": automation_handler.screenshot,
|
||||
"get_cursor_position": automation_handler.get_cursor_position,
|
||||
"get_screen_size": automation_handler.get_screen_size,
|
||||
# Clipboard actions
|
||||
"copy_to_clipboard": automation_handler.copy_to_clipboard,
|
||||
"set_clipboard": automation_handler.set_clipboard,
|
||||
}
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
# Create OS-specific handlers
|
||||
self.accessibility_handler, self.automation_handler, self.diorama_handler, self.file_handler = HandlerFactory.create_handlers()
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@@ -47,6 +103,8 @@ manager = ConnectionManager()
|
||||
|
||||
@app.websocket("/ws", name="websocket_endpoint")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global handlers
|
||||
|
||||
# WebSocket message size is configured at the app or endpoint level, not on the instance
|
||||
await manager.connect(websocket)
|
||||
|
||||
@@ -156,55 +214,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Map commands to appropriate handler methods
|
||||
handlers = {
|
||||
# App-Use commands
|
||||
"diorama_cmd": manager.diorama_handler.diorama_cmd,
|
||||
# Accessibility commands
|
||||
"get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree,
|
||||
"find_element": manager.accessibility_handler.find_element,
|
||||
# Shell commands
|
||||
"run_command": manager.automation_handler.run_command,
|
||||
# File system commands
|
||||
"file_exists": manager.file_handler.file_exists,
|
||||
"directory_exists": manager.file_handler.directory_exists,
|
||||
"list_dir": manager.file_handler.list_dir,
|
||||
"read_text": manager.file_handler.read_text,
|
||||
"write_text": manager.file_handler.write_text,
|
||||
"read_bytes": manager.file_handler.read_bytes,
|
||||
"write_bytes": manager.file_handler.write_bytes,
|
||||
"get_file_size": manager.file_handler.get_file_size,
|
||||
"delete_file": manager.file_handler.delete_file,
|
||||
"create_dir": manager.file_handler.create_dir,
|
||||
"delete_dir": manager.file_handler.delete_dir,
|
||||
# Mouse commands
|
||||
"mouse_down": manager.automation_handler.mouse_down,
|
||||
"mouse_up": manager.automation_handler.mouse_up,
|
||||
"left_click": manager.automation_handler.left_click,
|
||||
"right_click": manager.automation_handler.right_click,
|
||||
"double_click": manager.automation_handler.double_click,
|
||||
"move_cursor": manager.automation_handler.move_cursor,
|
||||
"drag_to": manager.automation_handler.drag_to,
|
||||
"drag": manager.automation_handler.drag,
|
||||
# Keyboard commands
|
||||
"key_down": manager.automation_handler.key_down,
|
||||
"key_up": manager.automation_handler.key_up,
|
||||
"type_text": manager.automation_handler.type_text,
|
||||
"press_key": manager.automation_handler.press_key,
|
||||
"hotkey": manager.automation_handler.hotkey,
|
||||
# Scrolling actions
|
||||
"scroll": manager.automation_handler.scroll,
|
||||
"scroll_down": manager.automation_handler.scroll_down,
|
||||
"scroll_up": manager.automation_handler.scroll_up,
|
||||
# Screen actions
|
||||
"screenshot": manager.automation_handler.screenshot,
|
||||
"get_cursor_position": manager.automation_handler.get_cursor_position,
|
||||
"get_screen_size": manager.automation_handler.get_screen_size,
|
||||
# Clipboard actions
|
||||
"copy_to_clipboard": manager.automation_handler.copy_to_clipboard,
|
||||
"set_clipboard": manager.automation_handler.set_clipboard,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
@@ -256,5 +265,124 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.post("/cmd")
|
||||
async def cmd_endpoint(
|
||||
request: Request,
|
||||
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
|
||||
api_key: Optional[str] = Header(None, alias="X-API-Key")
|
||||
):
|
||||
"""
|
||||
Backup endpoint for when WebSocket connections fail.
|
||||
Accepts commands via HTTP POST with streaming response.
|
||||
|
||||
Headers:
|
||||
- X-Container-Name: Container name for cloud authentication
|
||||
- X-API-Key: API key for cloud authentication
|
||||
|
||||
Body:
|
||||
{
|
||||
"command": "command_name",
|
||||
"params": {...}
|
||||
}
|
||||
"""
|
||||
global handlers
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
body = await request.json()
|
||||
command = body.get("command")
|
||||
params = body.get("params", {})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
|
||||
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Command is required")
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
# If cloud provider, perform authentication
|
||||
if server_container_name:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if container_name != server_container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}")
|
||||
raise HTTPException(status_code=401, detail="VM name mismatch")
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(f"API key validation failed. Status: {resp.status}")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
auth_response = await resp.json()
|
||||
if not auth_response.get("success"):
|
||||
logger.warning(f"API key validation failed. Response: {auth_response}")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
logger.info("Authentication successful")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Authentication service unavailable")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during authentication: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Authentication failed")
|
||||
|
||||
if command not in handlers:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
|
||||
|
||||
async def generate_response():
|
||||
"""Generate streaming response for the command execution"""
|
||||
try:
|
||||
# Filter params to only include those accepted by the handler function
|
||||
handler_func = handlers[command]
|
||||
sig = inspect.signature(handler_func)
|
||||
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
|
||||
|
||||
# Execute the command
|
||||
result = await handler_func(**filtered_params)
|
||||
|
||||
# Stream the successful result
|
||||
response_data = {"success": True, **result}
|
||||
yield f"data: {json.dumps(response_data)}\n\n"
|
||||
|
||||
except Exception as cmd_error:
|
||||
logger.error(f"Error executing command {command}: {str(cmd_error)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Stream the error result
|
||||
error_data = {"success": False, "error": str(cmd_error)}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, X-Container-Name, X-API-Key"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
Reference in New Issue
Block a user