mirror of
https://github.com/trycua/computer.git
synced 2026-01-02 19:40:18 -06:00
Merge branch 'main' into feat/gemini-2_5-cua
This commit is contained in:
@@ -226,6 +226,13 @@ Examples:
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
choices=["cloud", "lume", "winsandbox", "docker"],
|
||||
default="cloud",
|
||||
help="Computer provider to use: cloud (default), lume, winsandbox, or docker"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--images",
|
||||
type=int,
|
||||
@@ -257,6 +264,12 @@ Examples:
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt-file",
|
||||
type=Path,
|
||||
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict-click",
|
||||
dest="predict_click",
|
||||
@@ -289,33 +302,35 @@ Examples:
|
||||
container_name = os.getenv("CUA_CONTAINER_NAME")
|
||||
cua_api_key = os.getenv("CUA_API_KEY")
|
||||
|
||||
# Prompt for missing environment variables
|
||||
# Prompt for missing environment variables (container name always required)
|
||||
if not container_name:
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
|
||||
if not cua_api_key:
|
||||
if args.provider == "cloud":
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
container_name = "cli-sandbox"
|
||||
|
||||
# Only require API key for cloud provider
|
||||
if args.provider == "cloud" and not cua_api_key:
|
||||
print_colored("CUA_API_KEY not set.", dim=True)
|
||||
cua_api_key = input("Enter your CUA API key: ").strip()
|
||||
if not cua_api_key:
|
||||
print_colored("❌ API key is required.")
|
||||
print_colored("❌ API key is required for cloud provider.")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for provider-specific API keys based on model
|
||||
provider_api_keys = {
|
||||
"openai/": "OPENAI_API_KEY",
|
||||
"anthropic/": "ANTHROPIC_API_KEY",
|
||||
"omniparser+": "OPENAI_API_KEY",
|
||||
"omniparser+": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
# Find matching provider and check for API key
|
||||
for prefix, env_var in provider_api_keys.items():
|
||||
if args.model.startswith(prefix):
|
||||
if prefix in args.model:
|
||||
if not os.getenv(env_var):
|
||||
print_colored(f"{env_var} not set.", dim=True)
|
||||
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
||||
@@ -335,13 +350,25 @@ Examples:
|
||||
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve provider -> os_type, provider_type, api key requirement
|
||||
provider_map = {
|
||||
"cloud": ("linux", "cloud", True),
|
||||
"lume": ("macos", "lume", False),
|
||||
"winsandbox": ("windows", "winsandbox", False),
|
||||
"docker": ("linux", "docker", False),
|
||||
}
|
||||
os_type, provider_type, needs_api_key = provider_map[args.provider]
|
||||
|
||||
computer_kwargs = {
|
||||
"os_type": os_type,
|
||||
"provider_type": provider_type,
|
||||
"name": container_name,
|
||||
}
|
||||
if needs_api_key:
|
||||
computer_kwargs["api_key"] = cua_api_key # type: ignore
|
||||
|
||||
# Create computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=container_name,
|
||||
api_key=cua_api_key
|
||||
) as computer:
|
||||
async with Computer(**computer_kwargs) as computer: # type: ignore
|
||||
|
||||
# Create agent
|
||||
agent_kwargs = {
|
||||
@@ -442,8 +469,17 @@ Examples:
|
||||
# Done
|
||||
sys.exit(0)
|
||||
|
||||
# Resolve initial prompt from --prompt-file or --prompt
|
||||
initial_prompt = args.prompt or ""
|
||||
if args.prompt_file:
|
||||
try:
|
||||
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
||||
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ cli = [
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
hud = [
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
]
|
||||
gemini = [
|
||||
"google-genai>=1.41.0",
|
||||
@@ -91,7 +91,7 @@ all = [
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
# hud requirements
|
||||
"hud-python==0.4.26",
|
||||
"hud-python==0.4.52",
|
||||
# gemini requirements
|
||||
"google-genai>=1.41.0",
|
||||
]
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer-server"
|
||||
version = "0.1.0"
|
||||
version = "0.1.24"
|
||||
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
|
||||
import asyncio
|
||||
from .models import Computer as ComputerConfig, Display
|
||||
@@ -451,6 +452,7 @@ class Computer:
|
||||
raise RuntimeError(f"VM failed to become ready: {wait_error}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize computer: {e}")
|
||||
self.logger.error(traceback.format_exc())
|
||||
raise RuntimeError(f"Failed to initialize computer: {e}")
|
||||
|
||||
try:
|
||||
@@ -558,6 +560,102 @@ class Computer:
|
||||
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
|
||||
return
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the computer."""
|
||||
await self.run()
|
||||
|
||||
async def restart(self) -> None:
|
||||
"""Restart the computer.
|
||||
|
||||
If using a VM provider that supports restart, this will issue a restart
|
||||
without tearing down the provider context, then reconnect the interface.
|
||||
Falls back to stop()+run() when a provider restart is not available.
|
||||
"""
|
||||
# Host computer server: just disconnect and run again
|
||||
if self.use_host_computer_server:
|
||||
try:
|
||||
await self.disconnect()
|
||||
finally:
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# If no VM provider context yet, fall back to full run
|
||||
if not getattr(self, "_provider_context", None) or self.config.vm_provider is None:
|
||||
self.logger.info("No provider context active; performing full restart via run()")
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# Gracefully close current interface connection if present
|
||||
if self._interface:
|
||||
try:
|
||||
self._interface.close()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error closing interface prior to restart: {e}")
|
||||
|
||||
# Attempt provider-level restart if implemented
|
||||
try:
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
if hasattr(self.config.vm_provider, "restart_vm"):
|
||||
self.logger.info(f"Restarting VM {self.config.name} via provider...")
|
||||
await self.config.vm_provider.restart_vm(name=self.config.name, storage=storage_param)
|
||||
else:
|
||||
# Fallback: stop then start without leaving provider context
|
||||
self.logger.info(f"Provider has no restart_vm; performing stop+start for {self.config.name}...")
|
||||
await self.config.vm_provider.stop_vm(name=self.config.name, storage=storage_param)
|
||||
await self.config.vm_provider.run_vm(image=self.image, name=self.config.name, run_opts={}, storage=storage_param)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to restart VM via provider: {e}")
|
||||
# As a last resort, do a full stop (with provider context exit) and run
|
||||
try:
|
||||
await self.stop()
|
||||
finally:
|
||||
await self.run()
|
||||
return
|
||||
|
||||
# Wait for VM to be ready and reconnect interface
|
||||
try:
|
||||
self.logger.info("Waiting for VM to be ready after restart...")
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
max_retries = 60
|
||||
retry_delay = 3
|
||||
else:
|
||||
max_retries = 30
|
||||
retry_delay = 2
|
||||
ip_address = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
self.logger.info(f"Re-initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
),
|
||||
)
|
||||
|
||||
self.logger.info("Connecting to WebSocket interface after restart...")
|
||||
await self._interface.wait_for_ready(timeout=30)
|
||||
self.logger.info("Computer reconnected and ready after restart")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to reconnect after restart: {e}")
|
||||
# Try a full reset if reconnection failed
|
||||
try:
|
||||
await self.stop()
|
||||
finally:
|
||||
await self.run()
|
||||
|
||||
# @property
|
||||
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
|
||||
"""Get the IP address of the VM or localhost if using host computer server.
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
from typing import Dict, List, Optional, Any, AsyncContextManager
|
||||
from typing import Dict, Optional, Any, AsyncContextManager
|
||||
|
||||
from .types import ListVMsResponse
|
||||
|
||||
|
||||
class VMProviderType(StrEnum):
|
||||
@@ -42,8 +44,13 @@ class BaseVMProvider(AsyncContextManager):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all available VMs."""
|
||||
async def list_vms(self) -> ListVMsResponse:
|
||||
"""List all available VMs.
|
||||
|
||||
Returns:
|
||||
ListVMsResponse: A list of minimal VM objects as defined in
|
||||
`computer.providers.types.MinimalVM`.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -76,6 +83,20 @@ class BaseVMProvider(AsyncContextManager):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Restart a VM by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to restart
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM restart status and information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration.
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Cloud VM provider implementation.
|
||||
"""Cloud VM provider implementation using CUA Public API.
|
||||
|
||||
This module contains a stub implementation for a future cloud VM provider.
|
||||
Implements the following public API endpoints:
|
||||
|
||||
- GET /v1/vms
|
||||
- POST /v1/vms/:name/start
|
||||
- POST /v1/vms/:name/stop
|
||||
- POST /v1/vms/:name/restart
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from ..base import BaseVMProvider, VMProviderType
|
||||
from ..types import ListVMsResponse, MinimalVM
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -14,6 +20,10 @@ logger = logging.getLogger(__name__)
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from urllib.parse import urlparse
|
||||
import os
|
||||
|
||||
|
||||
DEFAULT_API_BASE = os.getenv("CUA_API_BASE", "https://api.cua.ai")
|
||||
|
||||
class CloudProvider(BaseVMProvider):
|
||||
"""Cloud VM Provider implementation."""
|
||||
@@ -21,6 +31,7 @@ class CloudProvider(BaseVMProvider):
|
||||
self,
|
||||
api_key: str,
|
||||
verbose: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -32,6 +43,7 @@ class CloudProvider(BaseVMProvider):
|
||||
assert api_key, "api_key required for CloudProvider"
|
||||
self.api_key = api_key
|
||||
self.verbose = verbose
|
||||
self.api_base = (api_base or DEFAULT_API_BASE).rstrip("/")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
@@ -44,24 +56,162 @@ class CloudProvider(BaseVMProvider):
|
||||
pass
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM VNC URL by name using the cloud API."""
|
||||
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
|
||||
"""Get VM information by querying the VM status endpoint.
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
logger.warning("CloudProvider.list_vms is not implemented")
|
||||
return []
|
||||
- Build hostname via get_ip(name) → "{name}.containers.cloud.trycua.com"
|
||||
- Probe https://{hostname}:8443/status with a short timeout
|
||||
- If JSON contains a "status" field, return it; otherwise infer
|
||||
- Fallback to DNS resolve check to distinguish unknown vs not_found
|
||||
"""
|
||||
hostname = await self.get_ip(name=name)
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
# logger.warning("CloudProvider.run_vm is not implemented")
|
||||
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
|
||||
# Try HTTPS probe to the computer-server status endpoint (8443)
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=3)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
url = f"https://{hostname}:8443/status"
|
||||
async with session.get(url, allow_redirects=False) as resp:
|
||||
status_code = resp.status
|
||||
vm_status: str
|
||||
vm_os_type: Optional[str] = None
|
||||
if status_code == 200:
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
vm_status = str(data.get("status", "ok"))
|
||||
vm_os_type = str(data.get("os_type"))
|
||||
except Exception:
|
||||
vm_status = "unknown"
|
||||
elif status_code < 500:
|
||||
vm_status = "unknown"
|
||||
else:
|
||||
vm_status = "unknown"
|
||||
return {
|
||||
"name": name,
|
||||
"status": "running" if vm_status == "ok" else vm_status,
|
||||
"api_url": f"https://{hostname}:8443",
|
||||
"os_type": vm_os_type,
|
||||
}
|
||||
except Exception:
|
||||
return {"name": name, "status": "not_found", "api_url": f"https://{hostname}:8443"}
|
||||
|
||||
async def list_vms(self) -> ListVMsResponse:
|
||||
url = f"{self.api_base}/v1/vms"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
if resp.status == 200:
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
except Exception:
|
||||
text = await resp.text()
|
||||
logger.error(f"Failed to parse list_vms JSON: {text}")
|
||||
return []
|
||||
if isinstance(data, list):
|
||||
# Enrich with convenience URLs when possible.
|
||||
enriched: List[Dict[str, Any]] = []
|
||||
for item in data:
|
||||
vm = dict(item) if isinstance(item, dict) else {}
|
||||
name = vm.get("name")
|
||||
password = vm.get("password")
|
||||
if isinstance(name, str) and name:
|
||||
host = f"{name}.containers.cloud.trycua.com"
|
||||
# api_url: always set if missing
|
||||
if not vm.get("api_url"):
|
||||
vm["api_url"] = f"https://{host}:8443"
|
||||
# vnc_url: only when password available
|
||||
if not vm.get("vnc_url") and isinstance(password, str) and password:
|
||||
vm[
|
||||
"vnc_url"
|
||||
] = f"https://{host}/vnc.html?autoconnect=true&password={password}"
|
||||
enriched.append(vm)
|
||||
return enriched # type: ignore[return-value]
|
||||
logger.warning("Unexpected response for list_vms; expected list")
|
||||
return []
|
||||
elif resp.status == 401:
|
||||
logger.error("Unauthorized: invalid CUA API key for list_vms")
|
||||
return []
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"list_vms failed: HTTP {resp.status} - {text}")
|
||||
return []
|
||||
|
||||
async def run_vm(self, name: str, image: Optional[str] = None, run_opts: Optional[Dict[str, Any]] = None, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Start a VM via public API. Returns a minimal status."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/start"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 201, 202, 204):
|
||||
return {"name": name, "status": "starting"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
|
||||
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
|
||||
"""Stop a VM via public API."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/stop"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 202):
|
||||
# Spec says 202 with {"status":"stopping"}
|
||||
body_status: Optional[str] = None
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
body_status = data.get("status") if isinstance(data, dict) else None
|
||||
except Exception:
|
||||
body_status = None
|
||||
return {"name": name, "status": body_status or "stopping"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Restart a VM via public API."""
|
||||
url = f"{self.api_base}/v1/vms/{name}/restart"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers) as resp:
|
||||
if resp.status in (200, 202):
|
||||
# Spec says 202 with {"status":"restarting"}
|
||||
body_status: Optional[str] = None
|
||||
try:
|
||||
data = await resp.json(content_type=None)
|
||||
body_status = data.get("status") if isinstance(data, dict) else None
|
||||
except Exception:
|
||||
body_status = None
|
||||
return {"name": name, "status": body_status or "restarting"}
|
||||
elif resp.status == 404:
|
||||
return {"name": name, "status": "not_found"}
|
||||
elif resp.status == 401:
|
||||
return {"name": name, "status": "unauthorized"}
|
||||
else:
|
||||
text = await resp.text()
|
||||
return {"name": name, "status": "error", "message": text}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.update_vm is not implemented")
|
||||
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
|
||||
logger.warning("CloudProvider.update_vm is not implemented via public API")
|
||||
return {"name": name, "status": "unchanged", "message": "update_vm not supported by public API"}
|
||||
|
||||
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""
|
||||
|
||||
@@ -36,7 +36,7 @@ class DockerProvider(BaseVMProvider):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
port: Optional[int] = 8000,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
@@ -47,13 +47,16 @@ class DockerProvider(BaseVMProvider):
|
||||
vnc_port: Optional[int] = 6901,
|
||||
):
|
||||
"""Initialize the Docker VM Provider.
|
||||
|
||||
|
||||
Args:
|
||||
port: Currently unused (VM provider port)
|
||||
host: Hostname for the API server (default: localhost)
|
||||
storage: Path for persistent VM storage
|
||||
shared_path: Path for shared folder between host and container
|
||||
image: Docker image to use (default: "trycua/cua-ubuntu:latest")
|
||||
Supported images:
|
||||
- "trycua/cua-ubuntu:latest" (Kasm-based)
|
||||
- "trycua/cua-docker-xfce:latest" (vanilla XFCE)
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Use ephemeral (temporary) storage
|
||||
vnc_port: Port for VNC interface (default: 6901)
|
||||
@@ -62,19 +65,35 @@ class DockerProvider(BaseVMProvider):
|
||||
self.api_port = 8000
|
||||
self.vnc_port = vnc_port
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
|
||||
# Handle ephemeral storage (temporary directory)
|
||||
if ephemeral:
|
||||
self.storage = "ephemeral"
|
||||
else:
|
||||
self.storage = storage
|
||||
|
||||
|
||||
self.shared_path = shared_path
|
||||
self.image = image
|
||||
self.verbose = verbose
|
||||
self._container_id = None
|
||||
self._running_containers = {} # Track running containers by name
|
||||
|
||||
# Detect image type and configure user directory accordingly
|
||||
self._detect_image_config()
|
||||
|
||||
def _detect_image_config(self):
|
||||
"""Detect image type and configure paths accordingly."""
|
||||
# Detect if this is a docker-xfce image or Kasm image
|
||||
if "docker-xfce" in self.image.lower() or "xfce" in self.image.lower():
|
||||
self._home_dir = "/home/cua"
|
||||
self._image_type = "docker-xfce"
|
||||
logger.info(f"Detected docker-xfce image: using {self._home_dir}")
|
||||
else:
|
||||
# Default to Kasm configuration
|
||||
self._home_dir = "/home/kasm-user"
|
||||
self._image_type = "kasm"
|
||||
logger.info(f"Detected Kasm image: using {self._home_dir}")
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Return the provider type."""
|
||||
@@ -277,12 +296,13 @@ class DockerProvider(BaseVMProvider):
|
||||
# Add volume mounts if storage is specified
|
||||
storage_path = storage or self.storage
|
||||
if storage_path and storage_path != "ephemeral":
|
||||
# Mount storage directory
|
||||
cmd.extend(["-v", f"{storage_path}:/home/kasm-user/storage"])
|
||||
|
||||
# Mount storage directory using detected home directory
|
||||
cmd.extend(["-v", f"{storage_path}:{self._home_dir}/storage"])
|
||||
|
||||
# Add shared path if specified
|
||||
if self.shared_path:
|
||||
cmd.extend(["-v", f"{self.shared_path}:/home/kasm-user/shared"])
|
||||
# Mount shared directory using detected home directory
|
||||
cmd.extend(["-v", f"{self.shared_path}:{self._home_dir}/shared"])
|
||||
|
||||
# Add environment variables
|
||||
cmd.extend(["-e", "VNC_PW=password"]) # Set VNC password
|
||||
@@ -405,6 +425,9 @@ class DockerProvider(BaseVMProvider):
|
||||
"provider": "docker"
|
||||
}
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
raise NotImplementedError("DockerProvider does not support restarting VMs.")
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration.
|
||||
|
||||
|
||||
@@ -486,6 +486,9 @@ class LumeProvider(BaseVMProvider):
|
||||
"""Update VM configuration."""
|
||||
return self._lume_api_update(name, update_opts, debug=self.verbose)
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
raise NotImplementedError("LumeProvider does not support restarting VMs.")
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
|
||||
@@ -836,6 +836,9 @@ class LumierProvider(BaseVMProvider):
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
raise NotImplementedError("LumierProvider does not support restarting VMs.")
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
|
||||
36
libs/python/computer/computer/providers/types.py
Normal file
36
libs/python/computer/computer/providers/types.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Shared provider type definitions for VM metadata and responses.
|
||||
|
||||
These base types describe the common shape of objects returned by provider
|
||||
methods like `list_vms()`.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, TypedDict, NotRequired
|
||||
|
||||
# Core status values per product docs
|
||||
VMStatus = Literal[
|
||||
"pending", # VM deployment in progress
|
||||
"running", # VM is active and accessible
|
||||
"stopped", # VM is stopped but not terminated
|
||||
"terminated", # VM has been permanently destroyed
|
||||
"failed", # VM deployment or operation failed
|
||||
]
|
||||
|
||||
OSType = Literal["macos", "linux", "windows"]
|
||||
|
||||
class MinimalVM(TypedDict):
|
||||
"""Minimal VM object shape returned by list calls.
|
||||
|
||||
Providers may include additional fields. Optional fields below are
|
||||
common extensions some providers expose or that callers may compute.
|
||||
"""
|
||||
name: str
|
||||
status: VMStatus
|
||||
# Not always included by all providers
|
||||
password: NotRequired[str]
|
||||
vnc_url: NotRequired[str]
|
||||
api_url: NotRequired[str]
|
||||
|
||||
|
||||
# Convenience alias for list_vms() responses
|
||||
ListVMsResponse = list[MinimalVM]
|
||||
@@ -390,6 +390,9 @@ class WinSandboxProvider(BaseVMProvider):
|
||||
"error": "Windows Sandbox does not support runtime configuration updates. "
|
||||
"Please stop and restart the sandbox with new configuration."
|
||||
}
|
||||
|
||||
async def restart_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
raise NotImplementedError("WinSandboxProvider does not support restarting VMs.")
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer"
|
||||
version = "0.4.0"
|
||||
version = "0.4.8"
|
||||
description = "Computer-Use Interface (CUI) framework powering Cua"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
Reference in New Issue
Block a user