Merge branch 'main' into feat/gemini-2_5-cua

This commit is contained in:
Dillon DuPont
2025-10-14 11:33:24 -04:00
56 changed files with 2562 additions and 627 deletions

View File

@@ -226,6 +226,13 @@ Examples:
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
)
parser.add_argument(
"--provider",
choices=["cloud", "lume", "winsandbox", "docker"],
default="cloud",
help="Computer provider to use: cloud (default), lume, winsandbox, or docker"
)
parser.add_argument(
"--images",
type=int,
@@ -257,6 +264,12 @@ Examples:
help="Initial prompt to send to the agent. Leave blank for interactive mode."
)
parser.add_argument(
"--prompt-file",
type=Path,
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt."
)
parser.add_argument(
"--predict-click",
dest="predict_click",
@@ -289,33 +302,35 @@ Examples:
container_name = os.getenv("CUA_CONTAINER_NAME")
cua_api_key = os.getenv("CUA_API_KEY")
# Prompt for missing environment variables
# Prompt for missing environment variables (container name always required)
if not container_name:
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
if not cua_api_key:
if args.provider == "cloud":
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
else:
container_name = "cli-sandbox"
# Only require API key for cloud provider
if args.provider == "cloud" and not cua_api_key:
print_colored("CUA_API_KEY not set.", dim=True)
cua_api_key = input("Enter your CUA API key: ").strip()
if not cua_api_key:
print_colored("❌ API key is required.")
print_colored("❌ API key is required for cloud provider.")
sys.exit(1)
# Check for provider-specific API keys based on model
provider_api_keys = {
"openai/": "OPENAI_API_KEY",
"anthropic/": "ANTHROPIC_API_KEY",
"omniparser+": "OPENAI_API_KEY",
"omniparser+": "ANTHROPIC_API_KEY",
}
# Find matching provider and check for API key
for prefix, env_var in provider_api_keys.items():
if args.model.startswith(prefix):
if prefix in args.model:
if not os.getenv(env_var):
print_colored(f"{env_var} not set.", dim=True)
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
@@ -335,13 +350,25 @@ Examples:
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
sys.exit(1)
# Resolve provider -> os_type, provider_type, api key requirement
provider_map = {
"cloud": ("linux", "cloud", True),
"lume": ("macos", "lume", False),
"winsandbox": ("windows", "winsandbox", False),
"docker": ("linux", "docker", False),
}
os_type, provider_type, needs_api_key = provider_map[args.provider]
computer_kwargs = {
"os_type": os_type,
"provider_type": provider_type,
"name": container_name,
}
if needs_api_key:
computer_kwargs["api_key"] = cua_api_key # type: ignore
# Create computer instance
async with Computer(
os_type="linux",
provider_type="cloud",
name=container_name,
api_key=cua_api_key
) as computer:
async with Computer(**computer_kwargs) as computer: # type: ignore
# Create agent
agent_kwargs = {
@@ -442,8 +469,17 @@ Examples:
# Done
sys.exit(0)
# Resolve initial prompt from --prompt-file or --prompt
initial_prompt = args.prompt or ""
if args.prompt_file:
try:
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
except Exception as e:
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
sys.exit(1)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)

View File

@@ -68,7 +68,7 @@ cli = [
"yaspin>=3.1.0",
]
hud = [
"hud-python==0.4.26",
"hud-python==0.4.52",
]
gemini = [
"google-genai>=1.41.0",
@@ -91,7 +91,7 @@ all = [
# cli requirements
"yaspin>=3.1.0",
# hud requirements
"hud-python==0.4.26",
"hud-python==0.4.52",
# gemini requirements
"google-genai>=1.41.0",
]

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer-server"
version = "0.1.0"
version = "0.1.24"
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
authors = [
{ name = "TryCua", email = "gh@trycua.com" }

View File

@@ -1,3 +1,4 @@
import traceback
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
import asyncio
from .models import Computer as ComputerConfig, Display
@@ -451,6 +452,7 @@ class Computer:
raise RuntimeError(f"VM failed to become ready: {wait_error}")
except Exception as e:
self.logger.error(f"Failed to initialize computer: {e}")
self.logger.error(traceback.format_exc())
raise RuntimeError(f"Failed to initialize computer: {e}")
try:
@@ -558,6 +560,102 @@ class Computer:
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
return
async def start(self) -> None:
"""Start the computer."""
await self.run()
async def restart(self) -> None:
"""Restart the computer.
If using a VM provider that supports restart, this will issue a restart
without tearing down the provider context, then reconnect the interface.
Falls back to stop()+run() when a provider restart is not available.
"""
# Host computer server: just disconnect and run again
if self.use_host_computer_server:
try:
await self.disconnect()
finally:
await self.run()
return
# If no VM provider context yet, fall back to full run
if not getattr(self, "_provider_context", None) or self.config.vm_provider is None:
self.logger.info("No provider context active; performing full restart via run()")
await self.run()
return
# Gracefully close current interface connection if present
if self._interface:
try:
self._interface.close()
except Exception as e:
self.logger.debug(f"Error closing interface prior to restart: {e}")
# Attempt provider-level restart if implemented
try:
storage_param = "ephemeral" if self.ephemeral else self.storage
if hasattr(self.config.vm_provider, "restart_vm"):
self.logger.info(f"Restarting VM {self.config.name} via provider...")
await self.config.vm_provider.restart_vm(name=self.config.name, storage=storage_param)
else:
# Fallback: stop then start without leaving provider context
self.logger.info(f"Provider has no restart_vm; performing stop+start for {self.config.name}...")
await self.config.vm_provider.stop_vm(name=self.config.name, storage=storage_param)
await self.config.vm_provider.run_vm(image=self.image, name=self.config.name, run_opts={}, storage=storage_param)
except Exception as e:
self.logger.error(f"Failed to restart VM via provider: {e}")
# As a last resort, do a full stop (with provider context exit) and run
try:
await self.stop()
finally:
await self.run()
return
# Wait for VM to be ready and reconnect interface
try:
self.logger.info("Waiting for VM to be ready after restart...")
if self.provider_type == VMProviderType.LUMIER:
max_retries = 60
retry_delay = 3
else:
max_retries = 30
retry_delay = 2
ip_address = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
self.logger.info(f"Re-initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name,
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
),
)
self.logger.info("Connecting to WebSocket interface after restart...")
await self._interface.wait_for_ready(timeout=30)
self.logger.info("Computer reconnected and ready after restart")
except Exception as e:
self.logger.error(f"Failed to reconnect after restart: {e}")
# Try a full reset if reconnection failed
try:
await self.stop()
finally:
await self.run()
# @property
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
"""Get the IP address of the VM or localhost if using host computer server.

View File

@@ -2,7 +2,9 @@
import abc
from enum import StrEnum
from typing import Dict, List, Optional, Any, AsyncContextManager
from typing import Dict, Optional, Any, AsyncContextManager
from .types import ListVMsResponse
class VMProviderType(StrEnum):
@@ -42,8 +44,13 @@ class BaseVMProvider(AsyncContextManager):
pass
@abc.abstractmethod
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
async def list_vms(self) -> ListVMsResponse:
"""List all available VMs.
Returns:
ListVMsResponse: A list of minimal VM objects as defined in
`computer.providers.types.MinimalVM`.
"""
pass
@abc.abstractmethod
@@ -76,6 +83,20 @@ class BaseVMProvider(AsyncContextManager):
"""
pass
@abc.abstractmethod
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Restart a VM by name.
Args:
name: Name of the VM to restart
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM restart status and information
"""
pass
@abc.abstractmethod
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.

View File

@@ -1,12 +1,18 @@
"""Cloud VM provider implementation.
"""Cloud VM provider implementation using CUA Public API.
This module contains a stub implementation for a future cloud VM provider.
Implements the following public API endpoints:
- GET /v1/vms
- POST /v1/vms/:name/start
- POST /v1/vms/:name/stop
- POST /v1/vms/:name/restart
"""
import logging
from typing import Dict, List, Optional, Any
from ..base import BaseVMProvider, VMProviderType
from ..types import ListVMsResponse, MinimalVM
# Setup logging
logger = logging.getLogger(__name__)
@@ -14,6 +20,10 @@ logger = logging.getLogger(__name__)
import asyncio
import aiohttp
from urllib.parse import urlparse
import os
DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai")
class CloudProvider(BaseVMProvider):
"""Cloud VM Provider implementation."""
@@ -21,6 +31,7 @@ class CloudProvider(BaseVMProvider):
self,
api_key: str,
verbose: bool = False,
api_base: Optional[str] = None,
**kwargs,
):
"""
@@ -32,6 +43,7 @@ class CloudProvider(BaseVMProvider):
assert api_key, "api_key required for CloudProvider"
self.api_key = api_key
self.verbose = verbose
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
@property
def provider_type(self) -> VMProviderType:
@@ -44,24 +56,162 @@ class CloudProvider(BaseVMProvider):
pass
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM VNC URL by name using the cloud API."""
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
"""Get VM information by querying the VM status endpoint.
async def list_vms(self) -> List[Dict[str, Any]]:
logger.warning("CloudProvider.list_vms is not implemented")
return []
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
- Probe https://{hostname}:8443/status with a short timeout
- If JSON contains a "status" field, return it; otherwise infer
- Fallback to DNS resolve check to distinguish unknown vs not_found
"""
hostname = await self.get_ip(name=name)
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
# logger.warning("CloudProvider.run_vm is not implemented")
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
# Try HTTPS probe to the computer-server status endpoint (8443)
try:
timeout = aiohttp.ClientTimeout(total=3)
async with aiohttp.ClientSession(timeout=timeout) as session:
url = f"https://{hostname}:8443/status"
async with session.get(url, allow_redirects=False) as resp:
status_code = resp.status
vm_status: str
vm_os_type: Optional[str] = None
if status_code == 200:
try:
data = await resp.json(content_type=None)
vm_status = str(data.get("status", "ok"))
vm_os_type = str(data.get("os_type"))
except Exception:
vm_status = "unknown"
elif status_code < 500:
vm_status = "unknown"
else:
vm_status = "unknown"
return {
"name": name,
"status": "running" if vm_status == "ok" else vm_status,
"api_url": f"https://{hostname}:8443",
"os_type": vm_os_type,
}
except Exception:
return {"name": name, "status": "not_found", "api_url": f"https://{hostname}:8443"}
async def list_vms(self) -> ListVMsResponse:
url = f"{self.api_base}/v1/vms"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as resp:
if resp.status == 200:
try:
data = await resp.json(content_type=None)
except Exception:
text = await resp.text()
logger.error(f"Failed to parse list_vms JSON: {text}")
return []
if isinstance(data, list):
# Enrich with convenience URLs when possible.
enriched: List[Dict[str, Any]] = []
for item in data:
vm = dict(item) if isinstance(item, dict) else {}
name = vm.get("name")
password = vm.get("password")
if isinstance(name, str) and name:
host = f"{name}.containers.cloud.trycua.com"
# api_url: always set if missing
if not vm.get("api_url"):
vm["api_url"] = f"https://{host}:8443"
# vnc_url: only when password available
if not vm.get("vnc_url") and isinstance(password, str) and password:
vm[
"vnc_url"
] = f"https://{host}/vnc.html?autoconnect=true&password={password}"
enriched.append(vm)
return enriched # type: ignore[return-value]
logger.warning("Unexpected response for list_vms; expected list")
return []
elif resp.status == 401:
logger.error("Unauthorized: invalid CUA API key for list_vms")
return []
else:
text = await resp.text()
logger.error(f"list_vms failed: HTTP {resp.status} - {text}")
return []
async def run_vm(self, name: str, image: Optional[str] = None, run_opts: Optional[Dict[str, Any]] = None, storage: Optional[str] = None) -> Dict[str, Any]:
"""Start a VM via public API. Returns a minimal status."""
url = f"{self.api_base}/v1/vms/{name}/start"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 201, 202, 204):
return {"name": name, "status": "starting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
"""Stop a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/stop"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"stopping"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "stopping"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Restart a VM via public API."""
url = f"{self.api_base}/v1/vms/{name}/restart"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Accept": "application/json",
}
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers) as resp:
if resp.status in (200, 202):
# Spec says 202 with {"status":"restarting"}
body_status: Optional[str] = None
try:
data = await resp.json(content_type=None)
body_status = data.get("status") if isinstance(data, dict) else None
except Exception:
body_status = None
return {"name": name, "status": body_status or "restarting"}
elif resp.status == 404:
return {"name": name, "status": "not_found"}
elif resp.status == 401:
return {"name": name, "status": "unauthorized"}
else:
text = await resp.text()
return {"name": name, "status": "error", "message": text}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.update_vm is not implemented")
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
logger.warning("CloudProvider.update_vm is not implemented via public API")
return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"}
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""

View File

@@ -36,7 +36,7 @@ class DockerProvider(BaseVMProvider):
"""
def __init__(
self,
self,
port: Optional[int] = 8000,
host: str = "localhost",
storage: Optional[str] = None,
@@ -47,13 +47,16 @@ class DockerProvider(BaseVMProvider):
vnc_port: Optional[int] = 6901,
):
"""Initialize the Docker VM Provider.
Args:
port: Currently unused (VM provider port)
host: Hostname for the API server (default: localhost)
storage: Path for persistent VM storage
shared_path: Path for shared folder between host and container
image: Docker image to use (default: "trycua/cua-ubuntu:latest")
Supported images:
- "trycua/cua-ubuntu:latest" (Kasm-based)
- "trycua/cua-docker-xfce:latest" (vanilla XFCE)
verbose: Enable verbose logging
ephemeral: Use ephemeral (temporary) storage
vnc_port: Port for VNC interface (default: 6901)
@@ -62,19 +65,35 @@ class DockerProvider(BaseVMProvider):
self.api_port = 8000
self.vnc_port = vnc_port
self.ephemeral = ephemeral
# Handle ephemeral storage (temporary directory)
if ephemeral:
self.storage = "ephemeral"
else:
self.storage = storage
self.shared_path = shared_path
self.image = image
self.verbose = verbose
self._container_id = None
self._running_containers = {} # Track running containers by name
# Detect image type and configure user directory accordingly
self._detect_image_config()
def _detect_image_config(self):
"""Detect image type and configure paths accordingly."""
# Detect if this is a docker-xfce image or Kasm image
if "docker-xfce" in self.image.lower() or "xfce" in self.image.lower():
self._home_dir = "/home/cua"
self._image_type = "docker-xfce"
logger.info(f"Detected docker-xfce image: using {self._home_dir}")
else:
# Default to Kasm configuration
self._home_dir = "/home/kasm-user"
self._image_type = "kasm"
logger.info(f"Detected Kasm image: using {self._home_dir}")
@property
def provider_type(self) -> VMProviderType:
"""Return the provider type."""
@@ -277,12 +296,13 @@ class DockerProvider(BaseVMProvider):
# Add volume mounts if storage is specified
storage_path = storage or self.storage
if storage_path and storage_path != "ephemeral":
# Mount storage directory
cmd.extend(["-v", f"{storage_path}:/home/kasm-user/storage"])
# Mount storage directory using detected home directory
cmd.extend(["-v", f"{storage_path}:{self._home_dir}/storage"])
# Add shared path if specified
if self.shared_path:
cmd.extend(["-v", f"{self.shared_path}:/home/kasm-user/shared"])
# Mount shared directory using detected home directory
cmd.extend(["-v", f"{self.shared_path}:{self._home_dir}/shared"])
# Add environment variables
cmd.extend(["-e", "VNC_PW=password"]) # Set VNC password
@@ -405,6 +425,9 @@ class DockerProvider(BaseVMProvider):
"provider": "docker"
}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("DockerProvider does not support restarting VMs.")
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.

View File

@@ -486,6 +486,9 @@ class LumeProvider(BaseVMProvider):
"""Update VM configuration."""
return self._lume_api_update(name, update_opts, debug=self.verbose)
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("LumeProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.

View File

@@ -836,6 +836,9 @@ class LumierProvider(BaseVMProvider):
logger.error(error_msg)
return error_msg
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("LumierProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.

View File

@@ -0,0 +1,36 @@
"""Shared provider type definitions for VM metadata and responses.
These base types describe the common shape of objects returned by provider
methods like `list_vms()`.
"""
from __future__ import annotations
from typing import Literal, TypedDict, NotRequired
# Core status values per product docs
VMStatus = Literal[
"pending", # VM deployment in progress
"running", # VM is active and accessible
"stopped", # VM is stopped but not terminated
"terminated", # VM has been permanently destroyed
"failed", # VM deployment or operation failed
]
OSType = Literal["macos", "linux", "windows"]
class MinimalVM(TypedDict):
"""Minimal VM object shape returned by list calls.
Providers may include additional fields. Optional fields below are
common extensions some providers expose or that callers may compute.
"""
name: str
status: VMStatus
# Not always included by all providers
password: NotRequired[str]
vnc_url: NotRequired[str]
api_url: NotRequired[str]
# Convenience alias for list_vms() responses
ListVMsResponse = list[MinimalVM]

View File

@@ -390,6 +390,9 @@ class WinSandboxProvider(BaseVMProvider):
"error": "Windows Sandbox does not support runtime configuration updates. "
"Please stop and restart the sandbox with new configuration."
}
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
raise NotImplementedError("WinSandboxProvider does not support restarting VMs.")
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer"
version = "0.4.0"
version = "0.4.8"
description = "Computer-Use Interface (CUI) framework powering Cua"
readme = "README.md"
authors = [