add auth flow

This commit is contained in:
Dillon DuPont
2025-05-26 20:49:44 -04:00
parent b370c9a225
commit de680a2941
7 changed files with 212 additions and 24 deletions

View File

@@ -8,11 +8,11 @@ import traceback
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
# Set up logging with more detail
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure WebSocket with larger message size
@@ -48,6 +48,112 @@ manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket):
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if VM_NAME is set (indicating cloud provider)
vm_name = os.environ.get("VM_NAME")
# If cloud provider, perform authentication handshake
if vm_name:
try:
logger.info(f"Cloud provider detected. VM_NAME: {vm_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_vm_name = auth_data.get("params", {}).get("vm_name")
# Layer 1: VM Identity Verification
if client_vm_name != vm_name:
logger.warning(f"VM name mismatch. Expected: {vm_name}, Got: {client_vm_name}")
await websocket.send_json({
"success": False,
"error": "VM name mismatch"
})
await websocket.close()
manager.disconnect(websocket)
return
# Layer 2: API Key Validation with TryCUA API
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {client_api_key}"}
params = {"vm_name": vm_name}
async with session.get(
"https://trycua.com/api/vm-host",
headers=headers,
params=params
) as resp:
if resp.status != 200:
error_msg = await resp.text()
logger.warning(f"API validation failed: {error_msg}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# If we get a 200 response with VNC URL, the VM exists and user has access
vnc_url = (await resp.text()).strip()
if not vnc_url:
logger.warning(f"No VNC URL returned for VM: {vm_name}")
await websocket.send_json({
"success": False,
"error": "VM not found"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {vm_name}")
await websocket.send_json({
"success": True,
"message": "Authenticated"
})
except Exception as e:
logger.error(f"Error validating with TryCUA API: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication service unavailable"
})
await websocket.close()
manager.disconnect(websocket)
return
except Exception as e:
logger.error(f"Authentication error: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# Map commands to appropriate handler methods
handlers = {

View File

@@ -393,12 +393,25 @@ class Computer:
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
),
)
# Pass authentication credentials if using cloud provider
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address
),
)
# Wait for the WebSocket interface to be ready
self.logger.info("Connecting to WebSocket interface...")
@@ -493,6 +506,11 @@ class Computer:
# Call the provider's get_ip method which will wait indefinitely
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.get_ip with explicit image parameter
ip = await self.config.vm_provider.get_ip(
name=self.config.name,
storage=storage_param,

View File

@@ -8,17 +8,21 @@ from ..logger import Logger, LogLevel
class BaseComputerInterface(ABC):
"""Base class for computer control interfaces."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
"""Initialize interface.
Args:
ip_address: IP address of the computer to control
username: Username for authentication
password: Password for authentication
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
"""
self.ip_address = ip_address
self.username = username
self.password = password
self.api_key = api_key
self.vm_name = vm_name
self.logger = Logger("cua.interface", LogLevel.NORMAL)
@abstractmethod

View File

@@ -1,6 +1,6 @@
"""Factory for creating computer interfaces."""
from typing import Literal
from typing import Literal, Optional
from .base import BaseComputerInterface
class InterfaceFactory:
@@ -9,13 +9,17 @@ class InterfaceFactory:
@staticmethod
def create_interface_for_os(
os: Literal['macos', 'linux'],
ip_address: str
ip_address: str,
api_key: Optional[str] = None,
vm_name: Optional[str] = None
) -> BaseComputerInterface:
"""Create an interface for the specified OS.
Args:
os: Operating system type ('macos' or 'linux')
ip_address: IP address of the computer to control
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
Returns:
BaseComputerInterface: The appropriate interface for the OS
@@ -28,8 +32,8 @@ class InterfaceFactory:
from .linux import LinuxComputerInterface
if os == 'macos':
return MacOSComputerInterface(ip_address)
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'linux':
return LinuxComputerInterface(ip_address)
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
else:
raise ValueError(f"Unsupported OS type: {os}")

View File

@@ -15,8 +15,8 @@ from .models import Key, KeyType
class LinuxComputerInterface(BaseComputerInterface):
"""Interface for Linux."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
super().__init__(ip_address, username, password)
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
@@ -86,6 +86,32 @@ class LinuxComputerInterface(BaseComputerInterface):
timeout=30,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"vm_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection

View File

@@ -13,10 +13,10 @@ from .models import Key, KeyType
class MacOSComputerInterface(BaseComputerInterface):
"""Interface for MacOS."""
"""Interface for macOS."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume"):
super().__init__(ip_address, username, password)
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
@@ -27,7 +27,7 @@ class MacOSComputerInterface(BaseComputerInterface):
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
self._log_connection_attempts = True # Flag to control connection attempt logging
# Set logger name for MacOS interface
# Set logger name for macOS interface
self.logger = Logger("cua.interface.macos", LogLevel.NORMAL)
@property
@@ -86,6 +86,32 @@ class MacOSComputerInterface(BaseComputerInterface):
timeout=30,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"vm_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection

View File

@@ -19,7 +19,7 @@ class CloudProvider(BaseVMProvider):
"""Cloud VM Provider implementation using /api/vm-host endpoint."""
def __init__(
self,
api_key: str = None,
api_key: str,
endpoint_url: str = "https://trycua.com/api/vm-host",
verbose: bool = False,
**kwargs,
@@ -56,13 +56,13 @@ class CloudProvider(BaseVMProvider):
vnc_url = (await resp.text()).strip()
parsed = urlparse(vnc_url)
hostname = parsed.hostname
return {"name": vm_name, "status": "available", "vnc_url": vnc_url, "hostname": hostname}
return {"name": name, "status": "available", "vnc_url": vnc_url, "hostname": hostname}
else:
try:
error = await resp.json()
except Exception:
error = {"error": await resp.text()}
return {"name": vm_name, "status": "error", **error}
return {"name": name, "status": "error", **error}
async def list_vms(self) -> List[Dict[str, Any]]:
logger.warning("CloudProvider.list_vms is not implemented")
@@ -83,9 +83,13 @@ class CloudProvider(BaseVMProvider):
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""
Return the VM's IP address as '{vm_name}.us.vms.trycua.com'.
Uses the provided 'name' argument (the VM name requested by the caller).
Uses the provided 'name' argument (the VM name requested by the caller),
falling back to self.name only if 'name' is None.
Retries up to 3 times with retry_delay seconds if hostname is not available.
"""
if name is None:
raise ValueError("VM name is required for CloudProvider.get_ip")
attempts = 3
last_error = None
for attempt in range(attempts):