mirror of
https://github.com/trycua/computer.git
synced 2026-01-07 05:50:13 -06:00
Merge pull request #318 from trycua/fix/networking-fallbacks
Add REST API Support with WebSocket Fallback for Computer Server
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing import List, Dict, Any, Optional
|
||||
import uvicorn
|
||||
import logging
|
||||
import asyncio
|
||||
@@ -11,6 +12,8 @@ from io import StringIO
|
||||
from .handlers.factory import HandlerFactory
|
||||
import os
|
||||
import aiohttp
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
# Set up logging with more detail
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -27,12 +30,156 @@ app = FastAPI(
|
||||
websocket_max_size=WEBSOCKET_MAX_SIZE,
|
||||
)
|
||||
|
||||
protocol_version = 1
|
||||
try:
|
||||
import pkg_resources
|
||||
package_version = pkg_resources.get_distribution("cua-computer-server").version
|
||||
except pkg_resources.DistributionNotFound:
|
||||
package_version = "unknown"
|
||||
|
||||
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
|
||||
handlers = {
|
||||
"version": lambda: {"protocol": protocol_version, "package": package_version},
|
||||
# App-Use commands
|
||||
"diorama_cmd": diorama_handler.diorama_cmd,
|
||||
# Accessibility commands
|
||||
"get_accessibility_tree": accessibility_handler.get_accessibility_tree,
|
||||
"find_element": accessibility_handler.find_element,
|
||||
# Shell commands
|
||||
"run_command": automation_handler.run_command,
|
||||
# File system commands
|
||||
"file_exists": file_handler.file_exists,
|
||||
"directory_exists": file_handler.directory_exists,
|
||||
"list_dir": file_handler.list_dir,
|
||||
"read_text": file_handler.read_text,
|
||||
"write_text": file_handler.write_text,
|
||||
"read_bytes": file_handler.read_bytes,
|
||||
"write_bytes": file_handler.write_bytes,
|
||||
"get_file_size": file_handler.get_file_size,
|
||||
"delete_file": file_handler.delete_file,
|
||||
"create_dir": file_handler.create_dir,
|
||||
"delete_dir": file_handler.delete_dir,
|
||||
# Mouse commands
|
||||
"mouse_down": automation_handler.mouse_down,
|
||||
"mouse_up": automation_handler.mouse_up,
|
||||
"left_click": automation_handler.left_click,
|
||||
"right_click": automation_handler.right_click,
|
||||
"double_click": automation_handler.double_click,
|
||||
"move_cursor": automation_handler.move_cursor,
|
||||
"drag_to": automation_handler.drag_to,
|
||||
"drag": automation_handler.drag,
|
||||
# Keyboard commands
|
||||
"key_down": automation_handler.key_down,
|
||||
"key_up": automation_handler.key_up,
|
||||
"type_text": automation_handler.type_text,
|
||||
"press_key": automation_handler.press_key,
|
||||
"hotkey": automation_handler.hotkey,
|
||||
# Scrolling actions
|
||||
"scroll": automation_handler.scroll,
|
||||
"scroll_down": automation_handler.scroll_down,
|
||||
"scroll_up": automation_handler.scroll_up,
|
||||
# Screen actions
|
||||
"screenshot": automation_handler.screenshot,
|
||||
"get_cursor_position": automation_handler.get_cursor_position,
|
||||
"get_screen_size": automation_handler.get_screen_size,
|
||||
# Clipboard actions
|
||||
"copy_to_clipboard": automation_handler.copy_to_clipboard,
|
||||
"set_clipboard": automation_handler.set_clipboard,
|
||||
}
|
||||
|
||||
|
||||
class AuthenticationManager:
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self.container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
def _hash_credentials(self, container_name: str, api_key: str) -> str:
|
||||
"""Create a hash of container name and API key for session identification"""
|
||||
combined = f"{container_name}:{api_key}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
|
||||
"""Check if a session is still valid based on expiration time"""
|
||||
if not session_data.get('valid', False):
|
||||
return False
|
||||
|
||||
expires_at = session_data.get('expires_at', 0)
|
||||
return time.time() < expires_at
|
||||
|
||||
async def auth(self, container_name: str, api_key: str) -> bool:
|
||||
"""Authenticate container name and API key, using cached sessions when possible"""
|
||||
# If no CONTAINER_NAME is set, always allow access (local development)
|
||||
if not self.container_name:
|
||||
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
|
||||
return True
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if container_name != self.container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
|
||||
return False
|
||||
|
||||
# Create hash for session lookup
|
||||
session_hash = self._hash_credentials(container_name, api_key)
|
||||
|
||||
# Check if we have a valid cached session
|
||||
if session_hash in self.sessions:
|
||||
session_data = self.sessions[session_hash]
|
||||
if self._is_session_valid(session_data):
|
||||
logger.info(f"Using cached authentication for container: {container_name}")
|
||||
return session_data['valid']
|
||||
else:
|
||||
# Remove expired session
|
||||
del self.sessions[session_hash]
|
||||
|
||||
# No valid cached session, authenticate with API
|
||||
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
is_valid = resp.status == 200 and bool((await resp.text()).strip())
|
||||
|
||||
# Cache the result with 5 second expiration
|
||||
self.sessions[session_hash] = {
|
||||
'valid': is_valid,
|
||||
'expires_at': time.time() + 5 # 5 seconds from now
|
||||
}
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Authentication successful for container: {container_name}")
|
||||
else:
|
||||
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
|
||||
|
||||
return is_valid
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + 5
|
||||
}
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during authentication: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + 5
|
||||
}
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
# Create OS-specific handlers
|
||||
self.accessibility_handler, self.automation_handler, self.diorama_handler, self.file_handler = HandlerFactory.create_handlers()
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
@@ -43,20 +190,23 @@ class ConnectionManager:
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
auth_manager = AuthenticationManager()
|
||||
|
||||
|
||||
@app.websocket("/ws", name="websocket_endpoint")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
global handlers
|
||||
|
||||
# WebSocket message size is configured at the app or endpoint level, not on the instance
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
container_name = os.environ.get("CONTAINER_NAME")
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
# If cloud provider, perform authentication handshake
|
||||
if container_name:
|
||||
if server_container_name:
|
||||
try:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
|
||||
|
||||
# Wait for authentication message
|
||||
auth_data = await websocket.receive_json()
|
||||
@@ -75,18 +225,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
client_api_key = auth_data.get("params", {}).get("api_key")
|
||||
client_container_name = auth_data.get("params", {}).get("container_name")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if client_container_name != container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM name mismatch"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
# Validate credentials using AuthenticationManager
|
||||
if not client_api_key:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
@@ -96,58 +235,34 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {client_api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_msg = await resp.text()
|
||||
logger.warning(f"API validation failed: {error_msg}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# If we get a 200 response with VNC URL, the VM exists and user has access
|
||||
vnc_url = (await resp.text()).strip()
|
||||
if not vnc_url:
|
||||
logger.warning(f"No VNC URL returned for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM not found"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authenticated"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating with TryCUA API: {e}")
|
||||
if not client_container_name:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication service unavailable"
|
||||
"error": "Container name required"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
# Use AuthenticationManager for validation
|
||||
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
|
||||
if not is_authenticated:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
logger.error(f"Error during authentication handshake: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
@@ -156,55 +271,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Map commands to appropriate handler methods
|
||||
handlers = {
|
||||
# App-Use commands
|
||||
"diorama_cmd": manager.diorama_handler.diorama_cmd,
|
||||
# Accessibility commands
|
||||
"get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree,
|
||||
"find_element": manager.accessibility_handler.find_element,
|
||||
# Shell commands
|
||||
"run_command": manager.automation_handler.run_command,
|
||||
# File system commands
|
||||
"file_exists": manager.file_handler.file_exists,
|
||||
"directory_exists": manager.file_handler.directory_exists,
|
||||
"list_dir": manager.file_handler.list_dir,
|
||||
"read_text": manager.file_handler.read_text,
|
||||
"write_text": manager.file_handler.write_text,
|
||||
"read_bytes": manager.file_handler.read_bytes,
|
||||
"write_bytes": manager.file_handler.write_bytes,
|
||||
"get_file_size": manager.file_handler.get_file_size,
|
||||
"delete_file": manager.file_handler.delete_file,
|
||||
"create_dir": manager.file_handler.create_dir,
|
||||
"delete_dir": manager.file_handler.delete_dir,
|
||||
# Mouse commands
|
||||
"mouse_down": manager.automation_handler.mouse_down,
|
||||
"mouse_up": manager.automation_handler.mouse_up,
|
||||
"left_click": manager.automation_handler.left_click,
|
||||
"right_click": manager.automation_handler.right_click,
|
||||
"double_click": manager.automation_handler.double_click,
|
||||
"move_cursor": manager.automation_handler.move_cursor,
|
||||
"drag_to": manager.automation_handler.drag_to,
|
||||
"drag": manager.automation_handler.drag,
|
||||
# Keyboard commands
|
||||
"key_down": manager.automation_handler.key_down,
|
||||
"key_up": manager.automation_handler.key_up,
|
||||
"type_text": manager.automation_handler.type_text,
|
||||
"press_key": manager.automation_handler.press_key,
|
||||
"hotkey": manager.automation_handler.hotkey,
|
||||
# Scrolling actions
|
||||
"scroll": manager.automation_handler.scroll,
|
||||
"scroll_down": manager.automation_handler.scroll_down,
|
||||
"scroll_up": manager.automation_handler.scroll_up,
|
||||
# Screen actions
|
||||
"screenshot": manager.automation_handler.screenshot,
|
||||
"get_cursor_position": manager.automation_handler.get_cursor_position,
|
||||
"get_screen_size": manager.automation_handler.get_screen_size,
|
||||
# Clipboard actions
|
||||
"copy_to_clipboard": manager.automation_handler.copy_to_clipboard,
|
||||
"set_clipboard": manager.automation_handler.set_clipboard,
|
||||
}
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
@@ -224,7 +290,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
sig = inspect.signature(handler_func)
|
||||
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
|
||||
|
||||
result = await handler_func(**filtered_params)
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutinefunction(handler_func):
|
||||
result = await handler_func(**filtered_params)
|
||||
else:
|
||||
# Run sync functions in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(handler_func, **filtered_params)
|
||||
await websocket.send_json({"success": True, **result})
|
||||
except Exception as cmd_error:
|
||||
logger.error(f"Error executing command {command}: {str(cmd_error)}")
|
||||
@@ -256,5 +327,100 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
|
||||
|
||||
@app.post("/cmd")
|
||||
async def cmd_endpoint(
|
||||
request: Request,
|
||||
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
|
||||
api_key: Optional[str] = Header(None, alias="X-API-Key")
|
||||
):
|
||||
"""
|
||||
Backup endpoint for when WebSocket connections fail.
|
||||
Accepts commands via HTTP POST with streaming response.
|
||||
|
||||
Headers:
|
||||
- X-Container-Name: Container name for cloud authentication
|
||||
- X-API-Key: API key for cloud authentication
|
||||
|
||||
Body:
|
||||
{
|
||||
"command": "command_name",
|
||||
"params": {...}
|
||||
}
|
||||
"""
|
||||
global handlers
|
||||
|
||||
# Parse request body
|
||||
try:
|
||||
body = await request.json()
|
||||
command = body.get("command")
|
||||
params = body.get("params", {})
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
|
||||
|
||||
if not command:
|
||||
raise HTTPException(status_code=400, detail="Command is required")
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
# If cloud provider, perform authentication
|
||||
if server_container_name:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
|
||||
|
||||
# Validate required headers
|
||||
if not container_name:
|
||||
raise HTTPException(status_code=401, detail="Container name required")
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
# Validate with AuthenticationManager
|
||||
is_authenticated = await auth_manager.auth(container_name, api_key)
|
||||
if not is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
if command not in handlers:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
|
||||
|
||||
async def generate_response():
|
||||
"""Generate streaming response for the command execution"""
|
||||
try:
|
||||
# Filter params to only include those accepted by the handler function
|
||||
handler_func = handlers[command]
|
||||
sig = inspect.signature(handler_func)
|
||||
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
|
||||
|
||||
# Handle both sync and async functions
|
||||
if asyncio.iscoroutinefunction(handler_func):
|
||||
result = await handler_func(**filtered_params)
|
||||
else:
|
||||
# Run sync functions in thread pool to avoid blocking event loop
|
||||
result = await asyncio.to_thread(handler_func, **filtered_params)
|
||||
|
||||
# Stream the successful result
|
||||
response_data = {"success": True, **result}
|
||||
yield f"data: {json.dumps(response_data)}\n\n"
|
||||
|
||||
except Exception as cmd_error:
|
||||
logger.error(f"Error executing command {command}: {str(cmd_error)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Stream the error result
|
||||
error_data = {"success": False, "error": str(cmd_error)}
|
||||
yield f"data: {json.dumps(error_data)}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_response(),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "POST, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Content-Type, X-Container-Name, X-API-Key"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
"""
|
||||
Connection test script for Computer Server.
|
||||
|
||||
This script tests the WebSocket connection to the Computer Server and keeps
|
||||
it alive, allowing you to verify the server is running correctly.
|
||||
This script tests both WebSocket (/ws) and REST (/cmd) connections to the Computer Server
|
||||
and keeps it alive, allowing you to verify the server is running correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -11,10 +11,14 @@ import json
|
||||
import websockets
|
||||
import argparse
|
||||
import sys
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None):
|
||||
"""Test connection to the Computer Server."""
|
||||
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
"""Test WebSocket connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use WSS with container domain and port 8443
|
||||
uri = f"wss://{container_name}.containers.cloud.trycua.com:8443/ws"
|
||||
@@ -26,15 +30,45 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("Connection established!")
|
||||
print("WebSocket connection established!")
|
||||
|
||||
# If container connection, send authentication first
|
||||
if container_name:
|
||||
if not api_key:
|
||||
print("Error: API key required for container connections")
|
||||
return False
|
||||
|
||||
print("Sending authentication...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": api_key,
|
||||
"container_name": container_name
|
||||
}
|
||||
}
|
||||
await websocket.send(json.dumps(auth_message))
|
||||
auth_response = await websocket.recv()
|
||||
print(f"Authentication response: {auth_response}")
|
||||
|
||||
# Check if authentication was successful
|
||||
auth_data = json.loads(auth_response)
|
||||
if not auth_data.get("success", False):
|
||||
print("Authentication failed!")
|
||||
return False
|
||||
print("Authentication successful!")
|
||||
|
||||
# Send a test command to get version
|
||||
await websocket.send(json.dumps({"command": "version", "params": {}}))
|
||||
response = await websocket.recv()
|
||||
print(f"Version response: {response}")
|
||||
|
||||
# Send a test command to get screen size
|
||||
await websocket.send(json.dumps({"command": "get_screen_size", "params": {}}))
|
||||
response = await websocket.recv()
|
||||
print(f"Response: {response}")
|
||||
print(f"Screen size response: {response}")
|
||||
|
||||
if keep_alive:
|
||||
print("\nKeeping connection alive. Press Ctrl+C to exit...")
|
||||
print("\nKeeping WebSocket connection alive. Press Ctrl+C to exit...")
|
||||
while True:
|
||||
# Send a command every 5 seconds to keep the connection alive
|
||||
await asyncio.sleep(5)
|
||||
@@ -44,24 +78,115 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
|
||||
response = await websocket.recv()
|
||||
print(f"Cursor position: {response}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"Connection closed: {e}")
|
||||
print(f"WebSocket connection closed: {e}")
|
||||
return False
|
||||
except ConnectionRefusedError:
|
||||
print(f"Connection refused. Is the server running at {host}:{port}?")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"WebSocket error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
"""Test REST connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use HTTPS with container domain and port 8443
|
||||
base_url = f"https://{container_name}.containers.cloud.trycua.com:8443"
|
||||
print(f"Connecting to container {container_name} at {base_url}...")
|
||||
else:
|
||||
# Local mode: use HTTP with specified host and port
|
||||
base_url = f"http://{host}:{port}"
|
||||
print(f"Connecting to local server at {base_url}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print("REST connection established!")
|
||||
|
||||
# Prepare headers for container authentication
|
||||
headers = {}
|
||||
if container_name:
|
||||
if not api_key:
|
||||
print("Error: API key required for container connections")
|
||||
return False
|
||||
headers["X-Container-Name"] = container_name
|
||||
headers["X-API-Key"] = api_key
|
||||
print(f"Using container authentication headers")
|
||||
|
||||
# Test screenshot endpoint
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "screenshot", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Screenshot response: {text}")
|
||||
else:
|
||||
print(f"Screenshot request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
# Test screen size endpoint
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_screen_size", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Screen size response: {text}")
|
||||
else:
|
||||
print(f"Screen size request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
if keep_alive:
|
||||
print("\nKeeping REST connection alive. Press Ctrl+C to exit...")
|
||||
while True:
|
||||
# Send a command every 5 seconds to keep testing
|
||||
await asyncio.sleep(5)
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_cursor_position", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Cursor position: {text}")
|
||||
else:
|
||||
print(f"Cursor position request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"REST connection error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"REST error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
|
||||
"""Test connection to the Computer Server using WebSocket or REST."""
|
||||
if use_rest:
|
||||
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
|
||||
else:
|
||||
return await test_websocket_connection(host, port, keep_alive, container_name, api_key)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
|
||||
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)")
|
||||
parser.add_argument("--container-name", help="Container name for cloud connection (uses WSS and port 8443)")
|
||||
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
|
||||
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
|
||||
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
|
||||
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
|
||||
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -71,11 +196,27 @@ async def main():
|
||||
# Convert hyphenated argument to underscore for function parameter
|
||||
container_name = getattr(args, 'container_name', None)
|
||||
|
||||
# Get API key from argument or environment variable
|
||||
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
|
||||
|
||||
# Check if container name is provided but API key is missing
|
||||
if container_name and not api_key:
|
||||
print("Warning: Container name provided but no API key found.")
|
||||
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
|
||||
return 1
|
||||
|
||||
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
|
||||
if container_name:
|
||||
print(f"Container: {container_name}")
|
||||
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
|
||||
|
||||
success = await test_connection(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
keep_alive=args.keep_alive,
|
||||
container_name=container_name
|
||||
container_name=container_name,
|
||||
use_rest=args.rest,
|
||||
api_key=api_key
|
||||
)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
@@ -753,7 +753,7 @@ class Computer:
|
||||
|
||||
|
||||
# Add virtual environment management functions to computer interface
|
||||
async def venv_install(self, venv_name: str, requirements: list[str]) -> tuple[str, str]:
|
||||
async def venv_install(self, venv_name: str, requirements: list[str]):
|
||||
"""Install packages in a virtual environment.
|
||||
|
||||
Args:
|
||||
@@ -771,14 +771,14 @@ class Computer:
|
||||
|
||||
# Check if venv exists, if not create it
|
||||
check_cmd = f"test -d {venv_path} || ({create_cmd})"
|
||||
_, _ = await self.interface.run_command(check_cmd)
|
||||
_ = await self.interface.run_command(check_cmd)
|
||||
|
||||
# Install packages
|
||||
requirements_str = " ".join(requirements)
|
||||
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
async def venv_cmd(self, venv_name: str, command: str) -> tuple[str, str]:
|
||||
async def venv_cmd(self, venv_name: str, command: str):
|
||||
"""Execute a shell command in a virtual environment.
|
||||
|
||||
Args:
|
||||
@@ -792,9 +792,9 @@ class Computer:
|
||||
|
||||
# Check if virtual environment exists
|
||||
check_cmd = f"test -d {venv_path}"
|
||||
stdout, stderr = await self.interface.run_command(check_cmd)
|
||||
result = await self.interface.run_command(check_cmd)
|
||||
|
||||
if stderr or "test:" in stdout: # venv doesn't exist
|
||||
if result.stderr or "test:" in result.stdout: # venv doesn't exist
|
||||
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
|
||||
|
||||
# Activate virtual environment and run command
|
||||
@@ -890,21 +890,21 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
|
||||
# Execute the Python code in the virtual environment
|
||||
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
|
||||
stdout, stderr = await self.venv_cmd(venv_name, python_command)
|
||||
result = await self.venv_cmd(venv_name, python_command)
|
||||
|
||||
# Parse the output to extract the payload
|
||||
start_marker = "<<<VENV_EXEC_START>>>"
|
||||
end_marker = "<<<VENV_EXEC_END>>>"
|
||||
|
||||
# Print original stdout
|
||||
print(stdout[:stdout.find(start_marker)])
|
||||
print(result.stdout[:result.stdout.find(start_marker)])
|
||||
|
||||
if start_marker in stdout and end_marker in stdout:
|
||||
start_idx = stdout.find(start_marker) + len(start_marker)
|
||||
end_idx = stdout.find(end_marker)
|
||||
if start_marker in result.stdout and end_marker in result.stdout:
|
||||
start_idx = result.stdout.find(start_marker) + len(start_marker)
|
||||
end_idx = result.stdout.find(end_marker)
|
||||
|
||||
if start_idx < end_idx:
|
||||
output_json = stdout[start_idx:end_idx]
|
||||
output_json = result.stdout[start_idx:end_idx]
|
||||
|
||||
try:
|
||||
# Decode and deserialize the output payload from JSON
|
||||
@@ -923,4 +923,4 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
raise Exception("Invalid output format: markers found but no content between them")
|
||||
else:
|
||||
# Fallback: return stdout/stderr if no payload markers found
|
||||
raise Exception(f"No output payload found. stdout: {stdout}, stderr: {stderr}")
|
||||
raise Exception(f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}")
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
import aiohttp
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .base import BaseComputerInterface
|
||||
@@ -57,6 +58,17 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
@property
|
||||
def rest_uri(self) -> str:
|
||||
"""Get the REST URI using the current IP address.
|
||||
|
||||
Returns:
|
||||
REST URI for the Computer API Server
|
||||
"""
|
||||
protocol = "https" if self.api_key else "http"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/cmd"
|
||||
|
||||
# Mouse actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
|
||||
@@ -677,7 +689,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
async def _send_command_ws(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through WebSocket."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
@@ -717,7 +729,151 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through REST API without retries or connection management."""
|
||||
try:
|
||||
# Prepare the request payload
|
||||
payload = {"command": command, "params": params or {}}
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
if self.vm_name:
|
||||
headers["X-Container-Name"] = self.vm_name
|
||||
|
||||
# Send the request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.rest_uri,
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as response:
|
||||
# Get the response text
|
||||
response_text = await response.text()
|
||||
|
||||
# Trim whitespace
|
||||
response_text = response_text.strip()
|
||||
|
||||
# Check if it starts with "data: "
|
||||
if response_text.startswith("data: "):
|
||||
# Extract everything after "data: "
|
||||
json_str = response_text[6:] # Remove "data: " prefix
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
}
|
||||
else:
|
||||
# Return error response
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request failed",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command using REST API with WebSocket fallback."""
|
||||
# Try REST API first
|
||||
result = await self._send_command_rest(command, params)
|
||||
|
||||
# If REST failed with "Request failed", try WebSocket as fallback
|
||||
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
|
||||
self.logger.debug(f"REST API failed for command '{command}', trying WebSocket fallback")
|
||||
try:
|
||||
return await self._send_command_ws(command, params)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"WebSocket fallback also failed: {e}")
|
||||
# Return the original REST error
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for Computer API Server to be ready by testing version command."""
|
||||
|
||||
# Check if REST API is available
|
||||
try:
|
||||
result = await self._send_command_rest("version", {})
|
||||
assert result.get("success", True)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
|
||||
try:
|
||||
await self._wait_for_ready_ws(timeout, interval)
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.debug(f"WebSocket fallback also failed: {e}")
|
||||
raise e
|
||||
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
attempt_count = 0
|
||||
progress_interval = 10 # Log progress every 10 seconds
|
||||
last_progress_time = start_time
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
|
||||
)
|
||||
|
||||
# Wait for the server to respond to get_screen_size command
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
attempt_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# Log progress periodically without flooding logs
|
||||
if current_time - last_progress_time >= progress_interval:
|
||||
elapsed = current_time - start_time
|
||||
self.logger.info(
|
||||
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
|
||||
)
|
||||
last_progress_time = current_time
|
||||
|
||||
# Test the server with a simple get_screen_size command
|
||||
result = await self._send_command("get_screen_size")
|
||||
if result.get("success", False):
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
)
|
||||
return # Server is ready
|
||||
else:
|
||||
last_error = result.get("error", "Unknown error")
|
||||
self.logger.debug(f"Initial connection command failed: {last_error}")
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
|
||||
|
||||
# Wait before trying again
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
|
||||
if last_error:
|
||||
error_msg += f": {str(last_error)}"
|
||||
self.logger.error(error_msg)
|
||||
raise TimeoutError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, TimeoutError):
|
||||
raise
|
||||
error_msg = f"Error while waiting for server: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def _wait_for_ready_ws(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for WebSocket connection to become available."""
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
@@ -755,7 +911,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
# Test the connection with a simple command
|
||||
try:
|
||||
await self._send_command("get_screen_size")
|
||||
await self._send_command_ws("get_screen_size")
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
|
||||
Reference in New Issue
Block a user