mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 02:19:58 -06:00
Added auth manager
This commit is contained in:
@@ -12,6 +12,8 @@ from io import StringIO
|
||||
from .handlers.factory import HandlerFactory
|
||||
import os
|
||||
import aiohttp
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
# Set up logging with more detail
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -86,6 +88,95 @@ handlers = {
|
||||
}
|
||||
|
||||
|
||||
class AuthenticationManager:
|
||||
def __init__(self):
|
||||
self.sessions: Dict[str, Dict[str, Any]] = {}
|
||||
self.container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
def _hash_credentials(self, container_name: str, api_key: str) -> str:
|
||||
"""Create a hash of container name and API key for session identification"""
|
||||
combined = f"{container_name}:{api_key}"
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
|
||||
"""Check if a session is still valid based on expiration time"""
|
||||
if not session_data.get('valid', False):
|
||||
return False
|
||||
|
||||
expires_at = session_data.get('expires_at', 0)
|
||||
return time.time() < expires_at
|
||||
|
||||
async def auth(self, container_name: str, api_key: str) -> bool:
|
||||
"""Authenticate container name and API key, using cached sessions when possible"""
|
||||
# If no CONTAINER_NAME is set, always allow access (local development)
|
||||
if not self.container_name:
|
||||
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
|
||||
return True
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if container_name != self.container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
|
||||
return False
|
||||
|
||||
# Create hash for session lookup
|
||||
session_hash = self._hash_credentials(container_name, api_key)
|
||||
|
||||
# Check if we have a valid cached session
|
||||
if session_hash in self.sessions:
|
||||
session_data = self.sessions[session_hash]
|
||||
if self._is_session_valid(session_data):
|
||||
logger.info(f"Using cached authentication for container: {container_name}")
|
||||
return session_data['valid']
|
||||
else:
|
||||
# Remove expired session
|
||||
del self.sessions[session_hash]
|
||||
|
||||
# No valid cached session, authenticate with API
|
||||
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
is_valid = resp.status == 200 and bool((await resp.text()).strip())
|
||||
|
||||
# Cache the result with 5 second expiration
|
||||
self.sessions[session_hash] = {
|
||||
'valid': is_valid,
|
||||
'expires_at': time.time() + 5 # 5 seconds from now
|
||||
}
|
||||
|
||||
if is_valid:
|
||||
logger.info(f"Authentication successful for container: {container_name}")
|
||||
else:
|
||||
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
|
||||
|
||||
return is_valid
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + 5
|
||||
}
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during authentication: {str(e)}")
|
||||
# Cache failed result to avoid repeated requests
|
||||
self.sessions[session_hash] = {
|
||||
'valid': False,
|
||||
'expires_at': time.time() + 5
|
||||
}
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: List[WebSocket] = []
|
||||
@@ -99,6 +190,7 @@ class ConnectionManager:
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
auth_manager = AuthenticationManager()
|
||||
|
||||
|
||||
@app.websocket("/ws", name="websocket_endpoint")
|
||||
@@ -109,12 +201,12 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
await manager.connect(websocket)
|
||||
|
||||
# Check if CONTAINER_NAME is set (indicating cloud provider)
|
||||
container_name = os.environ.get("CONTAINER_NAME")
|
||||
server_container_name = os.environ.get("CONTAINER_NAME")
|
||||
|
||||
# If cloud provider, perform authentication handshake
|
||||
if container_name:
|
||||
if server_container_name:
|
||||
try:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {container_name}. Waiting for authentication...")
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
|
||||
|
||||
# Wait for authentication message
|
||||
auth_data = await websocket.receive_json()
|
||||
@@ -133,18 +225,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
client_api_key = auth_data.get("params", {}).get("api_key")
|
||||
client_container_name = auth_data.get("params", {}).get("container_name")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if client_container_name != container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {container_name}, Got: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM name mismatch"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
# Validate credentials using AuthenticationManager
|
||||
if not client_api_key:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
@@ -154,58 +235,34 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {client_api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
error_msg = await resp.text()
|
||||
logger.warning(f"API validation failed: {error_msg}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
# If we get a 200 response with VNC URL, the VM exists and user has access
|
||||
vnc_url = (await resp.text()).strip()
|
||||
if not vnc_url:
|
||||
logger.warning(f"No VNC URL returned for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "VM not found"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authenticated"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating with TryCUA API: {e}")
|
||||
if not client_container_name:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication service unavailable"
|
||||
"error": "Container name required"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
|
||||
# Use AuthenticationManager for validation
|
||||
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
|
||||
if not is_authenticated:
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
})
|
||||
await websocket.close()
|
||||
manager.disconnect(websocket)
|
||||
return
|
||||
|
||||
logger.info(f"Authentication successful for VM: {client_container_name}")
|
||||
await websocket.send_json({
|
||||
"success": True,
|
||||
"message": "Authentication successful"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
logger.error(f"Error during authentication handshake: {str(e)}")
|
||||
await websocket.send_json({
|
||||
"success": False,
|
||||
"error": "Authentication failed"
|
||||
@@ -310,45 +367,17 @@ async def cmd_endpoint(
|
||||
if server_container_name:
|
||||
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
|
||||
|
||||
# Layer 1: VM Identity Verification
|
||||
if container_name != server_container_name:
|
||||
logger.warning(f"VM name mismatch. Expected: {server_container_name}, Got: {container_name}")
|
||||
raise HTTPException(status_code=401, detail="VM name mismatch")
|
||||
# Validate required headers
|
||||
if not container_name:
|
||||
raise HTTPException(status_code=401, detail="Container name required")
|
||||
|
||||
# Layer 2: API Key Validation with TryCUA API
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="API key required")
|
||||
|
||||
# Validate with TryCUA API
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
}
|
||||
|
||||
async with session.get(
|
||||
f"https://www.trycua.com/api/vm/auth?container_name={server_container_name}",
|
||||
headers=headers,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
logger.warning(f"API key validation failed. Status: {resp.status}")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
auth_failed = not (await resp.text()).strip()
|
||||
if auth_failed:
|
||||
logger.warning(f"API key validation failed.")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
logger.info("Authentication successful")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Authentication service unavailable")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during authentication: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Authentication failed")
|
||||
# Validate with AuthenticationManager
|
||||
is_authenticated = await auth_manager.auth(container_name, api_key)
|
||||
if not is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
if command not in handlers:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
|
||||
|
||||
Reference in New Issue
Block a user