mirror of
https://github.com/trycua/computer.git
synced 2026-02-18 12:28:51 -06:00
Added REST based send_command and wait_for_ready
This commit is contained in:
@@ -2,8 +2,8 @@
|
||||
"""
|
||||
Connection test script for Computer Server.
|
||||
|
||||
This script tests the WebSocket connection to the Computer Server and keeps
|
||||
it alive, allowing you to verify the server is running correctly.
|
||||
This script tests both WebSocket (/ws) and REST (/cmd) connections to the Computer Server
|
||||
and keeps it alive, allowing you to verify the server is running correctly.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -11,10 +11,14 @@ import json
|
||||
import websockets
|
||||
import argparse
|
||||
import sys
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
import dotenv
|
||||
dotenv.load_dotenv()
|
||||
|
||||
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None):
|
||||
"""Test connection to the Computer Server."""
|
||||
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
"""Test WebSocket connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use WSS with container domain and port 8443
|
||||
uri = f"wss://{container_name}.containers.cloud.trycua.com:8443/ws"
|
||||
@@ -26,15 +30,45 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
|
||||
|
||||
try:
|
||||
async with websockets.connect(uri) as websocket:
|
||||
print("Connection established!")
|
||||
print("WebSocket connection established!")
|
||||
|
||||
# If container connection, send authentication first
|
||||
if container_name:
|
||||
if not api_key:
|
||||
print("Error: API key required for container connections")
|
||||
return False
|
||||
|
||||
print("Sending authentication...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": api_key,
|
||||
"container_name": container_name
|
||||
}
|
||||
}
|
||||
await websocket.send(json.dumps(auth_message))
|
||||
auth_response = await websocket.recv()
|
||||
print(f"Authentication response: {auth_response}")
|
||||
|
||||
# Check if authentication was successful
|
||||
auth_data = json.loads(auth_response)
|
||||
if not auth_data.get("success", False):
|
||||
print("Authentication failed!")
|
||||
return False
|
||||
print("Authentication successful!")
|
||||
|
||||
# Send a test command to get version
|
||||
await websocket.send(json.dumps({"command": "version", "params": {}}))
|
||||
response = await websocket.recv()
|
||||
print(f"Version response: {response}")
|
||||
|
||||
# Send a test command to get screen size
|
||||
await websocket.send(json.dumps({"command": "get_screen_size", "params": {}}))
|
||||
response = await websocket.recv()
|
||||
print(f"Response: {response}")
|
||||
print(f"Screen size response: {response}")
|
||||
|
||||
if keep_alive:
|
||||
print("\nKeeping connection alive. Press Ctrl+C to exit...")
|
||||
print("\nKeeping WebSocket connection alive. Press Ctrl+C to exit...")
|
||||
while True:
|
||||
# Send a command every 5 seconds to keep the connection alive
|
||||
await asyncio.sleep(5)
|
||||
@@ -44,24 +78,115 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain
|
||||
response = await websocket.recv()
|
||||
print(f"Cursor position: {response}")
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
print(f"Connection closed: {e}")
|
||||
print(f"WebSocket connection closed: {e}")
|
||||
return False
|
||||
except ConnectionRefusedError:
|
||||
print(f"Connection refused. Is the server running at {host}:{port}?")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"WebSocket error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
|
||||
"""Test REST connection to the Computer Server."""
|
||||
if container_name:
|
||||
# Container mode: use HTTPS with container domain and port 8443
|
||||
base_url = f"https://{container_name}.containers.cloud.trycua.com:8443"
|
||||
print(f"Connecting to container {container_name} at {base_url}...")
|
||||
else:
|
||||
# Local mode: use HTTP with specified host and port
|
||||
base_url = f"http://{host}:{port}"
|
||||
print(f"Connecting to local server at {base_url}...")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
print("REST connection established!")
|
||||
|
||||
# Prepare headers for container authentication
|
||||
headers = {}
|
||||
if container_name:
|
||||
if not api_key:
|
||||
print("Error: API key required for container connections")
|
||||
return False
|
||||
headers["X-Container-Name"] = container_name
|
||||
headers["X-API-Key"] = api_key
|
||||
print(f"Using container authentication headers")
|
||||
|
||||
# Test screenshot endpoint
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "screenshot", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Screenshot response: {text}")
|
||||
else:
|
||||
print(f"Screenshot request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
# Test screen size endpoint
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_screen_size", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Screen size response: {text}")
|
||||
else:
|
||||
print(f"Screen size request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
if keep_alive:
|
||||
print("\nKeeping REST connection alive. Press Ctrl+C to exit...")
|
||||
while True:
|
||||
# Send a command every 5 seconds to keep testing
|
||||
await asyncio.sleep(5)
|
||||
async with session.post(
|
||||
f"{base_url}/cmd",
|
||||
json={"command": "get_cursor_position", "params": {}},
|
||||
headers=headers
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
text = await response.text()
|
||||
print(f"Cursor position: {text}")
|
||||
else:
|
||||
print(f"Cursor position request failed with status: {response.status}")
|
||||
print(await response.text())
|
||||
return False
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"REST connection error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"REST error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
|
||||
"""Test connection to the Computer Server using WebSocket or REST."""
|
||||
if use_rest:
|
||||
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
|
||||
else:
|
||||
return await test_websocket_connection(host, port, keep_alive, container_name, api_key)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
|
||||
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)")
|
||||
parser.add_argument("--container-name", help="Container name for cloud connection (uses WSS and port 8443)")
|
||||
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
|
||||
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
|
||||
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
|
||||
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
|
||||
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -71,11 +196,27 @@ async def main():
|
||||
# Convert hyphenated argument to underscore for function parameter
|
||||
container_name = getattr(args, 'container_name', None)
|
||||
|
||||
# Get API key from argument or environment variable
|
||||
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
|
||||
|
||||
# Check if container name is provided but API key is missing
|
||||
if container_name and not api_key:
|
||||
print("Warning: Container name provided but no API key found.")
|
||||
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
|
||||
return 1
|
||||
|
||||
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
|
||||
if container_name:
|
||||
print(f"Container: {container_name}")
|
||||
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
|
||||
|
||||
success = await test_connection(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
keep_alive=args.keep_alive,
|
||||
container_name=container_name
|
||||
container_name=container_name,
|
||||
use_rest=args.rest,
|
||||
api_key=api_key
|
||||
)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
import aiohttp
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .base import BaseComputerInterface
|
||||
@@ -57,6 +58,17 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
@property
|
||||
def rest_uri(self) -> str:
|
||||
"""Get the REST URI using the current IP address.
|
||||
|
||||
Returns:
|
||||
REST URI for the Computer API Server
|
||||
"""
|
||||
protocol = "https" if self.api_key else "http"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/cmd"
|
||||
|
||||
# Mouse actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
|
||||
@@ -677,7 +689,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
async def _send_command_ws(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through WebSocket."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
@@ -717,7 +729,151 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through REST API without retries or connection management."""
|
||||
try:
|
||||
# Prepare the request payload
|
||||
payload = {"command": command, "params": params or {}}
|
||||
|
||||
# Prepare headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["X-API-Key"] = self.api_key
|
||||
if self.vm_name:
|
||||
headers["X-Container-Name"] = self.vm_name
|
||||
|
||||
# Send the request
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
self.rest_uri,
|
||||
json=payload,
|
||||
headers=headers
|
||||
) as response:
|
||||
# Get the response text
|
||||
response_text = await response.text()
|
||||
|
||||
# Trim whitespace
|
||||
response_text = response_text.strip()
|
||||
|
||||
# Check if it starts with "data: "
|
||||
if response_text.startswith("data: "):
|
||||
# Extract everything after "data: "
|
||||
json_str = response_text[6:] # Remove "data: " prefix
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
}
|
||||
else:
|
||||
# Return error response
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Server returned malformed response",
|
||||
"message": response_text
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Request failed",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command using REST API with WebSocket fallback."""
|
||||
# Try REST API first
|
||||
result = await self._send_command_rest(command, params)
|
||||
|
||||
# If REST failed with "Request failed", try WebSocket as fallback
|
||||
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
|
||||
self.logger.debug(f"REST API failed for command '{command}', trying WebSocket fallback")
|
||||
try:
|
||||
return await self._send_command_ws(command, params)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"WebSocket fallback also failed: {e}")
|
||||
# Return the original REST error
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for Computer API Server to be ready by testing version command."""
|
||||
|
||||
# Check if REST API is available
|
||||
try:
|
||||
result = await self._send_command_rest("version", {})
|
||||
assert result.get("success", True)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
|
||||
try:
|
||||
await self._wait_for_ready_ws(timeout, interval)
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.debug(f"WebSocket fallback also failed: {e}")
|
||||
raise e
|
||||
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
attempt_count = 0
|
||||
progress_interval = 10 # Log progress every 10 seconds
|
||||
last_progress_time = start_time
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
|
||||
)
|
||||
|
||||
# Wait for the server to respond to get_screen_size command
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
attempt_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# Log progress periodically without flooding logs
|
||||
if current_time - last_progress_time >= progress_interval:
|
||||
elapsed = current_time - start_time
|
||||
self.logger.info(
|
||||
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
|
||||
)
|
||||
last_progress_time = current_time
|
||||
|
||||
# Test the server with a simple get_screen_size command
|
||||
result = await self._send_command("get_screen_size")
|
||||
if result.get("success", False):
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
)
|
||||
return # Server is ready
|
||||
else:
|
||||
last_error = result.get("error", "Unknown error")
|
||||
self.logger.debug(f"Initial connection command failed: {last_error}")
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
|
||||
|
||||
# Wait before trying again
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
|
||||
if last_error:
|
||||
error_msg += f": {str(last_error)}"
|
||||
self.logger.error(error_msg)
|
||||
raise TimeoutError(error_msg)
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, TimeoutError):
|
||||
raise
|
||||
error_msg = f"Error while waiting for server: {str(e)}"
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def _wait_for_ready_ws(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for WebSocket connection to become available."""
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
@@ -755,7 +911,7 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
# Test the connection with a simple command
|
||||
try:
|
||||
await self._send_command("get_screen_size")
|
||||
await self._send_command_ws("get_screen_size")
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
|
||||
Reference in New Issue
Block a user