From 669832030ebbecf82f9ee4ae823074565479e533 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Thu, 10 Jul 2025 15:55:59 -0400 Subject: [PATCH] Added auth manager --- .../computer-server/computer_server/main.py | 219 ++++++++++-------- 1 file changed, 124 insertions(+), 95 deletions(-) diff --git a/libs/python/computer-server/computer_server/main.py b/libs/python/computer-server/computer_server/main.py index 5c396691..306f3cd2 100644 --- a/libs/python/computer-server/computer_server/main.py +++ b/libs/python/computer-server/computer_server/main.py @@ -12,6 +12,8 @@ from io import StringIO from .handlers.factory import HandlerFactory import os import aiohttp +import hashlib +import time # Set up logging with more detail logger = logging.getLogger(__name__) @@ -86,6 +88,95 @@ handlers = { } +class AuthenticationManager: + def __init__(self): + self.sessions: Dict[str, Dict[str, Any]] = {} + self.container_name = os.environ.get("CONTAINER_NAME") + + def _hash_credentials(self, container_name: str, api_key: str) -> str: + """Create a hash of container name and API key for session identification""" + combined = f"{container_name}:{api_key}" + return hashlib.sha256(combined.encode()).hexdigest() + + def _is_session_valid(self, session_data: Dict[str, Any]) -> bool: + """Check if a session is still valid based on expiration time""" + if not session_data.get('valid', False): + return False + + expires_at = session_data.get('expires_at', 0) + return time.time() < expires_at + + async def auth(self, container_name: str, api_key: str) -> bool: + """Authenticate container name and API key, using cached sessions when possible""" + # If no CONTAINER_NAME is set, always allow access (local development) + if not self.container_name: + logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)") + return True + + # Layer 1: VM Identity Verification + if container_name != self.container_name: + logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}") + return False + + # Create hash for session lookup + session_hash = self._hash_credentials(container_name, api_key) + + # Check if we have a valid cached session + if session_hash in self.sessions: + session_data = self.sessions[session_hash] + if self._is_session_valid(session_data): + logger.info(f"Using cached authentication for container: {container_name}") + return session_data['valid'] + else: + # Remove expired session + del self.sessions[session_hash] + + # No valid cached session, authenticate with API + logger.info(f"Authenticating with TryCUA API for container: {container_name}") + + try: + async with aiohttp.ClientSession() as session: + headers = { + "Authorization": f"Bearer {api_key}" + } + + async with session.get( + f"https://www.trycua.com/api/vm/auth?container_name={container_name}", + headers=headers, + ) as resp: + is_valid = resp.status == 200 and bool((await resp.text()).strip()) + + # Cache the result with 5 second expiration + self.sessions[session_hash] = { + 'valid': is_valid, + 'expires_at': time.time() + 5 # 5 seconds from now + } + + if is_valid: + logger.info(f"Authentication successful for container: {container_name}") + else: + logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}") + + return is_valid + + except aiohttp.ClientError as e: + logger.error(f"Failed to validate API key with TryCUA API: {str(e)}") + # Cache failed result to avoid repeated requests + self.sessions[session_hash] = { + 'valid': False, + 'expires_at': time.time() + 5 + } + return False + except Exception as e: + logger.error(f"Unexpected error during authentication: {str(e)}") + # Cache failed result to avoid repeated requests + self.sessions[session_hash] = { + 'valid': False, + 'expires_at': time.time() + 5 + } + return False + + class ConnectionManager: def __init__(self): self.active_connections: List[WebSocket] = [] @@ -99,6 +190,7 @@ class ConnectionManager: manager = ConnectionManager() +auth_manager = AuthenticationManager() @app.websocket("/ws", name="websocket_endpoint") @@ -109,12 +201,12 @@ async def websocket_endpoint(websocket: WebSocket): await manager.connect(websocket) # Check if CONTAINER_NAME is set (indicating cloud provider) - container_name = os.environ.get("CONTAINER_NAME") + server_container_name = os.environ.get("CONTAINER_NAME") # If cloud provider, perform authentication handshake - if container_name: + if server_container_name: try: - logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...") + logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...") # Wait for authentication message auth_data = await websocket.receive_json() @@ -133,18 +225,7 @@ async def websocket_endpoint(websocket: WebSocket): client_api_key = auth_data.get("params", {}).get("api_key") client_container_name = auth_data.get("params", {}).get("container_name") - # Layer 1: VM Identity Verification - if client_container_name != container_name: - logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}") - await websocket.send_json({ - "success": False, - "error": "VM name mismatch" - }) - await websocket.close() - manager.disconnect(websocket) - return - - # Layer 2: API Key Validation with TryCUA API + # Validate credentials using AuthenticationManager if not client_api_key: await websocket.send_json({ "success": False, @@ -154,58 +235,34 @@ async def websocket_endpoint(websocket: WebSocket): manager.disconnect(websocket) return - # Validate with TryCUA API - try: - async with aiohttp.ClientSession() as session: - headers = { - "Authorization": f"Bearer {client_api_key}" - } - - async with session.get( - f"https://www.trycua.com/api/vm/auth?container_name={container_name}", - headers=headers, - ) as resp: - if resp.status != 200: - error_msg = await resp.text() - logger.warning(f"API validation failed: {error_msg}") - await websocket.send_json({ - "success": False, - "error": "Authentication failed" - }) - await websocket.close() - manager.disconnect(websocket) - return - - # If we get a 200 response with VNC URL, the VM exists and user has access - vnc_url = (await resp.text()).strip() - if not vnc_url: - logger.warning(f"No VNC URL returned for VM: {container_name}") - await websocket.send_json({ - "success": False, - "error": "VM not found" - }) - await websocket.close() - manager.disconnect(websocket) - return - - logger.info(f"Authentication successful for VM: {container_name}") - await websocket.send_json({ - "success": True, - "message": "Authenticated" - }) - - except Exception as e: - logger.error(f"Error validating with TryCUA API: {e}") + if not client_container_name: await websocket.send_json({ "success": False, - "error": "Authentication service unavailable" + "error": "Container name required" }) await websocket.close() manager.disconnect(websocket) return - + + # Use AuthenticationManager for validation + is_authenticated = await auth_manager.auth(client_container_name, client_api_key) + if not is_authenticated: + await websocket.send_json({ + "success": False, + "error": "Authentication failed" + }) + await websocket.close() + manager.disconnect(websocket) + return + + logger.info(f"Authentication successful for VM: {client_container_name}") + await websocket.send_json({ + "success": True, + "message": "Authentication successful" + }) + except Exception as e: - logger.error(f"Authentication error: {e}") + logger.error(f"Error during authentication handshake: {str(e)}") await websocket.send_json({ "success": False, "error": "Authentication failed" @@ -310,45 +367,17 @@ async def cmd_endpoint( if server_container_name: logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...") - # Layer 1: VM Identity Verification - if container_name != server_container_name: - logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}") - raise HTTPException(status_code=401, detail="VM name mismatch") + # Validate required headers + if not container_name: + raise HTTPException(status_code=401, detail="Container name required") - # Layer 2: API Key Validation with TryCUA API if not api_key: raise HTTPException(status_code=401, detail="API key required") - # Validate with TryCUA API - try: - async with aiohttp.ClientSession() as session: - headers = { - "Authorization": f"Bearer {api_key}" - } - - async with session.get( - f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}", - headers=headers, - ) as resp: - if resp.status != 200: - logger.warning(f"API key validation failed. Status: {resp.status}") - raise HTTPException(status_code=401, detail="Invalid API key") - - auth_failed = not (await resp.text()).strip() - if auth_failed: - logger.warning(f"API key validation failed.") - raise HTTPException(status_code=401, detail="Invalid API key") - - logger.info("Authentication successful") - - except aiohttp.ClientError as e: - logger.error(f"Failed to validate API key with TryCUA API: {str(e)}") - raise HTTPException(status_code=500, detail="Authentication service unavailable") - except HTTPException: - raise - except Exception as e: - logger.error(f"Unexpected error during authentication: {str(e)}") - raise HTTPException(status_code=500, detail="Authentication failed") + # Validate with AuthenticationManager + is_authenticated = await auth_manager.auth(container_name, api_key) + if not is_authenticated: + raise HTTPException(status_code=401, detail="Authentication failed") if command not in handlers: raise HTTPException(status_code=400, detail=f"Unknown command: {command}")