add auth flow

This commit is contained in:
Dillon DuPont
2025-05-26 20:49:44 -04:00
parent b370c9a225
commit de680a2941
7 changed files with 212 additions and 24 deletions

View File

@@ -8,11 +8,11 @@ import traceback
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
# Set up logging with more detail
logging.basicConfig(
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Configure WebSocket with larger message size
@@ -48,6 +48,112 @@ manager = ConnectionManager()
async def websocket_endpoint(websocket: WebSocket):
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if VM_NAME is set (indicating cloud provider)
vm_name = os.environ.get("VM_NAME")
# If cloud provider, perform authentication handshake
if vm_name:
try:
logger.info(f"Cloud provider detected. VM_NAME: {vm_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_vm_name = auth_data.get("params", {}).get("vm_name")
# Layer 1: VM Identity Verification
if client_vm_name != vm_name:
logger.warning(f"VM name mismatch. Expected: {vm_name}, Got: {client_vm_name}")
await websocket.send_json({
"success": False,
"error": "VM name mismatch"
})
await websocket.close()
manager.disconnect(websocket)
return
# Layer 2: API Key Validation with TryCUA API
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {"Authorization": f"Bearer {client_api_key}"}
params = {"vm_name": vm_name}
async with session.get(
"https://trycua.com/api/vm-host",
headers=headers,
params=params
) as resp:
if resp.status != 200:
error_msg = await resp.text()
logger.warning(f"API validation failed: {error_msg}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# If we get a 200 response with VNC URL, the VM exists and user has access
vnc_url = (await resp.text()).strip()
if not vnc_url:
logger.warning(f"No VNC URL returned for VM: {vm_name}")
await websocket.send_json({
"success": False,
"error": "VM not found"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {vm_name}")
await websocket.send_json({
"success": True,
"message": "Authenticated"
})
except Exception as e:
logger.error(f"Error validating with TryCUA API: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication service unavailable"
})
await websocket.close()
manager.disconnect(websocket)
return
except Exception as e:
logger.error(f"Authentication error: {e}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# Map commands to appropriate handler methods
handlers = {