Merge pull request #318 from trycua/fix/networking-fallbacks

Add REST API Support with WebSocket Fallback for Computer Server
This commit is contained in:
ddupont
2025-07-10 17:23:55 -04:00
committed by GitHub
4 changed files with 603 additions and 140 deletions

View File

@@ -1,5 +1,6 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from typing import List, Dict, Any
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
from fastapi.responses import StreamingResponse
from typing import List, Dict, Any, Optional
import uvicorn
import logging
import asyncio
@@ -11,6 +12,8 @@ from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
import hashlib
import time
# Set up logging with more detail
logger = logging.getLogger(__name__)
@@ -27,12 +30,156 @@ app = FastAPI(
websocket_max_size=WEBSOCKET_MAX_SIZE,
)
protocol_version = 1
try:
import pkg_resources
package_version = pkg_resources.get_distribution("cua-computer-server").version
except pkg_resources.DistributionNotFound:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
"diorama_cmd": diorama_handler.diorama_cmd,
# Accessibility commands
"get_accessibility_tree": accessibility_handler.get_accessibility_tree,
"find_element": accessibility_handler.find_element,
# Shell commands
"run_command": automation_handler.run_command,
# File system commands
"file_exists": file_handler.file_exists,
"directory_exists": file_handler.directory_exists,
"list_dir": file_handler.list_dir,
"read_text": file_handler.read_text,
"write_text": file_handler.write_text,
"read_bytes": file_handler.read_bytes,
"write_bytes": file_handler.write_bytes,
"get_file_size": file_handler.get_file_size,
"delete_file": file_handler.delete_file,
"create_dir": file_handler.create_dir,
"delete_dir": file_handler.delete_dir,
# Mouse commands
"mouse_down": automation_handler.mouse_down,
"mouse_up": automation_handler.mouse_up,
"left_click": automation_handler.left_click,
"right_click": automation_handler.right_click,
"double_click": automation_handler.double_click,
"move_cursor": automation_handler.move_cursor,
"drag_to": automation_handler.drag_to,
"drag": automation_handler.drag,
# Keyboard commands
"key_down": automation_handler.key_down,
"key_up": automation_handler.key_up,
"type_text": automation_handler.type_text,
"press_key": automation_handler.press_key,
"hotkey": automation_handler.hotkey,
# Scrolling actions
"scroll": automation_handler.scroll,
"scroll_down": automation_handler.scroll_down,
"scroll_up": automation_handler.scroll_up,
# Screen actions
"screenshot": automation_handler.screenshot,
"get_cursor_position": automation_handler.get_cursor_position,
"get_screen_size": automation_handler.get_screen_size,
# Clipboard actions
"copy_to_clipboard": automation_handler.copy_to_clipboard,
"set_clipboard": automation_handler.set_clipboard,
}
class AuthenticationManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.container_name = os.environ.get("CONTAINER_NAME")
def _hash_credentials(self, container_name: str, api_key: str) -> str:
"""Create a hash of container name and API key for session identification"""
combined = f"{container_name}:{api_key}"
return hashlib.sha256(combined.encode()).hexdigest()
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
"""Check if a session is still valid based on expiration time"""
if not session_data.get('valid', False):
return False
expires_at = session_data.get('expires_at', 0)
return time.time() < expires_at
async def auth(self, container_name: str, api_key: str) -> bool:
"""Authenticate container name and API key, using cached sessions when possible"""
# If no CONTAINER_NAME is set, always allow access (local development)
if not self.container_name:
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
return True
# Layer 1: VM Identity Verification
if container_name != self.container_name:
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
return False
# Create hash for session lookup
session_hash = self._hash_credentials(container_name, api_key)
# Check if we have a valid cached session
if session_hash in self.sessions:
session_data = self.sessions[session_hash]
if self._is_session_valid(session_data):
logger.info(f"Using cached authentication for container: {container_name}")
return session_data['valid']
else:
# Remove expired session
del self.sessions[session_hash]
# No valid cached session, authenticate with API
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
is_valid = resp.status == 200 and bool((await resp.text()).strip())
# Cache the result with 5 second expiration
self.sessions[session_hash] = {
'valid': is_valid,
'expires_at': time.time() + 5 # 5 seconds from now
}
if is_valid:
logger.info(f"Authentication successful for container: {container_name}")
else:
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
return is_valid
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
# Create OS-specific handlers
self.accessibility_handler, self.automation_handler, self.diorama_handler, self.file_handler = HandlerFactory.create_handlers()
async def connect(self, websocket: WebSocket):
await websocket.accept()
@@ -43,20 +190,23 @@ class ConnectionManager:
manager = ConnectionManager()
auth_manager = AuthenticationManager()
@app.websocket("/ws", name="websocket_endpoint")
async def websocket_endpoint(websocket: WebSocket):
global handlers
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
container_name = os.environ.get("CONTAINER_NAME")
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if container_name:
if server_container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
@@ -75,18 +225,7 @@ async def websocket_endpoint(websocket: WebSocket):
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Layer 1: VM Identity Verification
if client_container_name != container_name:
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
await websocket.send_json({
"success": False,
"error": "VM name mismatch"
})
await websocket.close()
manager.disconnect(websocket)
return
# Layer 2: API Key Validation with TryCUA API
# Validate credentials using AuthenticationManager
if not client_api_key:
await websocket.send_json({
"success": False,
@@ -96,58 +235,34 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
return
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {client_api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
if resp.status != 200:
error_msg = await resp.text()
logger.warning(f"API validation failed: {error_msg}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# If we get a 200 response with VNC URL, the VM exists and user has access
vnc_url = (await resp.text()).strip()
if not vnc_url:
logger.warning(f"No VNC URL returned for VM: {container_name}")
await websocket.send_json({
"success": False,
"error": "VM not found"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {container_name}")
await websocket.send_json({
"success": True,
"message": "Authenticated"
})
except Exception as e:
logger.error(f"Error validating with TryCUA API: {e}")
if not client_container_name:
await websocket.send_json({
"success": False,
"error": "Authentication service unavailable"
"error": "Container name required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Use AuthenticationManager for validation
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
if not is_authenticated:
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {client_container_name}")
await websocket.send_json({
"success": True,
"message": "Authentication successful"
})
except Exception as e:
logger.error(f"Authentication error: {e}")
logger.error(f"Error during authentication handshake: {str(e)}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
@@ -156,55 +271,6 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
return
# Map commands to appropriate handler methods
handlers = {
# App-Use commands
"diorama_cmd": manager.diorama_handler.diorama_cmd,
# Accessibility commands
"get_accessibility_tree": manager.accessibility_handler.get_accessibility_tree,
"find_element": manager.accessibility_handler.find_element,
# Shell commands
"run_command": manager.automation_handler.run_command,
# File system commands
"file_exists": manager.file_handler.file_exists,
"directory_exists": manager.file_handler.directory_exists,
"list_dir": manager.file_handler.list_dir,
"read_text": manager.file_handler.read_text,
"write_text": manager.file_handler.write_text,
"read_bytes": manager.file_handler.read_bytes,
"write_bytes": manager.file_handler.write_bytes,
"get_file_size": manager.file_handler.get_file_size,
"delete_file": manager.file_handler.delete_file,
"create_dir": manager.file_handler.create_dir,
"delete_dir": manager.file_handler.delete_dir,
# Mouse commands
"mouse_down": manager.automation_handler.mouse_down,
"mouse_up": manager.automation_handler.mouse_up,
"left_click": manager.automation_handler.left_click,
"right_click": manager.automation_handler.right_click,
"double_click": manager.automation_handler.double_click,
"move_cursor": manager.automation_handler.move_cursor,
"drag_to": manager.automation_handler.drag_to,
"drag": manager.automation_handler.drag,
# Keyboard commands
"key_down": manager.automation_handler.key_down,
"key_up": manager.automation_handler.key_up,
"type_text": manager.automation_handler.type_text,
"press_key": manager.automation_handler.press_key,
"hotkey": manager.automation_handler.hotkey,
# Scrolling actions
"scroll": manager.automation_handler.scroll,
"scroll_down": manager.automation_handler.scroll_down,
"scroll_up": manager.automation_handler.scroll_up,
# Screen actions
"screenshot": manager.automation_handler.screenshot,
"get_cursor_position": manager.automation_handler.get_cursor_position,
"get_screen_size": manager.automation_handler.get_screen_size,
# Clipboard actions
"copy_to_clipboard": manager.automation_handler.copy_to_clipboard,
"set_clipboard": manager.automation_handler.set_clipboard,
}
try:
while True:
try:
@@ -224,7 +290,12 @@ async def websocket_endpoint(websocket: WebSocket):
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
result = await handler_func(**filtered_params)
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
await websocket.send_json({"success": True, **result})
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
@@ -256,5 +327,100 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
@app.post("/cmd")
async def cmd_endpoint(
request: Request,
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
api_key: Optional[str] = Header(None, alias="X-API-Key")
):
"""
Backup endpoint for when WebSocket connections fail.
Accepts commands via HTTP POST with streaming response.
Headers:
- X-Container-Name: Container name for cloud authentication
- X-API-Key: API key for cloud authentication
Body:
{
"command": "command_name",
"params": {...}
}
"""
global handlers
# Parse request body
try:
body = await request.json()
command = body.get("command")
params = body.get("params", {})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
# Validate required headers
if not container_name:
raise HTTPException(status_code=401, detail="Container name required")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with AuthenticationManager
is_authenticated = await auth_manager.auth(container_name, api_key)
if not is_authenticated:
raise HTTPException(status_code=401, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
async def generate_response():
"""Generate streaming response for the command execution"""
try:
# Filter params to only include those accepted by the handler function
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
# Stream the successful result
response_data = {"success": True, **result}
yield f"data: {json.dumps(response_data)}\n\n"
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
# Stream the error result
error_data = {"success": False, "error": str(cmd_error)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "POST, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, X-Container-Name, X-API-Key"
}
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -2,8 +2,8 @@
"""
Connection test script for Computer Server.
This script tests the WebSocket connection to the Computer Server and keeps
it alive, allowing you to verify the server is running correctly.
This script tests both WebSocket (/ws) and REST (/cmd) connections to the Computer Server
and keeps it alive, allowing you to verify the server is running correctly.
"""
import asyncio
@@ -11,10 +11,14 @@ import json
import websockets
import argparse
import sys
import aiohttp
import os
import dotenv
dotenv.load_dotenv()
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None):
"""Test connection to the Computer Server."""
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
"""Test WebSocket connection to the Computer Server."""
if container_name:
# Container mode: use WSS with container domain and port 8443
uri = f"wss://{container_name}.containers.cloud.trycua.com:8443/ws"
@@ -26,15 +30,45 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
try:
async with websockets.connect(uri) as websocket:
print("Connection established!")
print("WebSocket connection established!")
# If container connection, send authentication first
if container_name:
if not api_key:
print("Error: API key required for container connections")
return False
print("Sending authentication...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": api_key,
"container_name": container_name
}
}
await websocket.send(json.dumps(auth_message))
auth_response = await websocket.recv()
print(f"Authentication response: {auth_response}")
# Check if authentication was successful
auth_data = json.loads(auth_response)
if not auth_data.get("success", False):
print("Authentication failed!")
return False
print("Authentication successful!")
# Send a test command to get version
await websocket.send(json.dumps({"command": "version", "params": {}}))
response = await websocket.recv()
print(f"Version response: {response}")
# Send a test command to get screen size
await websocket.send(json.dumps({"command": "get_screen_size", "params": {}}))
response = await websocket.recv()
print(f"Response: {response}")
print(f"Screen size response: {response}")
if keep_alive:
print("\nKeeping connection alive. Press Ctrl+C to exit...")
print("\nKeeping WebSocket connection alive. Press Ctrl+C to exit...")
while True:
# Send a command every 5 seconds to keep the connection alive
await asyncio.sleep(5)
@@ -44,24 +78,115 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
response = await websocket.recv()
print(f"Cursor position: {response}")
except websockets.exceptions.ConnectionClosed as e:
print(f"Connection closed: {e}")
print(f"WebSocket connection closed: {e}")
return False
except ConnectionRefusedError:
print(f"Connection refused. Is the server running at {host}:{port}?")
return False
except Exception as e:
print(f"Error: {e}")
print(f"WebSocket error: {e}")
return False
return True
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
"""Test REST connection to the Computer Server."""
if container_name:
# Container mode: use HTTPS with container domain and port 8443
base_url = f"https://{container_name}.containers.cloud.trycua.com:8443"
print(f"Connecting to container {container_name} at {base_url}...")
else:
# Local mode: use HTTP with specified host and port
base_url = f"http://{host}:{port}"
print(f"Connecting to local server at {base_url}...")
try:
async with aiohttp.ClientSession() as session:
print("REST connection established!")
# Prepare headers for container authentication
headers = {}
if container_name:
if not api_key:
print("Error: API key required for container connections")
return False
headers["X-Container-Name"] = container_name
headers["X-API-Key"] = api_key
print(f"Using container authentication headers")
# Test screenshot endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "screenshot", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Screenshot response: {text}")
else:
print(f"Screenshot request failed with status: {response.status}")
print(await response.text())
return False
# Test screen size endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "get_screen_size", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Screen size response: {text}")
else:
print(f"Screen size request failed with status: {response.status}")
print(await response.text())
return False
if keep_alive:
print("\nKeeping REST connection alive. Press Ctrl+C to exit...")
while True:
# Send a command every 5 seconds to keep testing
await asyncio.sleep(5)
async with session.post(
f"{base_url}/cmd",
json={"command": "get_cursor_position", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Cursor position: {text}")
else:
print(f"Cursor position request failed with status: {response.status}")
print(await response.text())
return False
except aiohttp.ClientError as e:
print(f"REST connection error: {e}")
return False
except Exception as e:
print(f"REST error: {e}")
return False
return True
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
"""Test connection to the Computer Server using WebSocket or REST."""
if use_rest:
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
else:
return await test_websocket_connection(host, port, keep_alive, container_name, api_key)
def parse_args():
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("--container-name", help="Container name for cloud connection (uses WSS and port 8443)")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
return parser.parse_args()
@@ -71,11 +196,27 @@ async def main():
# Convert hyphenated argument to underscore for function parameter
container_name = getattr(args, 'container_name', None)
# Get API key from argument or environment variable
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
# Check if container name is provided but API key is missing
if container_name and not api_key:
print("Warning: Container name provided but no API key found.")
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
return 1
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
if container_name:
print(f"Container: {container_name}")
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
success = await test_connection(
host=args.host,
port=args.port,
keep_alive=args.keep_alive,
container_name=container_name
container_name=container_name,
use_rest=args.rest,
api_key=api_key
)
return 0 if success else 1

View File

@@ -753,7 +753,7 @@ class Computer:
# Add virtual environment management functions to computer interface
async def venv_install(self, venv_name: str, requirements: list[str]) -> tuple[str, str]:
async def venv_install(self, venv_name: str, requirements: list[str]):
"""Install packages in a virtual environment.
Args:
@@ -771,14 +771,14 @@ class Computer:
# Check if venv exists, if not create it
check_cmd = f"test -d {venv_path} || ({create_cmd})"
_, _ = await self.interface.run_command(check_cmd)
_ = await self.interface.run_command(check_cmd)
# Install packages
requirements_str = " ".join(requirements)
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
return await self.interface.run_command(install_cmd)
async def venv_cmd(self, venv_name: str, command: str) -> tuple[str, str]:
async def venv_cmd(self, venv_name: str, command: str):
"""Execute a shell command in a virtual environment.
Args:
@@ -792,9 +792,9 @@ class Computer:
# Check if virtual environment exists
check_cmd = f"test -d {venv_path}"
stdout, stderr = await self.interface.run_command(check_cmd)
result = await self.interface.run_command(check_cmd)
if stderr or "test:" in stdout: # venv doesn't exist
if result.stderr or "test:" in result.stdout: # venv doesn't exist
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
# Activate virtual environment and run command
@@ -890,21 +890,21 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
# Execute the Python code in the virtual environment
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
stdout, stderr = await self.venv_cmd(venv_name, python_command)
result = await self.venv_cmd(venv_name, python_command)
# Parse the output to extract the payload
start_marker = "<<<VENV_EXEC_START>>>"
end_marker = "<<<VENV_EXEC_END>>>"
# Print original stdout
print(stdout[:stdout.find(start_marker)])
print(result.stdout[:result.stdout.find(start_marker)])
if start_marker in stdout and end_marker in stdout:
start_idx = stdout.find(start_marker) + len(start_marker)
end_idx = stdout.find(end_marker)
if start_marker in result.stdout and end_marker in result.stdout:
start_idx = result.stdout.find(start_marker) + len(start_marker)
end_idx = result.stdout.find(end_marker)
if start_idx < end_idx:
output_json = stdout[start_idx:end_idx]
output_json = result.stdout[start_idx:end_idx]
try:
# Decode and deserialize the output payload from JSON
@@ -923,4 +923,4 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
raise Exception("Invalid output format: markers found but no content between them")
else:
# Fallback: return stdout/stderr if no payload markers found
raise Exception(f"No output payload found. stdout: {stdout}, stderr: {stderr}")
raise Exception(f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}")

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import websockets
import aiohttp
from ..logger import Logger, LogLevel
from .base import BaseComputerInterface
@@ -57,6 +58,17 @@ class GenericComputerInterface(BaseComputerInterface):
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
@property
def rest_uri(self) -> str:
"""Get the REST URI using the current IP address.
Returns:
REST URI for the Computer API Server
"""
protocol = "https" if self.api_key else "http"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/cmd"
# Mouse actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
@@ -677,7 +689,7 @@ class GenericComputerInterface(BaseComputerInterface):
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
async def _send_command_ws(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through WebSocket."""
max_retries = 3
retry_count = 0
@@ -717,7 +729,151 @@ class GenericComputerInterface(BaseComputerInterface):
raise last_error if last_error else RuntimeError("Failed to send command")
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through REST API without retries or connection management."""
try:
# Prepare the request payload
payload = {"command": command, "params": params or {}}
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
if self.vm_name:
headers["X-Container-Name"] = self.vm_name
# Send the request
async with aiohttp.ClientSession() as session:
async with session.post(
self.rest_uri,
json=payload,
headers=headers
) as response:
# Get the response text
response_text = await response.text()
# Trim whitespace
response_text = response_text.strip()
# Check if it starts with "data: "
if response_text.startswith("data: "):
# Extract everything after "data: "
json_str = response_text[6:] # Remove "data: " prefix
try:
return json.loads(json_str)
except json.JSONDecodeError:
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
}
else:
# Return error response
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
}
except Exception as e:
return {
"success": False,
"error": "Request failed",
"message": str(e)
}
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command using REST API with WebSocket fallback."""
# Try REST API first
result = await self._send_command_rest(command, params)
# If REST failed with "Request failed", try WebSocket as fallback
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
self.logger.debug(f"REST API failed for command '{command}', trying WebSocket fallback")
try:
return await self._send_command_ws(command, params)
except Exception as e:
self.logger.debug(f"WebSocket fallback also failed: {e}")
# Return the original REST error
return result
return result
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
"""Wait for Computer API Server to be ready by testing version command."""
# Check if REST API is available
try:
result = await self._send_command_rest("version", {})
assert result.get("success", True)
except Exception as e:
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
try:
await self._wait_for_ready_ws(timeout, interval)
return
except Exception as e:
self.logger.debug(f"WebSocket fallback also failed: {e}")
raise e
start_time = time.time()
last_error = None
attempt_count = 0
progress_interval = 10 # Log progress every 10 seconds
last_progress_time = start_time
try:
self.logger.info(
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
)
# Wait for the server to respond to get_screen_size command
while time.time() - start_time < timeout:
try:
attempt_count += 1
current_time = time.time()
# Log progress periodically without flooding logs
if current_time - last_progress_time >= progress_interval:
elapsed = current_time - start_time
self.logger.info(
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
)
last_progress_time = current_time
# Test the server with a simple get_screen_size command
result = await self._send_command("get_screen_size")
if result.get("success", False):
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
)
return # Server is ready
else:
last_error = result.get("error", "Unknown error")
self.logger.debug(f"Initial connection command failed: {last_error}")
except Exception as e:
last_error = e
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
# Wait before trying again
await asyncio.sleep(interval)
# If we get here, we've timed out
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
if last_error:
error_msg += f": {str(last_error)}"
self.logger.error(error_msg)
raise TimeoutError(error_msg)
except Exception as e:
if isinstance(e, TimeoutError):
raise
error_msg = f"Error while waiting for server: {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
async def _wait_for_ready_ws(self, timeout: int = 60, interval: float = 1.0):
"""Wait for WebSocket connection to become available."""
start_time = time.time()
last_error = None
@@ -755,7 +911,7 @@ class GenericComputerInterface(BaseComputerInterface):
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
# Test the connection with a simple command
try:
await self._send_command("get_screen_size")
await self._send_command_ws("get_screen_size")
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"