Added REST based send_command and wait_for_ready

This commit is contained in:
Dillon DuPont
2025-07-10 16:30:20 -04:00
parent 669832030e
commit 5a4d9598a8
2 changed files with 311 additions and 14 deletions

View File

@@ -2,8 +2,8 @@
"""
Connection test script for Computer Server.
This script tests the WebSocket connection to the Computer Server and keeps
it alive, allowing you to verify the server is running correctly.
This script tests both WebSocket (/ws) and REST (/cmd) connections to the Computer Server
and keeps it alive, allowing you to verify the server is running correctly.
"""
import asyncio
@@ -11,10 +11,14 @@ import json
import websockets
import argparse
import sys
import aiohttp
import os
import dotenv
dotenv.load_dotenv()
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None):
"""Test connection to the Computer Server."""
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
"""Test WebSocket connection to the Computer Server."""
if container_name:
# Container mode: use WSS with container domain and port 8443
uri = f"wss://{container_name}.containers.cloud.trycua.com:8443/ws"
@@ -26,15 +30,45 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
try:
async with websockets.connect(uri) as websocket:
print("Connection established!")
print("WebSocket connection established!")
# If container connection, send authentication first
if container_name:
if not api_key:
print("Error: API key required for container connections")
return False
print("Sending authentication...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": api_key,
"container_name": container_name
}
}
await websocket.send(json.dumps(auth_message))
auth_response = await websocket.recv()
print(f"Authentication response: {auth_response}")
# Check if authentication was successful
auth_data = json.loads(auth_response)
if not auth_data.get("success", False):
print("Authentication failed!")
return False
print("Authentication successful!")
# Send a test command to get version
await websocket.send(json.dumps({"command": "version", "params": {}}))
response = await websocket.recv()
print(f"Version response: {response}")
# Send a test command to get screen size
await websocket.send(json.dumps({"command": "get_screen_size", "params": {}}))
response = await websocket.recv()
print(f"Response: {response}")
print(f"Screen size response: {response}")
if keep_alive:
print("\nKeeping connection alive. Press Ctrl+C to exit...")
print("\nKeeping WebSocket connection alive. Press Ctrl+C to exit...")
while True:
# Send a command every 5 seconds to keep the connection alive
await asyncio.sleep(5)
@@ -44,24 +78,115 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
response = await websocket.recv()
print(f"Cursor position: {response}")
except websockets.exceptions.ConnectionClosed as e:
print(f"Connection closed: {e}")
print(f"WebSocket connection closed: {e}")
return False
except ConnectionRefusedError:
print(f"Connection refused. Is the server running at {host}:{port}?")
return False
except Exception as e:
print(f"Error: {e}")
print(f"WebSocket error: {e}")
return False
return True
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
"""Test REST connection to the Computer Server."""
if container_name:
# Container mode: use HTTPS with container domain and port 8443
base_url = f"https://{container_name}.containers.cloud.trycua.com:8443"
print(f"Connecting to container {container_name} at {base_url}...")
else:
# Local mode: use HTTP with specified host and port
base_url = f"http://{host}:{port}"
print(f"Connecting to local server at {base_url}...")
try:
async with aiohttp.ClientSession() as session:
print("REST connection established!")
# Prepare headers for container authentication
headers = {}
if container_name:
if not api_key:
print("Error: API key required for container connections")
return False
headers["X-Container-Name"] = container_name
headers["X-API-Key"] = api_key
print(f"Using container authentication headers")
# Test screenshot endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "screenshot", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Screenshot response: {text}")
else:
print(f"Screenshot request failed with status: {response.status}")
print(await response.text())
return False
# Test screen size endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "get_screen_size", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Screen size response: {text}")
else:
print(f"Screen size request failed with status: {response.status}")
print(await response.text())
return False
if keep_alive:
print("\nKeeping REST connection alive. Press Ctrl+C to exit...")
while True:
# Send a command every 5 seconds to keep testing
await asyncio.sleep(5)
async with session.post(
f"{base_url}/cmd",
json={"command": "get_cursor_position", "params": {}},
headers=headers
) as response:
if response.status == 200:
text = await response.text()
print(f"Cursor position: {text}")
else:
print(f"Cursor position request failed with status: {response.status}")
print(await response.text())
return False
except aiohttp.ClientError as e:
print(f"REST connection error: {e}")
return False
except Exception as e:
print(f"REST error: {e}")
return False
return True
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
"""Test connection to the Computer Server using WebSocket or REST."""
if use_rest:
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
else:
return await test_websocket_connection(host, port, keep_alive, container_name, api_key)
def parse_args():
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("--container-name", help="Container name for cloud connection (uses WSS and port 8443)")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
return parser.parse_args()
@@ -71,11 +196,27 @@ async def main():
# Convert hyphenated argument to underscore for function parameter
container_name = getattr(args, 'container_name', None)
# Get API key from argument or environment variable
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
# Check if container name is provided but API key is missing
if container_name and not api_key:
print("Warning: Container name provided but no API key found.")
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
return 1
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
if container_name:
print(f"Container: {container_name}")
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
success = await test_connection(
host=args.host,
port=args.port,
keep_alive=args.keep_alive,
container_name=container_name
container_name=container_name,
use_rest=args.rest,
api_key=api_key
)
return 0 if success else 1

View File

@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import websockets
import aiohttp
from ..logger import Logger, LogLevel
from .base import BaseComputerInterface
@@ -57,6 +58,17 @@ class GenericComputerInterface(BaseComputerInterface):
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
@property
def rest_uri(self) -> str:
"""Get the REST URI using the current IP address.
Returns:
REST URI for the Computer API Server
"""
protocol = "https" if self.api_key else "http"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/cmd"
# Mouse actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
@@ -677,7 +689,7 @@ class GenericComputerInterface(BaseComputerInterface):
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
async def _send_command_ws(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through WebSocket."""
max_retries = 3
retry_count = 0
@@ -717,7 +729,151 @@ class GenericComputerInterface(BaseComputerInterface):
raise last_error if last_error else RuntimeError("Failed to send command")
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through REST API without retries or connection management."""
try:
# Prepare the request payload
payload = {"command": command, "params": params or {}}
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
if self.vm_name:
headers["X-Container-Name"] = self.vm_name
# Send the request
async with aiohttp.ClientSession() as session:
async with session.post(
self.rest_uri,
json=payload,
headers=headers
) as response:
# Get the response text
response_text = await response.text()
# Trim whitespace
response_text = response_text.strip()
# Check if it starts with "data: "
if response_text.startswith("data: "):
# Extract everything after "data: "
json_str = response_text[6:] # Remove "data: " prefix
try:
return json.loads(json_str)
except json.JSONDecodeError:
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
}
else:
# Return error response
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
}
except Exception as e:
return {
"success": False,
"error": "Request failed",
"message": str(e)
}
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command using REST API with WebSocket fallback."""
# Try REST API first
result = await self._send_command_rest(command, params)
# If REST failed with "Request failed", try WebSocket as fallback
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
self.logger.debug(f"REST API failed for command '{command}', trying WebSocket fallback")
try:
return await self._send_command_ws(command, params)
except Exception as e:
self.logger.debug(f"WebSocket fallback also failed: {e}")
# Return the original REST error
return result
return result
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
"""Wait for Computer API Server to be ready by testing version command."""
# Check if REST API is available
try:
result = await self._send_command_rest("version", {})
assert result.get("success", True)
except Exception as e:
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
try:
await self._wait_for_ready_ws(timeout, interval)
return
except Exception as e:
self.logger.debug(f"WebSocket fallback also failed: {e}")
raise e
start_time = time.time()
last_error = None
attempt_count = 0
progress_interval = 10 # Log progress every 10 seconds
last_progress_time = start_time
try:
self.logger.info(
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
)
# Wait for the server to respond to get_screen_size command
while time.time() - start_time < timeout:
try:
attempt_count += 1
current_time = time.time()
# Log progress periodically without flooding logs
if current_time - last_progress_time >= progress_interval:
elapsed = current_time - start_time
self.logger.info(
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
)
last_progress_time = current_time
# Test the server with a simple get_screen_size command
result = await self._send_command("get_screen_size")
if result.get("success", False):
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
)
return # Server is ready
else:
last_error = result.get("error", "Unknown error")
self.logger.debug(f"Initial connection command failed: {last_error}")
except Exception as e:
last_error = e
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
# Wait before trying again
await asyncio.sleep(interval)
# If we get here, we've timed out
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
if last_error:
error_msg += f": {str(last_error)}"
self.logger.error(error_msg)
raise TimeoutError(error_msg)
except Exception as e:
if isinstance(e, TimeoutError):
raise
error_msg = f"Error while waiting for server: {str(e)}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
async def _wait_for_ready_ws(self, timeout: int = 60, interval: float = 1.0):
"""Wait for WebSocket connection to become available."""
start_time = time.time()
last_error = None
@@ -755,7 +911,7 @@ class GenericComputerInterface(BaseComputerInterface):
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
# Test the connection with a simple command
try:
await self._send_command("get_screen_size")
await self._send_command_ws("get_screen_size")
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"