Add /start, /stop, and /restart to cloud provider

This commit is contained in:
Dillon DuPont
2025-10-09 12:01:45 -04:00
parent 003c10a846
commit 0ede822990

View File

@@ -1,6 +1,11 @@
"""Cloud VM provider implementation.
"""Cloud VM provider implementation using CUA Public API.
This module contains a stub implementation for a future cloud VM provider.
Implements the following public API endpoints:
- GET /v1/vms
- POST /v1/vms/:name/start
- POST /v1/vms/:name/stop
- POST /v1/vms/:name/restart
"""
import logging
@@ -14,6 +19,10 @@ logger = logging.getLogger(__name__)
import asyncio
import aiohttp
from urllib.parse import urlparse
import os
DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai")
class CloudProvider(BaseVMProvider):
"""Cloud VM Provider implementation."""
@@ -21,6 +30,7 @@ class CloudProvider(BaseVMProvider):
self,
api_key: str,
verbose: bool = False,
api_base: Optional[str] = None,
**kwargs,
):
"""
@@ -32,6 +42,7 @@ class CloudProvider(BaseVMProvider):
assert api_key, "api_key required for CloudProvider"
self.api_key = api_key
self.verbose = verbose
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
@property
def provider_type(self) -> VMProviderType:
@@ -44,24 +55,158 @@ class CloudProvider(BaseVMProvider):
pass
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM VNC URL by name using the cloud API."""
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
"""Get VM information by querying the VM status endpoint.
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
- Probe https://{hostname}:8443/status with a short timeout
- If JSON contains a "status" field, return it; otherwise infer
- Fallback to DNS resolve check to distinguish unknown vs not_found
"""
hostname = await self.get_ip(name=name)
# Try HTTPS probe to the computer-server status endpoint (8443)
try:
timeout = aiohttp.ClientTimeout(total=3)
async with aiohttp.ClientSession(timeout=timeout) as session:
url = f"https://{hostname}:8443/status"
async with session.get(url, allow_redirects=False) as resp:
status_code = resp.status
vm_status: str
vm_os_type: Optional[str] = None
if status_code == 200:
try:
data = await resp.json(content_type=None)
vm_status = str(data.get("status", "ok"))
if isinstance(data, dict) and "os_type" in data:
vm_os_type = str(data.get("os_type"))
except Exception:
vm_status = "unknown"
elif status_code < 500:
vm_status = "unknown"
else:
vm_status = "unknown"
return {
"name": name,
"status": "running" if vm_status == "ok" else vm_status,
"hostname": hostname,
"os_type": vm_os_type,
}
except Exception:
# Fall back to a DNS resolve check
try:
loop = asyncio.get_event_loop()
await loop.getaddrinfo(hostname, 443)
# Host resolves, but HTTPS probe failed → treat as unknown
return {
"name": name,
"status": "unknown",
"hostname": hostname,
}
except Exception:
# Host does not resolve → not found
return {"name": name, "status": "not_found", "hostname": hostname}
async def list_vms(self) -> List[Dict[str, Any]]:
logger.warning("CloudProvider.list_vms is not implemented")
return []
url = f"{self.api_base}/v1/vms"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
try:
data = await resp.json(content_type=None)
except Exception:
text = await resp.text()
logger.error(f"Failed to parse list_vms JSON: {text}")
return []
if isinstance(data, list):
return data
logger.warning("Unexpected response for list_vms; expected list")
return []
elif resp.status == 401:
logger.error("Unauthorized: invalid CUA API key for list_vms")
return []
else:
text = await resp.text()
logger.error(f"list_vms failed: HTTP {resp.status} - {text}")
return []
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
# logger.warning("CloudProvider.run_vm is not implemented")
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
"""Start a VM via public API. Returns a minimal status."""
url = f"{self.api_base}/v1/vms/{name}/start"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 201, 202, 204):
return {"name": name, "status": "starting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
"""Stop a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/stop"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"stopping"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "stopping"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Restart a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/restart"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"restarting"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "restarting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.update_vm is not implemented")
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
logger.warning("CloudProvider.update_vm is not implemented via public API")
return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"}
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""