mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 18:40:04 -06:00
Add /start, /stop, and /restart to cloud provider
This commit is contained in:
@@ -1,6 +1,11 @@
|
||||
"""Cloud VM provider implementation.
|
||||
"""Cloud VM provider implementation using CUA Public API.
|
||||
|
||||
This module contains a stub implementation for a future cloud VM provider.
|
||||
Implements the following public API endpoints:
|
||||
|
||||
- GET /v1/vms
|
||||
- POST /v1/vms/:name/start
|
||||
- POST /v1/vms/:name/stop
|
||||
- POST /v1/vms/:name/restart
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -14,6 +19,10 @@ logger = logging.getLogger(__name__)
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
|
||||
|
||||
DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai")
|
||||
|
||||
class CloudProvider(BaseVMProvider):
|
||||
"""Cloud VM Provider implementation."""
|
||||
@@ -21,6 +30,7 @@ class CloudProvider(BaseVMProvider):
|
||||
self,
|
||||
api_key: str,
|
||||
verbose: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -32,6 +42,7 @@ class CloudProvider(BaseVMProvider):
|
||||
assert api_key, "api_key required for CloudProvider"
|
||||
self.api_key = api_key
|
||||
self.verbose = verbose
|
||||
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
@@ -44,24 +55,158 @@ class CloudProvider(BaseVMProvider):
|
||||
pass
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM VNC URL by name using the cloud API."""
|
||||
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
|
||||
"""Get VM information by querying the VM status endpoint.
|
||||
|
||||
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
|
||||
- Probe https://{hostname}:8443/status with a short timeout
|
||||
- If JSON contains a "status" field, return it; otherwise infer
|
||||
- Fallback to DNS resolve check to distinguish unknown vs not_found
|
||||
"""
|
||||
hostname = await self.get_ip(name=name)
|
||||
|
||||
# Try HTTPS probe to the computer-server status endpoint (8443)
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=3)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
url = f"https://{hostname}:8443/status"
|
||||
async with session.get(url, allow_redirects=False) as resp:
|
||||
status_code = resp.status
|
||||
vm_status: str
|
||||
vm_os_type: Optional[str] = None
|
||||
if status_code == 200:
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
vm_status = str(data.get("status", "ok"))
|
||||
if isinstance(data, dict) and "os_type" in data:
|
||||
vm_os_type = str(data.get("os_type"))
|
||||
except Exception:
|
||||
vm_status = "unknown"
|
||||
elif status_code < 500:
|
||||
vm_status = "unknown"
|
||||
else:
|
||||
vm_status = "unknown"
|
||||
return {
|
||||
"name": name,
|
||||
"status": "running" if vm_status == "ok" else vm_status,
|
||||
"hostname": hostname,
|
||||
"os_type": vm_os_type,
|
||||
}
|
||||
except Exception:
|
||||
# Fall back to a DNS resolve check
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.getaddrinfo(hostname, 443)
|
||||
# Host resolves, but HTTPS probe failed → treat as unknown
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unknown",
|
||||
"hostname": hostname,
|
||||
}
|
||||
except Exception:
|
||||
# Host does not resolve → not found
|
||||
return {"name": name, "status": "not_found", "hostname": hostname}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
logger.warning("CloudProvider.list_vms is not implemented")
|
||||
return []
|
||||
url = f"{self.api_base}/v1/vms"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
except Exception:
|
||||
text = await resp.text()
|
||||
logger.error(f"Failed to parse list_vms JSON: {text}")
|
||||
return []
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
logger.warning("Unexpected response for list_vms; expected list")
|
||||
return []
|
||||
elif resp.status == 401:
|
||||
logger.error("Unauthorized: invalid CUA API key for list_vms")
|
||||
return []
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"list_vms failed: HTTP {resp.status} - {text}")
|
||||
return []
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
# logger.warning("CloudProvider.run_vm is not implemented")
|
||||
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
|
||||
"""Start a VM via public API. Returns a minimal status."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/start"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 201, 202, 204):
|
||||
return {"name": name, "status": "starting"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
|
||||
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
|
||||
"""Stop a VM via public API."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/stop"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 202):
|
||||
# Spec says 202 with {"status":"stopping"}
|
||||
body_status: Optional[str] = None
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
body_status = data.get("status") if isinstance(data, dict) else None
|
||||
except Exception:
|
||||
body_status = None
|
||||
return {"name": name, "status": body_status or "stopping"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Restart a VM via public API."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/restart"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 202):
|
||||
# Spec says 202 with {"status":"restarting"}
|
||||
body_status: Optional[str] = None
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
body_status = data.get("status") if isinstance(data, dict) else None
|
||||
except Exception:
|
||||
body_status = None
|
||||
return {"name": name, "status": body_status or "restarting"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.update_vm is not implemented")
|
||||
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
|
||||
logger.warning("CloudProvider.update_vm is not implemented via public API")
|
||||
return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"}
|
||||
|
||||
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user