mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 19:10:30 -06:00
add auth flow
This commit is contained in:
@@ -8,11 +8,11 @@ import traceback
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from io import StringIO
|
||||
from .handlers.factory import HandlerFactory
|
||||
import os
|
||||
import aiohttp
|
||||
|
||||
# Set up logging with more detail
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configure WebSocket with larger message size
|
||||
@@ -48,6 +48,112 @@ manager = ConnectionManager()
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
# WebSocket message size is configured at the app or endpoint level, not on the instance
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Check if VM_NAME is set (indicating cloud provider)
|
||||
vm_name = os.environ.get("VM_NAME")
|
||||
|
||||
# If cloud provider, perform authentication handshake
|
||||
if vm_name:
|
||||
try:
|
||||
logger.info(f"Cloud provider detected. VM_NAME: {vm_name}. Waiting for authentication...")
|
||||
|
||||
# Wait for authentication message
|
||||
auth_data = await websocket.receive_json()
|
||||
|
||||
# Validate auth message format
|
||||
if auth_data.get("command") != "authenticate":
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "First message must be authentication"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Extract credentials
|
||||
client_api_key = auth_data.get("params", {}).get("api_key")
|
||||
client_vm_name = auth_data.get("params", {}).get("vm_name")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if client_vm_name != vm_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {vm_name}, Got: {client_vm_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM name mismatch"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
if not client_api_key:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "API key required"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {"Authorization": f"Bearer {client_api_key}"}
|
||||
params = {"vm_name": vm_name}
|
||||
|
||||
async with session.get(
|
||||
"https://trycua.com/api/vm-host",
|
||||
headers=headers,
|
||||
params=params
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_msg = await resp.text()
|
||||
logger.warning(f"API validation failed: {error_msg}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# If we get a 200 response with VNC URL, the VM exists and user has access
|
||||
vnc_url = (await resp.text()).strip()
|
||||
if not vnc_url:
|
||||
logger.warning(f"No VNC URL returned for VM: {vm_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM not found"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {vm_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authenticated"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating with TryCUA API: {e}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication service unavailable"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Map commands to appropriate handler methods
|
||||
handlers = {
|
||||
|
||||
@@ -393,12 +393,25 @@ class Computer:
|
||||
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
# Pass authentication credentials if using cloud provider
|
||||
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address
|
||||
),
|
||||
)
|
||||
|
||||
# Wait for the WebSocket interface to be ready
|
||||
self.logger.info("Connecting to WebSocket interface...")
|
||||
@@ -493,6 +506,11 @@ class Computer:
|
||||
|
||||
# Call the provider's get_ip method which will wait indefinitely
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
# Call provider.get_ip with explicit image parameter
|
||||
ip = await self.config.vm_provider.get_ip(
|
||||
name=self.config.name,
|
||||
storage=storage_param,
|
||||
|
||||
@@ -8,17 +8,21 @@ from ..logger import Logger, LogLevel
|
||||
class BaseComputerInterface(ABC):
|
||||
"""Base class for computer control interfaces."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
"""Initialize interface.
|
||||
|
||||
Args:
|
||||
ip_address: IP address of the computer to control
|
||||
username: Username for authentication
|
||||
password: Password for authentication
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
"""
|
||||
self.ip_address = ip_address
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.api_key = api_key
|
||||
self.vm_name = vm_name
|
||||
self.logger = Logger("cua.interface", LogLevel.NORMAL)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Factory for creating computer interfaces."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
from .base import BaseComputerInterface
|
||||
|
||||
class InterfaceFactory:
|
||||
@@ -9,13 +9,17 @@ class InterfaceFactory:
|
||||
@staticmethod
|
||||
def create_interface_for_os(
|
||||
os: Literal['macos', 'linux'],
|
||||
ip_address: str
|
||||
ip_address: str,
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None
|
||||
) -> BaseComputerInterface:
|
||||
"""Create an interface for the specified OS.
|
||||
|
||||
Args:
|
||||
os: Operating system type ('macos' or 'linux')
|
||||
ip_address: IP address of the computer to control
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
|
||||
Returns:
|
||||
BaseComputerInterface: The appropriate interface for the OS
|
||||
@@ -28,8 +32,8 @@ class InterfaceFactory:
|
||||
from .linux import LinuxComputerInterface
|
||||
|
||||
if os == 'macos':
|
||||
return MacOSComputerInterface(ip_address)
|
||||
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'linux':
|
||||
return LinuxComputerInterface(ip_address)
|
||||
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported OS type: {os}")
|
||||
@@ -15,8 +15,8 @@ from .models import Key, KeyType
|
||||
class LinuxComputerInterface(BaseComputerInterface):
|
||||
"""Interface for Linux."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
super().__init__(ip_address, username, password)
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
@@ -86,6 +86,32 @@ class LinuxComputerInterface(BaseComputerInterface):
|
||||
timeout=30,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"vm_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
|
||||
@@ -13,10 +13,10 @@ from .models import Key, KeyType
|
||||
|
||||
|
||||
class MacOSComputerInterface(BaseComputerInterface):
|
||||
"""Interface for MacOS."""
|
||||
"""Interface for macOS."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
|
||||
super().__init__(ip_address, username, password)
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
@@ -27,7 +27,7 @@ class MacOSComputerInterface(BaseComputerInterface):
|
||||
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
|
||||
self._log_connection_attempts = True # Flag to control connection attempt logging
|
||||
|
||||
# Set logger name for MacOS interface
|
||||
# Set logger name for macOS interface
|
||||
self.logger = Logger("cua.interface.macos", LogLevel.NORMAL)
|
||||
|
||||
@property
|
||||
@@ -86,6 +86,32 @@ class MacOSComputerInterface(BaseComputerInterface):
|
||||
timeout=30,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"vm_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
|
||||
@@ -19,7 +19,7 @@ class CloudProvider(BaseVMProvider):
|
||||
"""Cloud VM Provider implementation using /api/vm-host endpoint."""
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = None,
|
||||
api_key: str,
|
||||
endpoint_url: str = "https://trycua.com/api/vm-host",
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
@@ -56,13 +56,13 @@ class CloudProvider(BaseVMProvider):
|
||||
vnc_url = (await resp.text()).strip()
|
||||
parsed = urlparse(vnc_url)
|
||||
hostname = parsed.hostname
|
||||
return {"name": vm_name, "status": "available", "vnc_url": vnc_url, "hostname": hostname}
|
||||
return {"name": name, "status": "available", "vnc_url": vnc_url, "hostname": hostname}
|
||||
else:
|
||||
try:
|
||||
error = await resp.json()
|
||||
except Exception:
|
||||
error = {"error": await resp.text()}
|
||||
return {"name": vm_name, "status": "error", **error}
|
||||
return {"name": name, "status": "error", **error}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
logger.warning("CloudProvider.list_vms is not implemented")
|
||||
@@ -83,9 +83,13 @@ class CloudProvider(BaseVMProvider):
|
||||
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""
|
||||
Return the VM's IP address as '{vm_name}.us.vms.trycua.com'.
|
||||
Uses the provided 'name' argument (the VM name requested by the caller).
|
||||
Uses the provided 'name' argument (the VM name requested by the caller),
|
||||
falling back to self.name only if 'name' is None.
|
||||
Retries up to 3 times with retry_delay seconds if hostname is not available.
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError("VM name is required for CloudProvider.get_ip")
|
||||
|
||||
attempts = 3
|
||||
last_error = None
|
||||
for attempt in range(attempts):
|
||||
|
||||
Reference in New Issue
Block a user