Merge pull request #452 from trycua/feat/extended_cloud_api

Add Cloud VM Management API
This commit is contained in:
ddupont
2025-10-14 11:31:18 -04:00
committed by GitHub
11 changed files with 645 additions and 17 deletions

View File

@@ -1,3 +1,4 @@
import traceback
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
import asyncio
from .models import Computer as ComputerConfig, Display
@@ -451,6 +452,7 @@ class Computer:
raise RuntimeError(f"VM failed to become ready: {wait_error}")
except Exception as e:
self.logger.error(f"Failed to initialize computer: {e}")
self.logger.error(traceback.format_exc())
raise RuntimeError(f"Failed to initialize computer: {e}")
try:
@@ -558,6 +560,102 @@ class Computer:
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
return
async def start(self) -> None:
"""Start the computer."""
await self.run()
async def restart(self) -> None:
"""Restart the computer.
If using a VM provider that supports restart, this will issue a restart
without tearing down the provider context, then reconnect the interface.
Falls back to stop()+run() when a provider restart is not available.
"""
# Host computer server: just disconnect and run again
if self.use_host_computer_server:
try:
await self.disconnect()
finally:
await self.run()
return
# If no VM provider context yet, fall back to full run
if not getattr(self, "_provider_context", None) or self.config.vm_provider is None:
self.logger.info("No provider context active; performing full restart via run()")
await self.run()
return
# Gracefully close current interface connection if present
if self._interface:
try:
self._interface.close()
except Exception as e:
self.logger.debug(f"Error closing interface prior to restart: {e}")
# Attempt provider-level restart if implemented
try:
storage_param = "ephemeral" if self.ephemeral else self.storage
if hasattr(self.config.vm_provider, "restart_vm"):
self.logger.info(f"Restarting VM {self.config.name} via provider...")
await self.config.vm_provider.restart_vm(name=self.config.name, storage=storage_param)
else:
# Fallback: stop then start without leaving provider context
self.logger.info(f"Provider has no restart_vm; performing stop+start for {self.config.name}...")
await self.config.vm_provider.stop_vm(name=self.config.name, storage=storage_param)
await self.config.vm_provider.run_vm(image=self.image, name=self.config.name, run_opts={}, storage=storage_param)
except Exception as e:
self.logger.error(f"Failed to restart VM via provider: {e}")
# As a last resort, do a full stop (with provider context exit) and run
try:
await self.stop()
finally:
await self.run()
return
# Wait for VM to be ready and reconnect interface
try:
self.logger.info("Waiting for VM to be ready after restart...")
if self.provider_type == VMProviderType.LUMIER:
max_retries = 60
retry_delay = 3
else:
max_retries = 30
retry_delay = 2
ip_address = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
self.logger.info(f"Re-initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name,
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
),
)
self.logger.info("Connecting to WebSocket interface after restart...")
await self._interface.wait_for_ready(timeout=30)
self.logger.info("Computer reconnected and ready after restart")
except Exception as e:
self.logger.error(f"Failed to reconnect after restart: {e}")
# Try a full reset if reconnection failed
try:
await self.stop()
finally:
await self.run()
# @property
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
"""Get the IP address of the VM or localhost if using host computer server.

View File

@@ -2,7 +2,9 @@
import abc
from enum import StrEnum
from typing import Dict, List, Optional, Any, AsyncContextManager
from typing import Dict, Optional, Any, AsyncContextManager
from .types import ListVMsResponse
class VMProviderType(StrEnum):
@@ -42,8 +44,13 @@ class BaseVMProvider(AsyncContextManager):
pass
@abc.abstractmethod
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
async def list_vms(self) -> ListVMsResponse:
"""List all available VMs.
Returns:
ListVMsResponse: A list of minimal VM objects as defined in
`computer.providers.types.MinimalVM`.
"""
pass
@abc.abstractmethod
@@ -76,6 +83,20 @@ class BaseVMProvider(AsyncContextManager):
"""
pass
@abc.abstractmethod
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Restart a VM by name.
Args:
name: Name of the VM to restart
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM restart status and information
"""
pass
@abc.abstractmethod
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.

View File

@@ -1,12 +1,18 @@
"""Cloud VM provider implementation.
"""Cloud VM provider implementation using CUA Public API.
This module contains a stub implementation for a future cloud VM provider.
Implements the following public API endpoints:
- GET /v1/vms
- POST /v1/vms/:name/start
- POST /v1/vms/:name/stop
- POST /v1/vms/:name/restart
"""
import logging
from typing import Dict, List, Optional, Any
from ..base import BaseVMProvider, VMProviderType
from ..types import ListVMsResponse, MinimalVM
# Setup logging
logger = logging.getLogger(__name__)
@@ -14,6 +20,10 @@ logger = logging.getLogger(__name__)
import asyncio
import aiohttp
from urllib.parse import urlparse
import os
DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai")
class CloudProvider(BaseVMProvider):
"""Cloud VM Provider implementation."""
@@ -21,6 +31,7 @@ class CloudProvider(BaseVMProvider):
self,
api_key: str,
verbose: bool = False,
api_base: Optional[str] = None,
**kwargs,
):
"""
@@ -32,6 +43,7 @@ class CloudProvider(BaseVMProvider):
assert api_key, "api_key required for CloudProvider"
self.api_key = api_key
self.verbose = verbose
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
@property
def provider_type(self) -> VMProviderType:
@@ -44,24 +56,162 @@ class CloudProvider(BaseVMProvider):
pass
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM VNC URL by name using the cloud API."""
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
"""Get VM information by querying the VM status endpoint.
async def list_vms(self) -> List[Dict[str, Any]]:
logger.warning("CloudProvider.list_vms is not implemented")
return []
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
- Probe https://{hostname}:8443/status with a short timeout
- If JSON contains a "status" field, return it; otherwise infer
- Fallback to DNS resolve check to distinguish unknown vs not_found
"""
hostname = await self.get_ip(name=name)
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
# logger.warning("CloudProvider.run_vm is not implemented")
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
# Try HTTPS probe to the computer-server status endpoint (8443)
try:
timeout = aiohttp.ClientTimeout(total=3)
async with aiohttp.ClientSession(timeout=timeout) as session:
url = f"https://{hostname}:8443/status"
async with session.get(url, allow_redirects=False) as resp:
status_code = resp.status
vm_status: str
vm_os_type: Optional[str] = None
if status_code == 200:
try:
data = await resp.json(content_type=None)
vm_status = str(data.get("status", "ok"))
vm_os_type = str(data.get("os_type"))
except Exception:
vm_status = "unknown"
elif status_code < 500:
vm_status = "unknown"
else:
vm_status = "unknown"
return {
"name": name,
"status": "running" if vm_status == "ok" else vm_status,
"api_url": f"https://{hostname}:8443",
"os_type": vm_os_type,
}
except Exception:
return {"name": name, "status": "not_found", "api_url": f"https://{hostname}:8443"}
async def list_vms(self) -> ListVMsResponse:
url = f"{self.api_base}/v1/vms"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
try:
data = await resp.json(content_type=None)
except Exception:
text = await resp.text()
logger.error(f"Failed to parse list_vms JSON: {text}")
return []
if isinstance(data, list):
# Enrich with convenience URLs when possible.
enriched: List[Dict[str, Any]] = []
for item in data:
vm = dict(item) if isinstance(item, dict) else {}
name = vm.get("name")
password = vm.get("password")
if isinstance(name, str) and name:
host = f"{name}.containers.cloud.trycua.com"
# api_url: always set if missing
if not vm.get("api_url"):
vm["api_url"] = f"https://{host}:8443"
# vnc_url: only when password available
if not vm.get("vnc_url") and isinstance(password, str) and password:
vm[
"vnc_url"
] = f"https://{host}/vnc.html?autoconnect=true&password={password}"
enriched.append(vm)
return enriched # type: ignore[return-value]
logger.warning("Unexpected response for list_vms; expected list")
return []
elif resp.status == 401:
logger.error("Unauthorized: invalid CUA API key for list_vms")
return []
else:
text = await resp.text()
logger.error(f"list_vms failed: HTTP {resp.status} - {text}")
return []
async def run_vm(self, name: str, image: Optional[str] = None, run_opts: Optional[Dict[str, Any]] = None, storage: Optional[str] = None) -> Dict[str, Any]:
"""Start a VM via public API. Returns a minimal status."""
url = f"{self.api_base}/v1/vms/{name}/start"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 201, 202, 204):
return {"name": name, "status": "starting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
"""Stop a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/stop"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"stopping"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "stopping"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Restart a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/restart"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"restarting"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "restarting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.update_vm is not implemented")
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
logger.warning("CloudProvider.update_vm is not implemented via public API")
return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"}
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""

View File

@@ -425,6 +425,9 @@ class DockerProvider(BaseVMProvider):
"provider": "docker"
}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("DockerProvider does not support restarting VMs.")
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.

View File

@@ -486,6 +486,9 @@ class LumeProvider(BaseVMProvider):
"""Update VM configuration."""
return self._lume_api_update(name, update_opts, debug=self.verbose)
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("LumeProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.

View File

@@ -836,6 +836,9 @@ class LumierProvider(BaseVMProvider):
logger.error(error_msg)
return error_msg
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("LumierProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.

View File

@@ -0,0 +1,36 @@
"""Shared provider type definitions for VM metadata and responses.
These base types describe the common shape of objects returned by provider
methods like `list_vms()`.
"""
from __future__ import annotations
from typing import Literal, TypedDict, NotRequired
# Core status values per product docs
VMStatus = Literal[
"pending", # VM deployment in progress
"running", # VM is active and accessible
"stopped", # VM is stopped but not terminated
"terminated", # VM has been permanently destroyed
"failed", # VM deployment or operation failed
]
OSType = Literal["macos", "linux", "windows"]
class MinimalVM(TypedDict):
"""Minimal VM object shape returned by list calls.
Providers may include additional fields. Optional fields below are
common extensions some providers expose or that callers may compute.
"""
name: str
status: VMStatus
# Not always included by all providers
password: NotRequired[str]
vnc_url: NotRequired[str]
api_url: NotRequired[str]
# Convenience alias for list_vms() responses
ListVMsResponse = list[MinimalVM]

View File

@@ -390,6 +390,9 @@ class WinSandboxProvider(BaseVMProvider):
"error": "Windows Sandbox does not support runtime configuration updates. "
"Please stop and restart the sandbox with new configuration."
}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("WinSandboxProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.