Added auth manager

This commit is contained in:
Dillon DuPont
2025-07-10 15:55:59 -04:00
parent a54b3e81b6
commit 669832030e

View File

@@ -12,6 +12,8 @@ from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
import hashlib
import time
# Set up logging with more detail
logger = logging.getLogger(__name__)
@@ -86,6 +88,95 @@ handlers = {
}
class AuthenticationManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.container_name = os.environ.get("CONTAINER_NAME")
def _hash_credentials(self, container_name: str, api_key: str) -> str:
"""Create a hash of container name and API key for session identification"""
combined = f"{container_name}:{api_key}"
return hashlib.sha256(combined.encode()).hexdigest()
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
"""Check if a session is still valid based on expiration time"""
if not session_data.get('valid', False):
return False
expires_at = session_data.get('expires_at', 0)
return time.time() < expires_at
async def auth(self, container_name: str, api_key: str) -> bool:
"""Authenticate container name and API key, using cached sessions when possible"""
# If no CONTAINER_NAME is set, always allow access (local development)
if not self.container_name:
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
return True
# Layer 1: VM Identity Verification
if container_name != self.container_name:
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
return False
# Create hash for session lookup
session_hash = self._hash_credentials(container_name, api_key)
# Check if we have a valid cached session
if session_hash in self.sessions:
session_data = self.sessions[session_hash]
if self._is_session_valid(session_data):
logger.info(f"Using cached authentication for container: {container_name}")
return session_data['valid']
else:
# Remove expired session
del self.sessions[session_hash]
# No valid cached session, authenticate with API
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
is_valid = resp.status == 200 and bool((await resp.text()).strip())
# Cache the result with 5 second expiration
self.sessions[session_hash] = {
'valid': is_valid,
'expires_at': time.time() + 5 # 5 seconds from now
}
if is_valid:
logger.info(f"Authentication successful for container: {container_name}")
else:
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
return is_valid
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + 5
}
return False
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
@@ -99,6 +190,7 @@ class ConnectionManager:
manager = ConnectionManager()
auth_manager = AuthenticationManager()
@app.websocket("/ws", name="websocket_endpoint")
@@ -109,12 +201,12 @@ async def websocket_endpoint(websocket: WebSocket):
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
container_name = os.environ.get("CONTAINER_NAME")
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if container_name:
if server_container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
# Wait for authentication message
auth_data = await websocket.receive_json()
@@ -133,18 +225,7 @@ async def websocket_endpoint(websocket: WebSocket):
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Layer 1: VM Identity Verification
if client_container_name != container_name:
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
await websocket.send_json({
"success": False,
"error": "VM name mismatch"
})
await websocket.close()
manager.disconnect(websocket)
return
# Layer 2: API Key Validation with TryCUA API
# Validate credentials using AuthenticationManager
if not client_api_key:
await websocket.send_json({
"success": False,
@@ -154,58 +235,34 @@ async def websocket_endpoint(websocket: WebSocket):
manager.disconnect(websocket)
return
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {client_api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
if resp.status != 200:
error_msg = await resp.text()
logger.warning(f"API validation failed: {error_msg}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
# If we get a 200 response with VNC URL, the VM exists and user has access
vnc_url = (await resp.text()).strip()
if not vnc_url:
logger.warning(f"No VNC URL returned for VM: {container_name}")
await websocket.send_json({
"success": False,
"error": "VM not found"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {container_name}")
await websocket.send_json({
"success": True,
"message": "Authenticated"
})
except Exception as e:
logger.error(f"Error validating with TryCUA API: {e}")
if not client_container_name:
await websocket.send_json({
"success": False,
"error": "Authentication service unavailable"
"error": "Container name required"
})
await websocket.close()
manager.disconnect(websocket)
return
# Use AuthenticationManager for validation
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
if not is_authenticated:
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {client_container_name}")
await websocket.send_json({
"success": True,
"message": "Authentication successful"
})
except Exception as e:
logger.error(f"Authentication error: {e}")
logger.error(f"Error during authentication handshake: {str(e)}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
@@ -310,45 +367,17 @@ async def cmd_endpoint(
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
# Layer 1: VM Identity Verification
if container_name != server_container_name:
logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}")
raise HTTPException(status_code=401, detail="VM name mismatch")
# Validate required headers
if not container_name:
raise HTTPException(status_code=401, detail="Container name required")
# Layer 2: API Key Validation with TryCUA API
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with TryCUA API
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}",
headers=headers,
) as resp:
if resp.status != 200:
logger.warning(f"API key validation failed. Status: {resp.status}")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_failed = not (await resp.text()).strip()
if auth_failed:
logger.warning(f"API key validation failed.")
raise HTTPException(status_code=401, detail="Invalid API key")
logger.info("Authentication successful")
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
raise HTTPException(status_code=500, detail="Authentication service unavailable")
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
raise HTTPException(status_code=500, detail="Authentication failed")
# Validate with AuthenticationManager
is_authenticated = await auth_manager.auth(container_name, api_key)
if not is_authenticated:
raise HTTPException(status_code=401, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")