diff --git a/libs/python/computer/computer/providers/cloud/provider.py b/libs/python/computer/computer/providers/cloud/provider.py index 1cfba161..a54e01c4 100644 --- a/libs/python/computer/computer/providers/cloud/provider.py +++ b/libs/python/computer/computer/providers/cloud/provider.py @@ -1,6 +1,11 @@ -"""Cloud VM provider implementation. +"""Cloud VM provider implementation using CUA Public API. -This module contains a stub implementation for a future cloud VM provider. +Implements the following public API endpoints: + +- GET /v1/vms +- POST /v1/vms/:name/start +- POST /v1/vms/:name/stop +- POST /v1/vms/:name/restart """ import logging @@ -14,6 +19,10 @@ logger = logging.getLogger(__name__) import asyncio import aiohttp from urllib.parse import urlparse +import os + + +DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai") class CloudProvider(BaseVMProvider): """Cloud VM Provider implementation.""" @@ -21,6 +30,7 @@ class CloudProvider(BaseVMProvider): self, api_key: str, verbose: bool = False, + api_base: Optional[str] = None, **kwargs, ): """ @@ -32,6 +42,7 @@ class CloudProvider(BaseVMProvider): assert api_key, "api_key required for CloudProvider" self.api_key = api_key self.verbose = verbose + self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/") @property def provider_type(self) -> VMProviderType: @@ -44,24 +55,158 @@ class CloudProvider(BaseVMProvider): pass async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]: - """Get VM VNC URL by name using the cloud API.""" - return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"} + """Get VM information by querying the VM status endpoint. + + - Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com" + - Probe https://{hostname}:8443/status with a short timeout + - If JSON contains a "status" field, return it; otherwise infer + - Fallback to DNS resolve check to distinguish unknown vs not_found + """ + hostname = await self.get_ip(name=name) + + # Try HTTPS probe to the computer-server status endpoint (8443) + try: + timeout = aiohttp.ClientTimeout(total=3) + async with aiohttp.ClientSession(timeout=timeout) as session: + url = f"https://{hostname}:8443/status" + async with session.get(url, allow_redirects=False) as resp: + status_code = resp.status + vm_status: str + vm_os_type: Optional[str] = None + if status_code == 200: + try: + data = await resp.json(content_type=None) + vm_status = str(data.get("status", "ok")) + if isinstance(data, dict) and "os_type" in data: + vm_os_type = str(data.get("os_type")) + except Exception: + vm_status = "unknown" + elif status_code < 500: + vm_status = "unknown" + else: + vm_status = "unknown" + return { + "name": name, + "status": "running" if vm_status == "ok" else vm_status, + "hostname": hostname, + "os_type": vm_os_type, + } + except Exception: + # Fall back to a DNS resolve check + try: + loop = asyncio.get_event_loop() + await loop.getaddrinfo(hostname, 443) + # Host resolves, but HTTPS probe failed → treat as unknown + return { + "name": name, + "status": "unknown", + "hostname": hostname, + } + except Exception: + # Host does not resolve → not found + return {"name": name, "status": "not_found", "hostname": hostname} async def list_vms(self) -> List[Dict[str, Any]]: - logger.warning("CloudProvider.list_vms is not implemented") - return [] + url = f"{self.api_base}/v1/vms" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "application/json", + } + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as resp: + if resp.status == 200: + try: + data = await resp.json(content_type=None) + except Exception: + text = await resp.text() + logger.error(f"Failed to parse list_vms JSON: {text}") + return [] + if isinstance(data, list): + return data + logger.warning("Unexpected response for list_vms; expected list") + return [] + elif resp.status == 401: + logger.error("Unauthorized: invalid CUA API key for list_vms") + return [] + else: + text = await resp.text() + logger.error(f"list_vms failed: HTTP {resp.status} - {text}") + return [] async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]: - # logger.warning("CloudProvider.run_vm is not implemented") - return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"} + """Start a VM via public API. Returns a minimal status.""" + url = f"{self.api_base}/v1/vms/{name}/start" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "application/json", + } + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers) as resp: + if resp.status in (200, 201, 202, 204): + return {"name": name, "status": "starting"} + elif resp.status == 404: + return {"name": name, "status": "not_found"} + elif resp.status == 401: + return {"name": name, "status": "unauthorized"} + else: + text = await resp.text() + return {"name": name, "status": "error", "message": text} async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]: - logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()") - return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"} + """Stop a VM via public API.""" + url = f"{self.api_base}/v1/vms/{name}/stop" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "application/json", + } + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers) as resp: + if resp.status in (200, 202): + # Spec says 202 with {"status":"stopping"} + body_status: Optional[str] = None + try: + data = await resp.json(content_type=None) + body_status = data.get("status") if isinstance(data, dict) else None + except Exception: + body_status = None + return {"name": name, "status": body_status or "stopping"} + elif resp.status == 404: + return {"name": name, "status": "not_found"} + elif resp.status == 401: + return {"name": name, "status": "unauthorized"} + else: + text = await resp.text() + return {"name": name, "status": "error", "message": text} + + async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]: + """Restart a VM via public API.""" + url = f"{self.api_base}/v1/vms/{name}/restart" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Accept": "application/json", + } + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers) as resp: + if resp.status in (200, 202): + # Spec says 202 with {"status":"restarting"} + body_status: Optional[str] = None + try: + data = await resp.json(content_type=None) + body_status = data.get("status") if isinstance(data, dict) else None + except Exception: + body_status = None + return {"name": name, "status": body_status or "restarting"} + elif resp.status == 404: + return {"name": name, "status": "not_found"} + elif resp.status == 401: + return {"name": name, "status": "unauthorized"} + else: + text = await resp.text() + return {"name": name, "status": "error", "message": text} async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]: - logger.warning("CloudProvider.update_vm is not implemented") - return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"} + logger.warning("CloudProvider.update_vm is not implemented via public API") + return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"} async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str: """