diff --git a/libs/python/computer-server/test_connection.py b/libs/python/computer-server/test_connection.py index de4eb2df..8b9f3222 100755 --- a/libs/python/computer-server/test_connection.py +++ b/libs/python/computer-server/test_connection.py @@ -2,8 +2,8 @@ """ Connection test script for Computer Server. -This script tests the WebSocket connection to the Computer Server and keeps -it alive, allowing you to verify the server is running correctly. +This script tests both WebSocket (/ws) and REST (/cmd) connections to the Computer Server +and keeps it alive, allowing you to verify the server is running correctly. """ import asyncio @@ -11,10 +11,14 @@ import json import websockets import argparse import sys +import aiohttp +import os +import dotenv +dotenv.load_dotenv() -async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None): - """Test connection to the Computer Server.""" +async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None): + """Test WebSocket connection to the Computer Server.""" if container_name: # Container mode: use WSS with container domain and port 8443 uri = f"wss://{container_name}.containers.cloud.trycua.com:8443/ws" @@ -26,15 +30,45 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain try: async with websockets.connect(uri) as websocket: - print("Connection established!") + print("WebSocket connection established!") + + # If container connection, send authentication first + if container_name: + if not api_key: + print("Error: API key required for container connections") + return False + + print("Sending authentication...") + auth_message = { + "command": "authenticate", + "params": { + "api_key": api_key, + "container_name": container_name + } + } + await websocket.send(json.dumps(auth_message)) + auth_response = await websocket.recv() + print(f"Authentication response: {auth_response}") + + # Check if authentication was successful + auth_data = json.loads(auth_response) + if not auth_data.get("success", False): + print("Authentication failed!") + return False + print("Authentication successful!") + + # Send a test command to get version + await websocket.send(json.dumps({"command": "version", "params": {}})) + response = await websocket.recv() + print(f"Version response: {response}") # Send a test command to get screen size await websocket.send(json.dumps({"command": "get_screen_size", "params": {}})) response = await websocket.recv() - print(f"Response: {response}") + print(f"Screen size response: {response}") if keep_alive: - print("\nKeeping connection alive. Press Ctrl+C to exit...") + print("\nKeeping WebSocket connection alive. Press Ctrl+C to exit...") while True: # Send a command every 5 seconds to keep the connection alive await asyncio.sleep(5) @@ -44,24 +78,115 @@ async def test_connection(host="localhost", port=8000, keep_alive=False, contain response = await websocket.recv() print(f"Cursor position: {response}") except websockets.exceptions.ConnectionClosed as e: - print(f"Connection closed: {e}") + print(f"WebSocket connection closed: {e}") return False except ConnectionRefusedError: print(f"Connection refused. Is the server running at {host}:{port}?") return False except Exception as e: - print(f"Error: {e}") + print(f"WebSocket error: {e}") return False return True +async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None): + """Test REST connection to the Computer Server.""" + if container_name: + # Container mode: use HTTPS with container domain and port 8443 + base_url = f"https://{container_name}.containers.cloud.trycua.com:8443" + print(f"Connecting to container {container_name} at {base_url}...") + else: + # Local mode: use HTTP with specified host and port + base_url = f"http://{host}:{port}" + print(f"Connecting to local server at {base_url}...") + + try: + async with aiohttp.ClientSession() as session: + print("REST connection established!") + + # Prepare headers for container authentication + headers = {} + if container_name: + if not api_key: + print("Error: API key required for container connections") + return False + headers["X-Container-Name"] = container_name + headers["X-API-Key"] = api_key + print(f"Using container authentication headers") + + # Test screenshot endpoint + async with session.post( + f"{base_url}/cmd", + json={"command": "screenshot", "params": {}}, + headers=headers + ) as response: + if response.status == 200: + text = await response.text() + print(f"Screenshot response: {text}") + else: + print(f"Screenshot request failed with status: {response.status}") + print(await response.text()) + return False + + # Test screen size endpoint + async with session.post( + f"{base_url}/cmd", + json={"command": "get_screen_size", "params": {}}, + headers=headers + ) as response: + if response.status == 200: + text = await response.text() + print(f"Screen size response: {text}") + else: + print(f"Screen size request failed with status: {response.status}") + print(await response.text()) + return False + + if keep_alive: + print("\nKeeping REST connection alive. Press Ctrl+C to exit...") + while True: + # Send a command every 5 seconds to keep testing + await asyncio.sleep(5) + async with session.post( + f"{base_url}/cmd", + json={"command": "get_cursor_position", "params": {}}, + headers=headers + ) as response: + if response.status == 200: + text = await response.text() + print(f"Cursor position: {text}") + else: + print(f"Cursor position request failed with status: {response.status}") + print(await response.text()) + return False + + except aiohttp.ClientError as e: + print(f"REST connection error: {e}") + return False + except Exception as e: + print(f"REST error: {e}") + return False + + return True + + +async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None): + """Test connection to the Computer Server using WebSocket or REST.""" + if use_rest: + return await test_rest_connection(host, port, keep_alive, container_name, api_key) + else: + return await test_websocket_connection(host, port, keep_alive, container_name, api_key) + + def parse_args(): parser = argparse.ArgumentParser(description="Test connection to Computer Server") parser.add_argument("--host", default="localhost", help="Host address (default: localhost)") - parser.add_argument("--port", type=int, default=8000, help="Port number (default: 8000)") - parser.add_argument("--container-name", help="Container name for cloud connection (uses WSS and port 8443)") + parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)") + parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)") + parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)") parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive") + parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)") return parser.parse_args() @@ -71,11 +196,27 @@ async def main(): # Convert hyphenated argument to underscore for function parameter container_name = getattr(args, 'container_name', None) + # Get API key from argument or environment variable + api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY') + + # Check if container name is provided but API key is missing + if container_name and not api_key: + print("Warning: Container name provided but no API key found.") + print("Please provide --api-key argument or set CUA_API_KEY environment variable.") + return 1 + + print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...") + if container_name: + print(f"Container: {container_name}") + print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}") + success = await test_connection( host=args.host, port=args.port, keep_alive=args.keep_alive, - container_name=container_name + container_name=container_name, + use_rest=args.rest, + api_key=api_key ) return 0 if success else 1 diff --git a/libs/python/computer/computer/interface/generic.py b/libs/python/computer/computer/interface/generic.py index 8f5c3a2c..41bea4be 100644 --- a/libs/python/computer/computer/interface/generic.py +++ b/libs/python/computer/computer/interface/generic.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple from PIL import Image import websockets +import aiohttp from ..logger import Logger, LogLevel from .base import BaseComputerInterface @@ -57,6 +58,17 @@ class GenericComputerInterface(BaseComputerInterface): protocol = "wss" if self.api_key else "ws" port = "8443" if self.api_key else "8000" return f"{protocol}://{self.ip_address}:{port}/ws" + + @property + def rest_uri(self) -> str: + """Get the REST URI using the current IP address. + + Returns: + REST URI for the Computer API Server + """ + protocol = "https" if self.api_key else "http" + port = "8443" if self.api_key else "8000" + return f"{protocol}://{self.ip_address}:{port}/cmd" # Mouse actions async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None: @@ -677,7 +689,7 @@ class GenericComputerInterface(BaseComputerInterface): raise ConnectionError("Failed to establish WebSocket connection after multiple retries") - async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]: + async def _send_command_ws(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]: """Send command through WebSocket.""" max_retries = 3 retry_count = 0 @@ -717,7 +729,151 @@ class GenericComputerInterface(BaseComputerInterface): raise last_error if last_error else RuntimeError("Failed to send command") + async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]: + """Send command through REST API without retries or connection management.""" + try: + # Prepare the request payload + payload = {"command": command, "params": params or {}} + + # Prepare headers + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["X-API-Key"] = self.api_key + if self.vm_name: + headers["X-Container-Name"] = self.vm_name + + # Send the request + async with aiohttp.ClientSession() as session: + async with session.post( + self.rest_uri, + json=payload, + headers=headers + ) as response: + # Get the response text + response_text = await response.text() + + # Trim whitespace + response_text = response_text.strip() + + # Check if it starts with "data: " + if response_text.startswith("data: "): + # Extract everything after "data: " + json_str = response_text[6:] # Remove "data: " prefix + try: + return json.loads(json_str) + except json.JSONDecodeError: + return { + "success": False, + "error": "Server returned malformed response", + "message": response_text + } + else: + # Return error response + return { + "success": False, + "error": "Server returned malformed response", + "message": response_text + } + + except Exception as e: + return { + "success": False, + "error": "Request failed", + "message": str(e) + } + + async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]: + """Send command using REST API with WebSocket fallback.""" + # Try REST API first + result = await self._send_command_rest(command, params) + + # If REST failed with "Request failed", try WebSocket as fallback + if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"): + self.logger.debug(f"REST API failed for command '{command}', trying WebSocket fallback") + try: + return await self._send_command_ws(command, params) + except Exception as e: + self.logger.debug(f"WebSocket fallback also failed: {e}") + # Return the original REST error + return result + + return result + async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0): + """Wait for Computer API Server to be ready by testing version command.""" + + # Check if REST API is available + try: + result = await self._send_command_rest("version", {}) + assert result.get("success", True) + except Exception as e: + self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}") + try: + await self._wait_for_ready_ws(timeout, interval) + return + except Exception as e: + self.logger.debug(f"WebSocket fallback also failed: {e}") + raise e + + start_time = time.time() + last_error = None + attempt_count = 0 + progress_interval = 10 # Log progress every 10 seconds + last_progress_time = start_time + + try: + self.logger.info( + f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..." + ) + + # Wait for the server to respond to get_screen_size command + while time.time() - start_time < timeout: + try: + attempt_count += 1 + current_time = time.time() + + # Log progress periodically without flooding logs + if current_time - last_progress_time >= progress_interval: + elapsed = current_time - start_time + self.logger.info( + f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})" + ) + last_progress_time = current_time + + # Test the server with a simple get_screen_size command + result = await self._send_command("get_screen_size") + if result.get("success", False): + elapsed = time.time() - start_time + self.logger.info( + f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)" + ) + return # Server is ready + else: + last_error = result.get("error", "Unknown error") + self.logger.debug(f"Initial connection command failed: {last_error}") + + except Exception as e: + last_error = e + self.logger.debug(f"Connection attempt {attempt_count} failed: {e}") + + # Wait before trying again + await asyncio.sleep(interval) + + # If we get here, we've timed out + error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds" + if last_error: + error_msg += f": {str(last_error)}" + self.logger.error(error_msg) + raise TimeoutError(error_msg) + + except Exception as e: + if isinstance(e, TimeoutError): + raise + error_msg = f"Error while waiting for server: {str(e)}" + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + async def _wait_for_ready_ws(self, timeout: int = 60, interval: float = 1.0): """Wait for WebSocket connection to become available.""" start_time = time.time() last_error = None @@ -755,7 +911,7 @@ class GenericComputerInterface(BaseComputerInterface): if self._ws and self._ws.state == websockets.protocol.State.OPEN: # Test the connection with a simple command try: - await self._send_command("get_screen_size") + await self._send_command_ws("get_screen_size") elapsed = time.time() - start_time self.logger.info( f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"