Reorganize lib folder w/typescript and python roots, initialize core library.

This commit is contained in:
Morgan Dean
2025-06-23 10:22:36 -07:00
parent d799bef5d9
commit 0246d18347
322 changed files with 6052 additions and 237 deletions

View File

@@ -0,0 +1,144 @@
<div align="center">
<h1>
<div class="image-wrapper" style="display: inline-block;">
<picture>
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
<img alt="Shows my svg">
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
</h1>
</div>
**cua-computer** is a Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). Computer relies on [Lume](https://github.com/trycua/lume) for creating and managing sandbox environments.
### Get started with Computer
<div align="center">
<img src="../../img/computer.png"/>
</div>
```python
from computer import Computer
computer = Computer(os_type="macos", display="1024x768", memory="8GB", cpu="4")
try:
await computer.run()
screenshot = await computer.interface.screenshot()
with open("screenshot.png", "wb") as f:
f.write(screenshot)
await computer.interface.move_cursor(100, 100)
await computer.interface.left_click()
await computer.interface.right_click(300, 300)
await computer.interface.double_click(400, 400)
await computer.interface.type("Hello, World!")
await computer.interface.press_key("enter")
await computer.interface.set_clipboard("Test clipboard")
content = await computer.interface.copy_to_clipboard()
print(f"Clipboard content: {content}")
finally:
await computer.stop()
```
## Install
To install the Computer-Use Interface (CUI):
```bash
pip install "cua-computer[all]"
```
The `cua-computer` PyPi package pulls automatically the latest executable version of Lume through [pylume](https://github.com/trycua/pylume).
## Run
Refer to this notebook for a step-by-step guide on how to use the Computer-Use Interface (CUI):
- [Computer-Use Interface (CUI)](../../notebooks/computer_nb.ipynb)
## Using the Gradio Computer UI
The computer module includes a Gradio UI for creating and sharing demonstration data. We make it easy for people to build community datasets for better computer use models with an upload to Huggingface feature.
```bash
# Install with UI support
pip install "cua-computer[ui]"
```
> **Note:** For precise control of the computer, we recommend using VNC or Screen Sharing instead of the Computer Gradio UI.
### Building and Sharing Demonstrations with Huggingface
Follow these steps to contribute your own demonstrations:
#### 1. Set up Huggingface Access
Set your HF_TOKEN in a .env file or in your environment variables:
```bash
# In .env file
HF_TOKEN=your_huggingface_token
```
#### 2. Launch the Computer UI
```python
# launch_ui.py
from computer.ui.gradio.app import create_gradio_ui
from dotenv import load_dotenv
load_dotenv('.env')
app = create_gradio_ui()
app.launch(share=False)
```
For examples, see [Computer UI Examples](../../examples/computer_ui_examples.py)
#### 3. Record Your Tasks
<details open>
<summary>View demonstration video</summary>
<video src="https://github.com/user-attachments/assets/de3c3477-62fe-413c-998d-4063e48de176" controls width="600"></video>
</details>
Record yourself performing various computer tasks using the UI.
#### 4. Save Your Demonstrations
<details open>
<summary>View demonstration video</summary>
<video src="https://github.com/user-attachments/assets/5ad1df37-026a-457f-8b49-922ae805faef" controls width="600"></video>
</details>
Save each task by picking a descriptive name and adding relevant tags (e.g., "office", "web-browsing", "coding").
#### 5. Record Additional Demonstrations
Repeat steps 3 and 4 until you have a good amount of demonstrations covering different tasks and scenarios.
#### 6. Upload to Huggingface
<details open>
<summary>View demonstration video</summary>
<video src="https://github.com/user-attachments/assets/c586d460-3877-4b5f-a736-3248886d2134" controls width="600"></video>
</details>
Upload your dataset to Huggingface by:
- Naming it as `{your_username}/{dataset_name}`
- Choosing public or private visibility
- Optionally selecting specific tags to upload only tasks with certain tags
#### Examples and Resources
- Example Dataset: [ddupont/test-dataset](https://huggingface.co/datasets/ddupont/test-dataset)
- Find Community Datasets: 🔍 [Browse CUA Datasets on Huggingface](https://huggingface.co/datasets?other=cua)

View File

@@ -0,0 +1,51 @@
"""CUA Computer Interface for cross-platform computer control."""
import logging
import sys
__version__ = "0.1.0"
# Initialize logging
logger = logging.getLogger("computer")
# Initialize telemetry when the package is imported
try:
# Import from core telemetry
from core.telemetry import (
flush,
is_telemetry_enabled,
record_event,
)
# Check if telemetry is enabled
if is_telemetry_enabled():
logger.info("Telemetry is enabled")
# Record package initialization
record_event(
"module_init",
{
"module": "computer",
"version": __version__,
"python_version": sys.version,
},
)
# Flush events to ensure they're sent
flush()
else:
logger.info("Telemetry is disabled")
except ImportError as e:
# Telemetry not available
logger.warning(f"Telemetry not available: {e}")
except Exception as e:
# Other issues with telemetry
logger.warning(f"Error initializing telemetry: {e}")
# Core components
from .computer import Computer
# Provider components
from .providers.base import VMProviderType
__all__ = ["Computer", "VMProviderType"]

View File

@@ -0,0 +1,926 @@
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
import asyncio
from .models import Computer as ComputerConfig, Display
from .interface.factory import InterfaceFactory
import time
from PIL import Image
import io
import re
from .logger import Logger, LogLevel
import json
import logging
from .telemetry import record_computer_initialization
import os
from . import helpers
# Import provider related modules
from .providers.base import VMProviderType
from .providers.factory import VMProviderFactory
OSType = Literal["macos", "linux", "windows"]
class Computer:
"""Computer is the main class for interacting with the computer."""
def create_desktop_from_apps(self, apps):
"""
Create a virtual desktop from a list of app names, returning a DioramaComputer
that proxies Diorama.Interface but uses diorama_cmds via the computer interface.
Args:
apps (list[str]): List of application names to include in the desktop.
Returns:
DioramaComputer: A proxy object with the Diorama interface, but using diorama_cmds.
"""
assert "app-use" in self.experiments, "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
from .diorama_computer import DioramaComputer
return DioramaComputer(self, apps)
def __init__(
self,
display: Union[Display, Dict[str, int], str] = "1024x768",
memory: str = "8GB",
cpu: str = "4",
os_type: OSType = "macos",
name: str = "",
image: str = "macos-sequoia-cua:latest",
shared_directories: Optional[List[str]] = None,
use_host_computer_server: bool = False,
verbosity: Union[int, LogLevel] = logging.INFO,
telemetry_enabled: bool = True,
provider_type: Union[str, VMProviderType] = VMProviderType.LUME,
port: Optional[int] = 7777,
noVNC_port: Optional[int] = 8006,
host: str = os.environ.get("PYLUME_HOST", "localhost"),
storage: Optional[str] = None,
ephemeral: bool = False,
api_key: Optional[str] = None,
experiments: Optional[List[str]] = None
):
"""Initialize a new Computer instance.
Args:
display: The display configuration. Can be:
- A Display object
- A dict with 'width' and 'height'
- A string in format "WIDTHxHEIGHT" (e.g. "1920x1080")
Defaults to "1024x768"
memory: The VM memory allocation. Defaults to "8GB"
cpu: The VM CPU allocation. Defaults to "4"
os_type: The operating system type ('macos' or 'linux')
name: The VM name
image: The VM image name
shared_directories: Optional list of directory paths to share with the VM
use_host_computer_server: If True, target localhost instead of starting a VM
verbosity: Logging level (standard Python logging levels: logging.DEBUG, logging.INFO, etc.)
LogLevel enum values are still accepted for backward compatibility
telemetry_enabled: Whether to enable telemetry tracking. Defaults to True.
provider_type: The VM provider type to use (lume, qemu, cloud)
port: Optional port to use for the VM provider server
noVNC_port: Optional port for the noVNC web interface (Lumier provider)
host: Host to use for VM provider connections (e.g. "localhost", "host.docker.internal")
storage: Optional path for persistent VM storage (Lumier provider)
ephemeral: Whether to use ephemeral storage
api_key: Optional API key for cloud providers
experiments: Optional list of experimental features to enable (e.g. ["app-use"])
"""
self.logger = Logger("computer", verbosity)
self.logger.info("Initializing Computer...")
# Store original parameters
self.image = image
self.port = port
self.noVNC_port = noVNC_port
self.host = host
self.os_type = os_type
self.provider_type = provider_type
self.ephemeral = ephemeral
self.api_key = api_key
self.experiments = experiments or []
if "app-use" in self.experiments:
assert self.os_type == "macos", "App use experiment is only supported on macOS"
# The default is currently to use non-ephemeral storage
if storage and ephemeral and storage != "ephemeral":
raise ValueError("Storage path and ephemeral flag cannot be used together")
# Windows Sandbox always uses ephemeral storage
if self.provider_type == VMProviderType.WINSANDBOX:
if not ephemeral and storage != None and storage != "ephemeral":
self.logger.warning("Windows Sandbox storage is always ephemeral. Setting ephemeral=True.")
self.ephemeral = True
self.storage = "ephemeral"
else:
self.storage = "ephemeral" if ephemeral else storage
# For Lumier provider, store the first shared directory path to use
# for VM file sharing
self.shared_path = None
if shared_directories and len(shared_directories) > 0:
self.shared_path = shared_directories[0]
self.logger.info(f"Using first shared directory for VM file sharing: {self.shared_path}")
# Store telemetry preference
self._telemetry_enabled = telemetry_enabled
# Set initialization flag
self._initialized = False
self._running = False
# Configure root logger
self.verbosity = verbosity
self.logger = Logger("computer", verbosity)
# Configure component loggers with proper hierarchy
self.vm_logger = Logger("computer.vm", verbosity)
self.interface_logger = Logger("computer.interface", verbosity)
if not use_host_computer_server:
if ":" not in image or len(image.split(":")) != 2:
raise ValueError("Image must be in the format <image_name>:<tag>")
if not name:
# Normalize the name to be used for the VM
name = image.replace(":", "_")
# Convert display parameter to Display object
if isinstance(display, str):
# Parse string format "WIDTHxHEIGHT"
match = re.match(r"(\d+)x(\d+)", display)
if not match:
raise ValueError(
"Display string must be in format 'WIDTHxHEIGHT' (e.g. '1024x768')"
)
width, height = map(int, match.groups())
display_config = Display(width=width, height=height)
elif isinstance(display, dict):
display_config = Display(**display)
else:
display_config = display
self.config = ComputerConfig(
image=image.split(":")[0],
tag=image.split(":")[1],
name=name,
display=display_config,
memory=memory,
cpu=cpu,
)
# Initialize VM provider but don't start it yet - we'll do that in run()
self.config.vm_provider = None # Will be initialized in run()
# Store shared directories config
self.shared_directories = shared_directories or []
# Placeholder for VM provider context manager
self._provider_context = None
# Initialize with proper typing - None at first, will be set in run()
self._interface = None
self.use_host_computer_server = use_host_computer_server
# Record initialization in telemetry (if enabled)
if telemetry_enabled:
record_computer_initialization()
else:
self.logger.debug("Telemetry disabled - skipping initialization tracking")
async def __aenter__(self):
"""Start the computer."""
await self.run()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Stop the computer."""
await self.disconnect()
def __enter__(self):
"""Start the computer."""
# Run the event loop to call the async enter method
loop = asyncio.get_event_loop()
loop.run_until_complete(self.__aenter__())
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Stop the computer."""
loop = asyncio.get_event_loop()
loop.run_until_complete(self.__aexit__(exc_type, exc_val, exc_tb))
async def run(self) -> Optional[str]:
"""Initialize the VM and computer interface."""
if TYPE_CHECKING:
from .interface.base import BaseComputerInterface
# If already initialized, just log and return
if hasattr(self, "_initialized") and self._initialized:
self.logger.info("Computer already initialized, skipping initialization")
return
self.logger.info("Starting computer...")
start_time = time.time()
try:
# If using host computer server
if self.use_host_computer_server:
self.logger.info("Using host computer server")
# Set ip_address for host computer server mode
ip_address = "localhost"
# Create the interface with explicit type annotation
from .interface.base import BaseComputerInterface
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
),
)
self.logger.info("Waiting for host computer server to be ready...")
await self._interface.wait_for_ready()
self.logger.info("Host computer server ready")
else:
# Start or connect to VM
self.logger.info(f"Starting VM: {self.image}")
if not self._provider_context:
try:
provider_type_name = self.provider_type.name if isinstance(self.provider_type, VMProviderType) else self.provider_type
self.logger.verbose(f"Initializing {provider_type_name} provider context...")
# Explicitly set provider parameters
storage = "ephemeral" if self.ephemeral else self.storage
verbose = self.verbosity >= LogLevel.DEBUG
ephemeral = self.ephemeral
port = self.port if self.port is not None else 7777
host = self.host if self.host else "localhost"
image = self.image
shared_path = self.shared_path
noVNC_port = self.noVNC_port
# Create VM provider instance with explicit parameters
try:
if self.provider_type == VMProviderType.LUMIER:
self.logger.info(f"Using VM image for Lumier provider: {image}")
if shared_path:
self.logger.info(f"Using shared path for Lumier provider: {shared_path}")
if noVNC_port:
self.logger.info(f"Using noVNC port for Lumier provider: {noVNC_port}")
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
port=port,
host=host,
storage=storage,
shared_path=shared_path,
image=image,
verbose=verbose,
ephemeral=ephemeral,
noVNC_port=noVNC_port,
)
elif self.provider_type == VMProviderType.LUME:
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
port=port,
host=host,
storage=storage,
verbose=verbose,
ephemeral=ephemeral,
)
elif self.provider_type == VMProviderType.CLOUD:
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
api_key=self.api_key,
verbose=verbose,
)
elif self.provider_type == VMProviderType.WINSANDBOX:
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
port=port,
host=host,
storage=storage,
verbose=verbose,
ephemeral=ephemeral,
)
else:
raise ValueError(f"Unsupported provider type: {self.provider_type}")
self._provider_context = await self.config.vm_provider.__aenter__()
self.logger.verbose("VM provider context initialized successfully")
except ImportError as ie:
self.logger.error(f"Failed to import provider dependencies: {ie}")
if str(ie).find("lume") >= 0 and str(ie).find("lumier") < 0:
self.logger.error("Please install with: pip install cua-computer[lume]")
elif str(ie).find("lumier") >= 0 or str(ie).find("docker") >= 0:
self.logger.error("Please install with: pip install cua-computer[lumier] and make sure Docker is installed")
elif str(ie).find("cloud") >= 0:
self.logger.error("Please install with: pip install cua-computer[cloud]")
raise
except Exception as e:
self.logger.error(f"Failed to initialize provider context: {e}")
raise RuntimeError(f"Failed to initialize VM provider: {e}")
# Check if VM exists or create it
is_running = False
try:
if self.config.vm_provider is None:
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
vm = await self.config.vm_provider.get_vm(self.config.name)
self.logger.verbose(f"Found existing VM: {self.config.name}")
is_running = vm.get("status") == "running"
except Exception as e:
self.logger.error(f"VM not found: {self.config.name}")
self.logger.error(f"Error: {e}")
raise RuntimeError(
f"VM {self.config.name} could not be found or created."
)
# Start the VM if it's not running
if not is_running:
self.logger.info(f"VM {self.config.name} is not running, starting it...")
# Convert paths to dictionary format for shared directories
shared_dirs = []
for path in self.shared_directories:
self.logger.verbose(f"Adding shared directory: {path}")
path = os.path.abspath(os.path.expanduser(path))
if os.path.exists(path):
# Add path in format expected by Lume API
shared_dirs.append({
"hostPath": path,
"readOnly": False
})
else:
self.logger.warning(f"Shared directory does not exist: {path}")
# Prepare run options to pass to the provider
run_opts = {}
# Add display information if available
if self.config.display is not None:
display_info = {
"width": self.config.display.width,
"height": self.config.display.height,
}
# Check if scale_factor exists before adding it
if hasattr(self.config.display, "scale_factor"):
display_info["scale_factor"] = self.config.display.scale_factor
run_opts["display"] = display_info
# Add shared directories if available
if self.shared_directories:
run_opts["shared_directories"] = shared_dirs.copy()
# Run the VM with the provider
try:
if self.config.vm_provider is None:
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
# Use the complete run_opts we prepared earlier
# Handle ephemeral storage for run_vm method too
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.run_vm with explicit image parameter
response = await self.config.vm_provider.run_vm(
image=self.image,
name=self.config.name,
run_opts=run_opts,
storage=storage_param
)
self.logger.info(f"VM run response: {response if response else 'None'}")
except Exception as run_error:
self.logger.error(f"Failed to run VM: {run_error}")
raise RuntimeError(f"Failed to start VM: {run_error}")
# Wait for VM to be ready with a valid IP address
self.logger.info("Waiting for VM to be ready with a valid IP address...")
try:
if self.provider_type == VMProviderType.LUMIER:
max_retries = 60 # Increased for Lumier VM startup which takes longer
retry_delay = 3 # 3 seconds between retries for Lumier
else:
max_retries = 30 # Default for other providers
retry_delay = 2 # 2 seconds between retries
self.logger.info(f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready...")
ip = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
# If we get here, we have a valid IP
self.logger.info(f"VM is ready with IP: {ip}")
ip_address = ip
except TimeoutError as timeout_error:
self.logger.error(str(timeout_error))
raise RuntimeError(f"VM startup timed out: {timeout_error}")
except Exception as wait_error:
self.logger.error(f"Error waiting for VM: {wait_error}")
raise RuntimeError(f"VM failed to become ready: {wait_error}")
except Exception as e:
self.logger.error(f"Failed to initialize computer: {e}")
raise RuntimeError(f"Failed to initialize computer: {e}")
try:
# Verify we have a valid IP before initializing the interface
if not ip_address or ip_address == "unknown" or ip_address == "0.0.0.0":
raise RuntimeError(f"Cannot initialize interface - invalid IP address: {ip_address}")
# Initialize the interface using the factory with the specified OS
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
# Pass authentication credentials if using cloud provider
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address
),
)
# Wait for the WebSocket interface to be ready
self.logger.info("Connecting to WebSocket interface...")
try:
# Use a single timeout for the entire connection process
# The VM should already be ready at this point, so we're just establishing the connection
await self._interface.wait_for_ready(timeout=30)
self.logger.info("WebSocket interface connected successfully")
except TimeoutError as e:
self.logger.error(f"Failed to connect to WebSocket interface at {ip_address}")
raise TimeoutError(
f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}"
)
# self.logger.warning(
# f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}, expect missing functionality"
# )
# Create an event to keep the VM running in background if needed
if not self.use_host_computer_server:
self._stop_event = asyncio.Event()
self._keep_alive_task = asyncio.create_task(self._stop_event.wait())
self.logger.info("Computer is ready")
# Set the initialization flag and clear the initializing flag
self._initialized = True
# Set this instance as the default computer for remote decorators
helpers.set_default_computer(self)
self.logger.info("Computer successfully initialized")
except Exception as e:
raise
finally:
# Log initialization time for performance monitoring
duration_ms = (time.time() - start_time) * 1000
self.logger.debug(f"Computer initialization took {duration_ms:.2f}ms")
return
async def disconnect(self) -> None:
"""Disconnect from the computer's WebSocket interface."""
if self._interface:
self._interface.close()
async def stop(self) -> None:
"""Disconnect from the computer's WebSocket interface and stop the computer."""
start_time = time.time()
try:
self.logger.info("Stopping Computer...")
# In VM mode, first explicitly stop the VM, then exit the provider context
if not self.use_host_computer_server and self._provider_context and self.config.vm_provider is not None:
try:
self.logger.info(f"Stopping VM {self.config.name}...")
await self.config.vm_provider.stop_vm(
name=self.config.name,
storage=self.storage # Pass storage explicitly for clarity
)
except Exception as e:
self.logger.error(f"Error stopping VM: {e}")
self.logger.verbose("Closing VM provider context...")
await self.config.vm_provider.__aexit__(None, None, None)
self._provider_context = None
await self.disconnect()
self.logger.info("Computer stopped")
except Exception as e:
self.logger.debug(f"Error during cleanup: {e}") # Log as debug since this might be expected
finally:
# Log stop time for performance monitoring
duration_ms = (time.time() - start_time) * 1000
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
return
# @property
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
"""Get the IP address of the VM or localhost if using host computer server.
This method delegates to the provider's get_ip method, which waits indefinitely
until the VM has a valid IP address.
Args:
max_retries: Unused parameter, kept for backward compatibility
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM or localhost if using host computer server
"""
# For host computer server, always return localhost immediately
if self.use_host_computer_server:
return "127.0.0.1"
# Get IP from the provider - each provider implements its own waiting logic
if self.config.vm_provider is None:
raise RuntimeError("VM provider is not initialized")
# Log that we're waiting for the IP
self.logger.info(f"Waiting for VM {self.config.name} to get an IP address...")
# Call the provider's get_ip method which will wait indefinitely
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.get_ip with explicit image parameter
ip = await self.config.vm_provider.get_ip(
name=self.config.name,
storage=storage_param,
retry_delay=retry_delay
)
# Log success
self.logger.info(f"VM {self.config.name} has IP address: {ip}")
return ip
async def wait_vm_ready(self) -> Optional[Dict[str, Any]]:
"""Wait for VM to be ready with an IP address.
Returns:
VM status information or None if using host computer server.
"""
if self.use_host_computer_server:
return None
timeout = 600 # 10 minutes timeout (increased from 4 minutes)
interval = 2.0 # 2 seconds between checks (increased to reduce API load)
start_time = time.time()
last_status = None
attempts = 0
self.logger.info(f"Waiting for VM {self.config.name} to be ready (timeout: {timeout}s)...")
while time.time() - start_time < timeout:
attempts += 1
elapsed = time.time() - start_time
try:
# Keep polling for VM info
if self.config.vm_provider is None:
self.logger.error("VM provider is not initialized")
vm = None
else:
vm = await self.config.vm_provider.get_vm(self.config.name)
# Log full VM properties for debugging (every 30 attempts)
if attempts % 30 == 0:
self.logger.info(
f"VM properties at attempt {attempts}: {vars(vm) if vm else 'None'}"
)
# Get current status for logging
current_status = getattr(vm, "status", None) if vm else None
if current_status != last_status:
self.logger.info(
f"VM status changed to: {current_status} (after {elapsed:.1f}s)"
)
last_status = current_status
# Check for IP address - ensure it's not None or empty
ip = getattr(vm, "ip_address", None) if vm else None
if ip and ip.strip(): # Check for non-empty string
self.logger.info(
f"VM {self.config.name} got IP address: {ip} (after {elapsed:.1f}s)"
)
return vm
if attempts % 10 == 0: # Log every 10 attempts to avoid flooding
self.logger.info(
f"Still waiting for VM IP address... (elapsed: {elapsed:.1f}s)"
)
else:
self.logger.debug(
f"Waiting for VM IP address... Current IP: {ip}, Status: {current_status}"
)
except Exception as e:
self.logger.warning(f"Error checking VM status (attempt {attempts}): {str(e)}")
# If we've been trying for a while and still getting errors, log more details
if elapsed > 60: # After 1 minute of errors, log more details
self.logger.error(f"Persistent error getting VM status: {str(e)}")
self.logger.info("Trying to get VM list for debugging...")
try:
if self.config.vm_provider is not None:
vms = await self.config.vm_provider.list_vms()
self.logger.info(
f"Available VMs: {[getattr(vm, 'name', None) for vm in vms if hasattr(vm, 'name')]}"
)
except Exception as list_error:
self.logger.error(f"Failed to list VMs: {str(list_error)}")
await asyncio.sleep(interval)
# If we get here, we've timed out
elapsed = time.time() - start_time
self.logger.error(f"VM {self.config.name} not ready after {elapsed:.1f} seconds")
# Try to get final VM status for debugging
try:
if self.config.vm_provider is not None:
vm = await self.config.vm_provider.get_vm(self.config.name)
# VM data is returned as a dictionary from the Lumier provider
status = vm.get('status', 'unknown') if vm else "unknown"
ip = vm.get('ip_address') if vm else None
else:
status = "unknown"
ip = None
self.logger.error(f"Final VM status: {status}, IP: {ip}")
except Exception as e:
self.logger.error(f"Failed to get final VM status: {str(e)}")
raise TimeoutError(
f"VM {self.config.name} not ready after {elapsed:.1f} seconds - IP address not assigned"
)
async def update(self, cpu: Optional[int] = None, memory: Optional[str] = None):
"""Update VM settings."""
self.logger.info(
f"Updating VM settings: CPU={cpu or self.config.cpu}, Memory={memory or self.config.memory}"
)
update_opts = {
"cpu": cpu or int(self.config.cpu),
"memory": memory or self.config.memory
}
if self.config.vm_provider is not None:
await self.config.vm_provider.update_vm(
name=self.config.name,
update_opts=update_opts,
storage=self.storage # Pass storage explicitly for clarity
)
else:
raise RuntimeError("VM provider not initialized")
def get_screenshot_size(self, screenshot: bytes) -> Dict[str, int]:
"""Get the dimensions of a screenshot.
Args:
screenshot: The screenshot bytes
Returns:
Dict[str, int]: Dictionary containing 'width' and 'height' of the image
"""
image = Image.open(io.BytesIO(screenshot))
width, height = image.size
return {"width": width, "height": height}
@property
def interface(self):
"""Get the computer interface for interacting with the VM.
Returns:
The computer interface
"""
if not hasattr(self, "_interface") or self._interface is None:
error_msg = "Computer interface not initialized. Call run() first."
self.logger.error(error_msg)
self.logger.error(
"Make sure to call await computer.run() before using any interface methods."
)
raise RuntimeError(error_msg)
return self._interface
@property
def telemetry_enabled(self) -> bool:
"""Check if telemetry is enabled for this computer instance.
Returns:
bool: True if telemetry is enabled, False otherwise
"""
return self._telemetry_enabled
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert normalized coordinates to screen coordinates.
Args:
x: X coordinate between 0 and 1
y: Y coordinate between 0 and 1
Returns:
tuple[float, float]: Screen coordinates (x, y)
"""
return await self.interface.to_screen_coordinates(x, y)
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screen coordinates to screenshot coordinates.
Args:
x: X coordinate in screen space
y: Y coordinate in screen space
Returns:
tuple[float, float]: (x, y) coordinates in screenshot space
"""
return await self.interface.to_screenshot_coordinates(x, y)
# Add virtual environment management functions to computer interface
async def venv_install(self, venv_name: str, requirements: list[str]) -> tuple[str, str]:
"""Install packages in a virtual environment.
Args:
venv_name: Name of the virtual environment
requirements: List of package requirements to install
Returns:
Tuple of (stdout, stderr) from the installation command
"""
requirements = requirements or []
# Create virtual environment if it doesn't exist
venv_path = f"~/.venvs/{venv_name}"
create_cmd = f"mkdir -p ~/.venvs && python3 -m venv {venv_path}"
# Check if venv exists, if not create it
check_cmd = f"test -d {venv_path} || ({create_cmd})"
_, _ = await self.interface.run_command(check_cmd)
# Install packages
requirements_str = " ".join(requirements)
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
return await self.interface.run_command(install_cmd)
async def venv_cmd(self, venv_name: str, command: str) -> tuple[str, str]:
"""Execute a shell command in a virtual environment.
Args:
venv_name: Name of the virtual environment
command: Shell command to execute in the virtual environment
Returns:
Tuple of (stdout, stderr) from the command execution
"""
venv_path = f"~/.venvs/{venv_name}"
# Check if virtual environment exists
check_cmd = f"test -d {venv_path}"
stdout, stderr = await self.interface.run_command(check_cmd)
if stderr or "test:" in stdout: # venv doesn't exist
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
# Activate virtual environment and run command
full_command = f". {venv_path}/bin/activate && {command}"
return await self.interface.run_command(full_command)
async def venv_exec(self, venv_name: str, python_func, *args, **kwargs):
"""Execute Python function in a virtual environment using source code extraction.
Args:
venv_name: Name of the virtual environment
python_func: A callable function to execute
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function execution, or raises any exception that occurred
"""
import base64
import inspect
import json
import textwrap
try:
# Get function source code using inspect.getsource
source = inspect.getsource(python_func)
# Remove common leading whitespace (dedent)
func_source = textwrap.dedent(source).strip()
# Remove decorators
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
# Get function name for execution
func_name = python_func.__name__
# Serialize args and kwargs as JSON (safer than dill for cross-version compatibility)
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
except OSError as e:
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
except Exception as e:
raise Exception(f"Failed to reconstruct function source: {e}")
# Create Python code that will define and execute the function
python_code = f'''
import json
import traceback
try:
# Define the function from source
{textwrap.indent(func_source, " ")}
# Deserialize args and kwargs from JSON
args_json = """{args_json}"""
kwargs_json = """{kwargs_json}"""
args = json.loads(args_json)
kwargs = json.loads(kwargs_json)
# Execute the function
result = {func_name}(*args, **kwargs)
# Create success output payload
output_payload = {{
"success": True,
"result": result,
"error": None
}}
except Exception as e:
# Create error output payload
output_payload = {{
"success": False,
"result": None,
"error": {{
"type": type(e).__name__,
"message": str(e),
"traceback": traceback.format_exc()
}}
}}
# Serialize the output payload as JSON
import json
output_json = json.dumps(output_payload, default=str)
# Print the JSON output with markers
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
'''
# Encode the Python code in base64 to avoid shell escaping issues
encoded_code = base64.b64encode(python_code.encode('utf-8')).decode('ascii')
# Execute the Python code in the virtual environment
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
stdout, stderr = await self.venv_cmd(venv_name, python_command)
# Parse the output to extract the payload
start_marker = "<<<VENV_EXEC_START>>>"
end_marker = "<<<VENV_EXEC_END>>>"
# Print original stdout
print(stdout[:stdout.find(start_marker)])
if start_marker in stdout and end_marker in stdout:
start_idx = stdout.find(start_marker) + len(start_marker)
end_idx = stdout.find(end_marker)
if start_idx < end_idx:
output_json = stdout[start_idx:end_idx]
try:
# Decode and deserialize the output payload from JSON
output_payload = json.loads(output_json)
except Exception as e:
raise Exception(f"Failed to decode output payload: {e}")
if output_payload["success"]:
return output_payload["result"]
else:
# Recreate and raise the original exception
error_info = output_payload["error"]
error_class = eval(error_info["type"])
raise error_class(error_info["message"])
else:
raise Exception("Invalid output format: markers found but no content between them")
else:
# Fallback: return stdout/stderr if no payload markers found
raise Exception(f"No output payload found. stdout: {stdout}, stderr: {stderr}")

View File

@@ -0,0 +1,104 @@
import asyncio
from .interface.models import KeyType, Key
class DioramaComputer:
"""
A Computer-compatible proxy for Diorama that sends commands over the ComputerInterface.
"""
def __init__(self, computer, apps):
self.computer = computer
self.apps = apps
self.interface = DioramaComputerInterface(computer, apps)
self._initialized = False
async def __aenter__(self):
self._initialized = True
return self
async def run(self):
if not self._initialized:
await self.__aenter__()
return self
class DioramaComputerInterface:
"""
Diorama Interface proxy that sends diorama_cmds via the Computer's interface.
"""
def __init__(self, computer, apps):
self.computer = computer
self.apps = apps
self._scene_size = None
async def _send_cmd(self, action, arguments=None):
arguments = arguments or {}
arguments = {"app_list": self.apps, **arguments}
# Use the computer's interface (must be initialized)
iface = getattr(self.computer, "_interface", None)
if iface is None:
raise RuntimeError("Computer interface not initialized. Call run() first.")
result = await iface.diorama_cmd(action, arguments)
if not result.get("success"):
raise RuntimeError(f"Diorama command failed: {result.get('error')}\n{result.get('trace')}")
return result.get("result")
async def screenshot(self, as_bytes=True):
from PIL import Image
import base64
result = await self._send_cmd("screenshot")
# assume result is a b64 string of an image
img_bytes = base64.b64decode(result)
import io
img = Image.open(io.BytesIO(img_bytes))
self._scene_size = img.size
return img_bytes if as_bytes else img
async def get_screen_size(self):
if not self._scene_size:
await self.screenshot(as_bytes=False)
return {"width": self._scene_size[0], "height": self._scene_size[1]}
async def move_cursor(self, x, y):
await self._send_cmd("move_cursor", {"x": x, "y": y})
async def left_click(self, x=None, y=None):
await self._send_cmd("left_click", {"x": x, "y": y})
async def right_click(self, x=None, y=None):
await self._send_cmd("right_click", {"x": x, "y": y})
async def double_click(self, x=None, y=None):
await self._send_cmd("double_click", {"x": x, "y": y})
async def scroll_up(self, clicks=1):
await self._send_cmd("scroll_up", {"clicks": clicks})
async def scroll_down(self, clicks=1):
await self._send_cmd("scroll_down", {"clicks": clicks})
async def drag_to(self, x, y, duration=0.5):
await self._send_cmd("drag_to", {"x": x, "y": y, "duration": duration})
async def get_cursor_position(self):
return await self._send_cmd("get_cursor_position")
async def type_text(self, text):
await self._send_cmd("type_text", {"text": text})
async def press_key(self, key):
await self._send_cmd("press_key", {"key": key})
async def hotkey(self, *keys):
actual_keys = []
for key in keys:
if isinstance(key, Key):
actual_keys.append(key.value)
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_cmd("hotkey", {"keys": actual_keys})
async def to_screen_coordinates(self, x, y):
return await self._send_cmd("to_screen_coordinates", {"x": x, "y": y})

View File

@@ -0,0 +1,52 @@
"""
Helper functions and decorators for the Computer module.
"""
import logging
import asyncio
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
# Global reference to the default computer instance
_default_computer = None
logger = logging.getLogger(__name__)
def set_default_computer(computer):
"""
Set the default computer instance to be used by the remote decorator.
Args:
computer: The computer instance to use as default
"""
global _default_computer
_default_computer = computer
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
"""
Decorator that wraps a function to be executed remotely via computer.venv_exec
Args:
venv_name: Name of the virtual environment to execute in
computer: The computer instance to use, or "default" to use the globally set default
max_retries: Maximum number of retries for the remote execution
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Determine which computer instance to use
comp = computer if computer != "default" else _default_computer
if comp is None:
raise RuntimeError("No computer instance available. Either specify a computer instance or call set_default_computer() first.")
for i in range(max_retries):
try:
return await comp.venv_exec(venv_name, func, *args, **kwargs)
except Exception as e:
logger.error(f"Attempt {i+1} failed: {e}")
await asyncio.sleep(1)
if i == max_retries - 1:
raise e
return wrapper
return decorator

View File

@@ -0,0 +1,13 @@
"""
Interface package for Computer SDK.
"""
from .factory import InterfaceFactory
from .base import BaseComputerInterface
from .macos import MacOSComputerInterface
__all__ = [
"InterfaceFactory",
"BaseComputerInterface",
"MacOSComputerInterface",
]

View File

@@ -0,0 +1,271 @@
"""Base interface for computer control."""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple, List
from ..logger import Logger, LogLevel
from .models import MouseButton
class BaseComputerInterface(ABC):
"""Base class for computer control interfaces."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
"""Initialize interface.
Args:
ip_address: IP address of the computer to control
username: Username for authentication
password: Password for authentication
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
"""
self.ip_address = ip_address
self.username = username
self.password = password
self.api_key = api_key
self.vm_name = vm_name
self.logger = Logger("cua.interface", LogLevel.NORMAL)
@abstractmethod
async def wait_for_ready(self, timeout: int = 60) -> None:
"""Wait for interface to be ready.
Args:
timeout: Maximum time to wait in seconds
Raises:
TimeoutError: If interface is not ready within timeout
"""
pass
@abstractmethod
def close(self) -> None:
"""Close the interface connection."""
pass
def force_close(self) -> None:
"""Force close the interface connection.
By default, this just calls close(), but subclasses can override
to provide more forceful cleanup.
"""
self.close()
# Mouse Actions
@abstractmethod
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
"""Press and hold a mouse button."""
pass
@abstractmethod
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
"""Release a mouse button."""
pass
@abstractmethod
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Perform a left click."""
pass
@abstractmethod
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Perform a right click."""
pass
@abstractmethod
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Perform a double click."""
pass
@abstractmethod
async def move_cursor(self, x: int, y: int) -> None:
"""Move the cursor to specified position."""
pass
@abstractmethod
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None:
"""Drag from current position to specified coordinates.
Args:
x: The x coordinate to drag to
y: The y coordinate to drag to
button: The mouse button to use ('left', 'middle', 'right')
duration: How long the drag should take in seconds
"""
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
"""Drag the cursor along a path of coordinates.
Args:
path: List of (x, y) coordinate tuples defining the drag path
button: The mouse button to use ('left', 'middle', 'right')
duration: Total time in seconds that the drag operation should take
"""
pass
# Keyboard Actions
@abstractmethod
async def key_down(self, key: str) -> None:
"""Press and hold a key."""
pass
@abstractmethod
async def key_up(self, key: str) -> None:
"""Release a key."""
pass
@abstractmethod
async def type_text(self, text: str) -> None:
"""Type the specified text."""
pass
@abstractmethod
async def press_key(self, key: str) -> None:
"""Press a single key."""
pass
@abstractmethod
async def hotkey(self, *keys: str) -> None:
"""Press multiple keys simultaneously."""
pass
# Scrolling Actions
@abstractmethod
async def scroll(self, x: int, y: int) -> None:
"""Scroll the mouse wheel."""
pass
@abstractmethod
async def scroll_down(self, clicks: int = 1) -> None:
"""Scroll down."""
pass
@abstractmethod
async def scroll_up(self, clicks: int = 1) -> None:
"""Scroll up."""
pass
# Screen Actions
@abstractmethod
async def screenshot(self) -> bytes:
"""Take a screenshot.
Returns:
Raw bytes of the screenshot image
"""
pass
@abstractmethod
async def get_screen_size(self) -> Dict[str, int]:
"""Get the screen dimensions.
Returns:
Dict with 'width' and 'height' keys
"""
pass
@abstractmethod
async def get_cursor_position(self) -> Dict[str, int]:
"""Get current cursor position."""
pass
# Clipboard Actions
@abstractmethod
async def copy_to_clipboard(self) -> str:
"""Get clipboard content."""
pass
@abstractmethod
async def set_clipboard(self, text: str) -> None:
"""Set clipboard content."""
pass
# File System Actions
@abstractmethod
async def file_exists(self, path: str) -> bool:
"""Check if file exists."""
pass
@abstractmethod
async def directory_exists(self, path: str) -> bool:
"""Check if directory exists."""
pass
@abstractmethod
async def list_dir(self, path: str) -> List[str]:
"""List directory contents."""
pass
@abstractmethod
async def read_text(self, path: str) -> str:
"""Read file text contents."""
pass
@abstractmethod
async def write_text(self, path: str, content: str) -> None:
"""Write file text contents."""
pass
@abstractmethod
async def read_bytes(self, path: str) -> bytes:
"""Read file binary contents."""
pass
@abstractmethod
async def write_bytes(self, path: str, content: bytes) -> None:
"""Write file binary contents."""
pass
@abstractmethod
async def delete_file(self, path: str) -> None:
"""Delete file."""
pass
@abstractmethod
async def create_dir(self, path: str) -> None:
"""Create directory."""
pass
@abstractmethod
async def delete_dir(self, path: str) -> None:
"""Delete directory."""
pass
@abstractmethod
async def run_command(self, command: str) -> Tuple[str, str]:
"""Run shell command."""
pass
# Accessibility Actions
@abstractmethod
async def get_accessibility_tree(self) -> Dict:
"""Get the accessibility tree of the current screen."""
pass
@abstractmethod
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
Args:
x: X coordinate in screenshot space
y: Y coordinate in screenshot space
Returns:
tuple[float, float]: (x, y) coordinates in screen space
"""
pass
@abstractmethod
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screen coordinates to screenshot coordinates.
Args:
x: X coordinate in screen space
y: Y coordinate in screen space
Returns:
tuple[float, float]: (x, y) coordinates in screenshot space
"""
pass

View File

@@ -0,0 +1,42 @@
"""Factory for creating computer interfaces."""
from typing import Literal, Optional
from .base import BaseComputerInterface
class InterfaceFactory:
"""Factory for creating OS-specific computer interfaces."""
@staticmethod
def create_interface_for_os(
os: Literal['macos', 'linux', 'windows'],
ip_address: str,
api_key: Optional[str] = None,
vm_name: Optional[str] = None
) -> BaseComputerInterface:
"""Create an interface for the specified OS.
Args:
os: Operating system type ('macos', 'linux', or 'windows')
ip_address: IP address of the computer to control
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
Returns:
BaseComputerInterface: The appropriate interface for the OS
Raises:
ValueError: If the OS type is not supported
"""
# Import implementations here to avoid circular imports
from .macos import MacOSComputerInterface
from .linux import LinuxComputerInterface
from .windows import WindowsComputerInterface
if os == 'macos':
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'linux':
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'windows':
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
else:
raise ValueError(f"Unsupported OS type: {os}")

View File

@@ -0,0 +1,688 @@
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import websockets
from ..logger import Logger, LogLevel
from .base import BaseComputerInterface
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
from .models import Key, KeyType, MouseButton
class LinuxComputerInterface(BaseComputerInterface):
"""Interface for Linux."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
self._last_ping = 0
self._ping_interval = 5 # Send ping every 5 seconds
self._ping_timeout = 120 # Wait 120 seconds for pong response
self._reconnect_delay = 1 # Start with 1 second delay
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
self._log_connection_attempts = True # Flag to control connection attempt logging
self._authenticated = False # Track authentication status
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
# Set logger name for Linux interface
self.logger = Logger("computer.interface.linux", LogLevel.NORMAL)
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
async def _keep_alive(self):
"""Keep the WebSocket connection alive with automatic reconnection."""
retry_count = 0
max_log_attempts = 1 # Only log the first attempt at INFO level
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
last_warning_time = 0
min_warning_interval = 30 # Minimum seconds between connection lost warnings
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
while not self._closed:
try:
if self._ws is None or (
self._ws and self._ws.state == websockets.protocol.State.CLOSED
):
try:
retry_count += 1
# Add a minimum delay between connection attempts to avoid flooding
if retry_count > 1:
await asyncio.sleep(min_retry_delay)
# Only log the first attempt at INFO level, then every Nth attempt
if retry_count == 1:
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
elif retry_count % log_interval == 0:
self.logger.info(
f"Still attempting WebSocket connection (attempt {retry_count})..."
)
else:
# All other attempts are logged at DEBUG level
self.logger.debug(
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
)
self._ws = await asyncio.wait_for(
websockets.connect(
self.ws_uri,
max_size=1024 * 1024 * 10, # 10MB limit
max_queue=32,
ping_interval=self._ping_interval,
ping_timeout=self._ping_timeout,
close_timeout=5,
compression=None, # Disable compression to reduce overhead
),
timeout=120,
)
self.logger.info("WebSocket connection established")
# Authentication will be handled by the first command that needs it
# Don't do authentication here to avoid recv conflicts
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection
self._authenticated = False # Reset auth status on new connection
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
next_retry = self._reconnect_delay
# Only log the first error at WARNING level, then every Nth attempt
if retry_count == 1:
self.logger.warning(
f"Computer API Server not ready yet. Will retry automatically."
)
elif retry_count % log_interval == 0:
self.logger.warning(
f"Still waiting for Computer API Server (attempt {retry_count})..."
)
else:
# All other errors are logged at DEBUG level
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
# Regular ping to check connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
try:
if time.time() - self._last_ping >= self._ping_interval:
pong_waiter = await self._ws.ping()
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
self._last_ping = time.time()
except Exception as e:
self.logger.debug(f"Ping failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
continue
await asyncio.sleep(1)
except Exception as e:
current_time = time.time()
# Only log connection lost warnings at most once every min_warning_interval seconds
if current_time - last_warning_time >= min_warning_interval:
self.logger.warning(
f"Computer API Server connection lost. Will retry automatically."
)
last_warning_time = current_time
else:
# Log at debug level instead
self.logger.debug(f"Connection lost: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
async def _ensure_connection(self):
"""Ensure WebSocket connection is established."""
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
retry_count = 0
max_retries = 5
while retry_count < max_retries:
try:
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
return
retry_count += 1
await asyncio.sleep(1)
except Exception as e:
# Only log at ERROR level for the last retry attempt
if retry_count == max_retries - 1:
self.logger.error(
f"Persistent connection check error after {retry_count} attempts: {e}"
)
else:
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
retry_count += 1
await asyncio.sleep(1)
continue
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through WebSocket."""
max_retries = 3
retry_count = 0
last_error = None
# Acquire lock to ensure only one command is processed at a time
async with self._command_lock:
self.logger.debug(f"Acquired lock for command: {command}")
while retry_count < max_retries:
try:
await self._ensure_connection()
if not self._ws:
raise ConnectionError("WebSocket connection is not established")
# Handle authentication if needed
if self.api_key and self.vm_name and not self._authenticated:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"container_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
self._authenticated = False
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._authenticated = True
message = {"command": command, "params": params or {}}
await self._ws.send(json.dumps(message))
response = await asyncio.wait_for(self._ws.recv(), timeout=30)
self.logger.debug(f"Completed command: {command}")
return json.loads(response)
except Exception as e:
last_error = e
retry_count += 1
if retry_count < max_retries:
# Only log at debug level for intermediate retries
self.logger.debug(
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
)
await asyncio.sleep(1)
continue
else:
# Only log at error level for the final failure
self.logger.error(
f"Failed to send command '{command}' after {max_retries} retries"
)
self.logger.debug(f"Command failure details: {e}")
raise last_error if last_error else RuntimeError("Failed to send command")
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
"""Wait for WebSocket connection to become available."""
start_time = time.time()
last_error = None
attempt_count = 0
progress_interval = 10 # Log progress every 10 seconds
last_progress_time = start_time
# Disable detailed logging for connection attempts
self._log_connection_attempts = False
try:
self.logger.info(
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
)
# Start the keep-alive task if it's not already running
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
# Wait for the connection to be established
while time.time() - start_time < timeout:
try:
attempt_count += 1
current_time = time.time()
# Log progress periodically without flooding logs
if current_time - last_progress_time >= progress_interval:
elapsed = current_time - start_time
self.logger.info(
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
)
last_progress_time = current_time
# Check if we have a connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
# Test the connection with a simple command
try:
await self._send_command("get_screen_size")
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
)
return # Connection is fully working
except Exception as e:
last_error = e
self.logger.debug(f"Connection test failed: {e}")
# Wait before trying again
await asyncio.sleep(interval)
except Exception as e:
last_error = e
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
await asyncio.sleep(interval)
# If we get here, we've timed out
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
if last_error:
error_msg += f": {str(last_error)}"
self.logger.error(error_msg)
raise TimeoutError(error_msg)
finally:
# Reset to default logging behavior
self._log_connection_attempts = False
def close(self):
"""Close WebSocket connection.
Note: In host computer server mode, we leave the connection open
to allow other clients to connect to the same server. The server
will handle cleaning up idle connections.
"""
# Only cancel the reconnect task
if self._reconnect_task:
self._reconnect_task.cancel()
# Don't set closed flag or close websocket by default
# This allows the server to stay connected for other clients
# self._closed = True
# if self._ws:
# asyncio.create_task(self._ws.close())
# self._ws = None
def force_close(self):
"""Force close the WebSocket connection.
This method should be called when you want to completely
shut down the connection, not just for regular cleanup.
"""
self._closed = True
if self._reconnect_task:
self._reconnect_task.cancel()
if self._ws:
asyncio.create_task(self._ws.close())
self._ws = None
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("left_click", {"x": x, "y": y})
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("right_click", {"x": x, "y": y})
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("double_click", {"x": x, "y": y})
async def move_cursor(self, x: int, y: int) -> None:
await self._send_command("move_cursor", {"x": x, "y": y})
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
)
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag", {"path": path, "button": button, "duration": duration}
)
# Keyboard Actions
async def key_down(self, key: "KeyType") -> None:
await self._send_command("key_down", {"key": key})
async def key_up(self, key: "KeyType") -> None:
await self._send_command("key_up", {"key": key})
async def type_text(self, text: str) -> None:
# Temporary fix for https://github.com/trycua/cua/issues/165
# Check if text contains Unicode characters
if any(ord(char) > 127 for char in text):
# For Unicode text, use clipboard and paste
await self.set_clipboard(text)
await self.hotkey(Key.COMMAND, 'v')
else:
# For ASCII text, use the regular typing method
await self._send_command("type_text", {"text": text})
async def press(self, key: "KeyType") -> None:
"""Press a single key.
Args:
key: The key to press. Can be any of:
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
- A direct key value string, e.g. 'pagedown'
- A single character string, e.g. 'a'
Examples:
```python
# Using enum (recommended)
await interface.press(Key.PAGE_DOWN)
await interface.press(Key.ENTER)
# Using direct values
await interface.press('pagedown')
await interface.press('enter')
# Using single characters
await interface.press('a')
```
Raises:
ValueError: If the key type is invalid or the key is not recognized
"""
if isinstance(key, Key):
actual_key = key.value
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("press_key", {"key": actual_key})
async def press_key(self, key: "KeyType") -> None:
"""DEPRECATED: Use press() instead.
This method is kept for backward compatibility but will be removed in a future version.
Please use the press() method instead.
"""
await self.press(key)
async def hotkey(self, *keys: "KeyType") -> None:
"""Press multiple keys simultaneously.
Args:
*keys: Multiple keys to press simultaneously. Each key can be any of:
- A Key enum value (recommended), e.g. Key.COMMAND
- A direct key value string, e.g. 'command'
- A single character string, e.g. 'a'
Examples:
```python
# Using enums (recommended)
await interface.hotkey(Key.COMMAND, Key.C) # Copy
await interface.hotkey(Key.COMMAND, Key.V) # Paste
# Using mixed formats
await interface.hotkey(Key.COMMAND, 'a') # Select all
```
Raises:
ValueError: If any key type is invalid or not recognized
"""
actual_keys = []
for key in keys:
if isinstance(key, Key):
actual_keys.append(key.value)
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("hotkey", {"keys": actual_keys})
# Scrolling Actions
async def scroll(self, x: int, y: int) -> None:
await self._send_command("scroll", {"x": x, "y": y})
async def scroll_down(self, clicks: int = 1) -> None:
await self._send_command("scroll_down", {"clicks": clicks})
async def scroll_up(self, clicks: int = 1) -> None:
await self._send_command("scroll_up", {"clicks": clicks})
# Screen Actions
async def screenshot(
self,
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
box_color: str = "#FF0000",
box_thickness: int = 2,
scale_factor: float = 1.0,
) -> bytes:
"""Take a screenshot with optional box drawing and scaling.
Args:
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
box_color: Color of the boxes in hex format (default: "#FF0000" red)
box_thickness: Thickness of the box borders in pixels (default: 2)
scale_factor: Factor to scale the final image by (default: 1.0)
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
Returns:
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
"""
result = await self._send_command("screenshot")
if not result.get("image_data"):
raise RuntimeError("Failed to take screenshot")
screenshot = decode_base64_image(result["image_data"])
if boxes:
# Get the natural scaling between screen and screenshot
screen_size = await self.get_screen_size()
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Scale box coordinates from screen space to screenshot space
for box in boxes:
scaled_box = (
int(box[0] * width_scale), # x
int(box[1] * height_scale), # y
int(box[2] * width_scale), # width
int(box[3] * height_scale), # height
)
screenshot = draw_box(
screenshot,
x=scaled_box[0],
y=scaled_box[1],
width=scaled_box[2],
height=scaled_box[3],
color=box_color,
thickness=box_thickness,
)
if scale_factor != 1.0:
screenshot = resize_image(screenshot, scale_factor)
return screenshot
async def get_screen_size(self) -> Dict[str, int]:
result = await self._send_command("get_screen_size")
if result["success"] and result["size"]:
return result["size"]
raise RuntimeError("Failed to get screen size")
async def get_cursor_position(self) -> Dict[str, int]:
result = await self._send_command("get_cursor_position")
if result["success"] and result["position"]:
return result["position"]
raise RuntimeError("Failed to get cursor position")
# Clipboard Actions
async def copy_to_clipboard(self) -> str:
result = await self._send_command("copy_to_clipboard")
if result["success"] and result["content"]:
return result["content"]
raise RuntimeError("Failed to get clipboard content")
async def set_clipboard(self, text: str) -> None:
await self._send_command("set_clipboard", {"text": text})
# File System Actions
async def file_exists(self, path: str) -> bool:
result = await self._send_command("file_exists", {"path": path})
return result.get("exists", False)
async def directory_exists(self, path: str) -> bool:
result = await self._send_command("directory_exists", {"path": path})
return result.get("exists", False)
async def list_dir(self, path: str) -> list[str]:
result = await self._send_command("list_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to list directory"))
return result.get("files", [])
async def read_text(self, path: str) -> str:
result = await self._send_command("read_text", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
return result.get("content", "")
async def write_text(self, path: str, content: str) -> None:
result = await self._send_command("write_text", {"path": path, "content": content})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def read_bytes(self, path: str) -> bytes:
result = await self._send_command("read_bytes", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
content_b64 = result.get("content_b64", "")
return decode_base64_image(content_b64)
async def write_bytes(self, path: str, content: bytes) -> None:
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def delete_file(self, path: str) -> None:
result = await self._send_command("delete_file", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete file"))
async def create_dir(self, path: str) -> None:
result = await self._send_command("create_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to create directory"))
async def delete_dir(self, path: str) -> None:
result = await self._send_command("delete_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete directory"))
async def run_command(self, command: str) -> Tuple[str, str]:
result = await self._send_command("run_command", {"command": command})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to run command"))
return result.get("stdout", ""), result.get("stderr", "")
# Accessibility Actions
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current screen."""
result = await self._send_command("get_accessibility_tree")
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
return result
async def get_active_window_bounds(self) -> Dict[str, int]:
"""Get the bounds of the currently active window."""
result = await self._send_command("get_active_window_bounds")
if result["success"] and result["bounds"]:
return result["bounds"]
raise RuntimeError("Failed to get active window bounds")
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
Args:
x: X coordinate in screenshot space
y: Y coordinate in screenshot space
Returns:
tuple[float, float]: (x, y) coordinates in screen space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screen_size["width"] / screenshot_width
height_scale = screen_size["height"] / screenshot_height
# Convert coordinates
screen_x = x * width_scale
screen_y = y * height_scale
return screen_x, screen_y
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screen coordinates to screenshot coordinates.
Args:
x: X coordinate in screen space
y: Y coordinate in screen space
Returns:
tuple[float, float]: (x, y) coordinates in screenshot space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Convert coordinates
screenshot_x = x * width_scale
screenshot_y = y * height_scale
return screenshot_x, screenshot_y

View File

@@ -0,0 +1,695 @@
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import websockets
from ..logger import Logger, LogLevel
from .base import BaseComputerInterface
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
from .models import Key, KeyType, MouseButton
class MacOSComputerInterface(BaseComputerInterface):
"""Interface for macOS."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
self._last_ping = 0
self._ping_interval = 5 # Send ping every 5 seconds
self._ping_timeout = 120 # Wait 120 seconds for pong response
self._reconnect_delay = 1 # Start with 1 second delay
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
self._log_connection_attempts = True # Flag to control connection attempt logging
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
# Set logger name for macOS interface
self.logger = Logger("computer.interface.macos", LogLevel.NORMAL)
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
async def _keep_alive(self):
"""Keep the WebSocket connection alive with automatic reconnection."""
retry_count = 0
max_log_attempts = 1 # Only log the first attempt at INFO level
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
last_warning_time = 0
min_warning_interval = 30 # Minimum seconds between connection lost warnings
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
while not self._closed:
try:
if self._ws is None or (
self._ws and self._ws.state == websockets.protocol.State.CLOSED
):
try:
retry_count += 1
# Add a minimum delay between connection attempts to avoid flooding
if retry_count > 1:
await asyncio.sleep(min_retry_delay)
# Only log the first attempt at INFO level, then every Nth attempt
if retry_count == 1:
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
elif retry_count % log_interval == 0:
self.logger.info(
f"Still attempting WebSocket connection (attempt {retry_count})..."
)
else:
# All other attempts are logged at DEBUG level
self.logger.debug(
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
)
self._ws = await asyncio.wait_for(
websockets.connect(
self.ws_uri,
max_size=1024 * 1024 * 10, # 10MB limit
max_queue=32,
ping_interval=self._ping_interval,
ping_timeout=self._ping_timeout,
close_timeout=5,
compression=None, # Disable compression to reduce overhead
),
timeout=120,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"container_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
next_retry = self._reconnect_delay
# Only log the first error at WARNING level, then every Nth attempt
if retry_count == 1:
self.logger.warning(
f"Computer API Server not ready yet. Will retry automatically."
)
elif retry_count % log_interval == 0:
self.logger.warning(
f"Still waiting for Computer API Server (attempt {retry_count})..."
)
else:
# All other errors are logged at DEBUG level
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
# Use exponential backoff for connection retries
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2, self._max_reconnect_delay
)
continue
# Regular ping to check connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
try:
if time.time() - self._last_ping >= self._ping_interval:
pong_waiter = await self._ws.ping()
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
self._last_ping = time.time()
except Exception as e:
self.logger.debug(f"Ping failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
continue
await asyncio.sleep(1)
except Exception as e:
current_time = time.time()
# Only log connection lost warnings at most once every min_warning_interval seconds
if current_time - last_warning_time >= min_warning_interval:
self.logger.warning(
f"Computer API Server connection lost. Will retry automatically."
)
last_warning_time = current_time
else:
# Log at debug level instead
self.logger.debug(f"Connection lost: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
async def _ensure_connection(self):
"""Ensure WebSocket connection is established."""
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
retry_count = 0
max_retries = 5
while retry_count < max_retries:
try:
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
return
retry_count += 1
await asyncio.sleep(1)
except Exception as e:
# Only log at ERROR level for the last retry attempt
if retry_count == max_retries - 1:
self.logger.error(
f"Persistent connection check error after {retry_count} attempts: {e}"
)
else:
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
retry_count += 1
await asyncio.sleep(1)
continue
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through WebSocket."""
max_retries = 3
retry_count = 0
last_error = None
# Acquire lock to ensure only one command is processed at a time
async with self._command_lock:
self.logger.debug(f"Acquired lock for command: {command}")
while retry_count < max_retries:
try:
await self._ensure_connection()
if not self._ws:
raise ConnectionError("WebSocket connection is not established")
message = {"command": command, "params": params or {}}
await self._ws.send(json.dumps(message))
response = await asyncio.wait_for(self._ws.recv(), timeout=120)
self.logger.debug(f"Completed command: {command}")
return json.loads(response)
except Exception as e:
last_error = e
retry_count += 1
if retry_count < max_retries:
# Only log at debug level for intermediate retries
self.logger.debug(
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
)
await asyncio.sleep(1)
continue
else:
# Only log at error level for the final failure
self.logger.error(
f"Failed to send command '{command}' after {max_retries} retries"
)
self.logger.debug(f"Command failure details: {e}")
raise
raise last_error if last_error else RuntimeError("Failed to send command")
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
"""Wait for WebSocket connection to become available."""
start_time = time.time()
last_error = None
attempt_count = 0
progress_interval = 10 # Log progress every 10 seconds
last_progress_time = start_time
# Disable detailed logging for connection attempts
self._log_connection_attempts = False
try:
self.logger.info(
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
)
# Start the keep-alive task if it's not already running
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
# Wait for the connection to be established
while time.time() - start_time < timeout:
try:
attempt_count += 1
current_time = time.time()
# Log progress periodically without flooding logs
if current_time - last_progress_time >= progress_interval:
elapsed = current_time - start_time
self.logger.info(
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
)
last_progress_time = current_time
# Check if we have a connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
# Test the connection with a simple command
try:
await self._send_command("get_screen_size")
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
)
return # Connection is fully working
except Exception as e:
last_error = e
self.logger.debug(f"Connection test failed: {e}")
# Wait before trying again
await asyncio.sleep(interval)
except Exception as e:
last_error = e
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
await asyncio.sleep(interval)
# If we get here, we've timed out
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
if last_error:
error_msg += f": {str(last_error)}"
self.logger.error(error_msg)
raise TimeoutError(error_msg)
finally:
# Reset to default logging behavior
self._log_connection_attempts = False
def close(self):
"""Close WebSocket connection.
Note: In host computer server mode, we leave the connection open
to allow other clients to connect to the same server. The server
will handle cleaning up idle connections.
"""
# Only cancel the reconnect task
if self._reconnect_task:
self._reconnect_task.cancel()
# Don't set closed flag or close websocket by default
# This allows the server to stay connected for other clients
# self._closed = True
# if self._ws:
# asyncio.create_task(self._ws.close())
# self._ws = None
def force_close(self):
"""Force close the WebSocket connection.
This method should be called when you want to completely
shut down the connection, not just for regular cleanup.
"""
self._closed = True
if self._reconnect_task:
self._reconnect_task.cancel()
if self._ws:
asyncio.create_task(self._ws.close())
self._ws = None
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
"""Send a diorama command to the server (macOS only)."""
return await self._send_command("diorama_cmd", {"action": action, "arguments": arguments or {}})
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("left_click", {"x": x, "y": y})
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("right_click", {"x": x, "y": y})
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("double_click", {"x": x, "y": y})
async def move_cursor(self, x: int, y: int) -> None:
await self._send_command("move_cursor", {"x": x, "y": y})
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
)
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag", {"path": path, "button": button, "duration": duration}
)
# Keyboard Actions
async def key_down(self, key: "KeyType") -> None:
await self._send_command("key_down", {"key": key})
async def key_up(self, key: "KeyType") -> None:
await self._send_command("key_up", {"key": key})
async def type_text(self, text: str) -> None:
# Temporary fix for https://github.com/trycua/cua/issues/165
# Check if text contains Unicode characters
if any(ord(char) > 127 for char in text):
# For Unicode text, use clipboard and paste
await self.set_clipboard(text)
await self.hotkey(Key.COMMAND, 'v')
else:
# For ASCII text, use the regular typing method
await self._send_command("type_text", {"text": text})
async def press(self, key: "KeyType") -> None:
"""Press a single key.
Args:
key: The key to press. Can be any of:
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
- A direct key value string, e.g. 'pagedown'
- A single character string, e.g. 'a'
Examples:
```python
# Using enum (recommended)
await interface.press(Key.PAGE_DOWN)
await interface.press(Key.ENTER)
# Using direct values
await interface.press('pagedown')
await interface.press('enter')
# Using single characters
await interface.press('a')
```
Raises:
ValueError: If the key type is invalid or the key is not recognized
"""
if isinstance(key, Key):
actual_key = key.value
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("press_key", {"key": actual_key})
async def press_key(self, key: "KeyType") -> None:
"""DEPRECATED: Use press() instead.
This method is kept for backward compatibility but will be removed in a future version.
Please use the press() method instead.
"""
await self.press(key)
async def hotkey(self, *keys: "KeyType") -> None:
"""Press multiple keys simultaneously.
Args:
*keys: Multiple keys to press simultaneously. Each key can be any of:
- A Key enum value (recommended), e.g. Key.COMMAND
- A direct key value string, e.g. 'command'
- A single character string, e.g. 'a'
Examples:
```python
# Using enums (recommended)
await interface.hotkey(Key.COMMAND, Key.C) # Copy
await interface.hotkey(Key.COMMAND, Key.V) # Paste
# Using mixed formats
await interface.hotkey(Key.COMMAND, 'a') # Select all
```
Raises:
ValueError: If any key type is invalid or not recognized
"""
actual_keys = []
for key in keys:
if isinstance(key, Key):
actual_keys.append(key.value)
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("hotkey", {"keys": actual_keys})
# Scrolling Actions
async def scroll(self, x: int, y: int) -> None:
await self._send_command("scroll", {"x": x, "y": y})
async def scroll_down(self, clicks: int = 1) -> None:
await self._send_command("scroll_down", {"clicks": clicks})
async def scroll_up(self, clicks: int = 1) -> None:
await self._send_command("scroll_up", {"clicks": clicks})
# Screen Actions
async def screenshot(
self,
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
box_color: str = "#FF0000",
box_thickness: int = 2,
scale_factor: float = 1.0,
) -> bytes:
"""Take a screenshot with optional box drawing and scaling.
Args:
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
box_color: Color of the boxes in hex format (default: "#FF0000" red)
box_thickness: Thickness of the box borders in pixels (default: 2)
scale_factor: Factor to scale the final image by (default: 1.0)
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
Returns:
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
"""
result = await self._send_command("screenshot")
if not result.get("image_data"):
raise RuntimeError("Failed to take screenshot")
screenshot = decode_base64_image(result["image_data"])
if boxes:
# Get the natural scaling between screen and screenshot
screen_size = await self.get_screen_size()
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Scale box coordinates from screen space to screenshot space
for box in boxes:
scaled_box = (
int(box[0] * width_scale), # x
int(box[1] * height_scale), # y
int(box[2] * width_scale), # width
int(box[3] * height_scale), # height
)
screenshot = draw_box(
screenshot,
x=scaled_box[0],
y=scaled_box[1],
width=scaled_box[2],
height=scaled_box[3],
color=box_color,
thickness=box_thickness,
)
if scale_factor != 1.0:
screenshot = resize_image(screenshot, scale_factor)
return screenshot
async def get_screen_size(self) -> Dict[str, int]:
result = await self._send_command("get_screen_size")
if result["success"] and result["size"]:
return result["size"]
raise RuntimeError("Failed to get screen size")
async def get_cursor_position(self) -> Dict[str, int]:
result = await self._send_command("get_cursor_position")
if result["success"] and result["position"]:
return result["position"]
raise RuntimeError("Failed to get cursor position")
# Clipboard Actions
async def copy_to_clipboard(self) -> str:
result = await self._send_command("copy_to_clipboard")
if result["success"] and result["content"]:
return result["content"]
raise RuntimeError("Failed to get clipboard content")
async def set_clipboard(self, text: str) -> None:
await self._send_command("set_clipboard", {"text": text})
# File System Actions
async def file_exists(self, path: str) -> bool:
result = await self._send_command("file_exists", {"path": path})
return result.get("exists", False)
async def directory_exists(self, path: str) -> bool:
result = await self._send_command("directory_exists", {"path": path})
return result.get("exists", False)
async def list_dir(self, path: str) -> list[str]:
result = await self._send_command("list_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to list directory"))
return result.get("files", [])
async def read_text(self, path: str) -> str:
result = await self._send_command("read_text", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
return result.get("content", "")
async def write_text(self, path: str, content: str) -> None:
result = await self._send_command("write_text", {"path": path, "content": content})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def read_bytes(self, path: str) -> bytes:
result = await self._send_command("read_bytes", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
content_b64 = result.get("content_b64", "")
return decode_base64_image(content_b64)
async def write_bytes(self, path: str, content: bytes) -> None:
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def delete_file(self, path: str) -> None:
result = await self._send_command("delete_file", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete file"))
async def create_dir(self, path: str) -> None:
result = await self._send_command("create_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to create directory"))
async def delete_dir(self, path: str) -> None:
result = await self._send_command("delete_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete directory"))
async def run_command(self, command: str) -> Tuple[str, str]:
result = await self._send_command("run_command", {"command": command})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to run command"))
return result.get("stdout", ""), result.get("stderr", "")
# Accessibility Actions
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current screen."""
result = await self._send_command("get_accessibility_tree")
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
return result
async def get_active_window_bounds(self) -> Dict[str, int]:
"""Get the bounds of the currently active window."""
result = await self._send_command("get_active_window_bounds")
if result["success"] and result["bounds"]:
return result["bounds"]
raise RuntimeError("Failed to get active window bounds")
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
Args:
x: X coordinate in screenshot space
y: Y coordinate in screenshot space
Returns:
tuple[float, float]: (x, y) coordinates in screen space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screen_size["width"] / screenshot_width
height_scale = screen_size["height"] / screenshot_height
# Convert coordinates
screen_x = x * width_scale
screen_y = y * height_scale
return screen_x, screen_y
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screen coordinates to screenshot coordinates.
Args:
x: X coordinate in screen space
y: Y coordinate in screen space
Returns:
tuple[float, float]: (x, y) coordinates in screenshot space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Convert coordinates
screenshot_x = x * width_scale
screenshot_y = y * height_scale
return screenshot_x, screenshot_y

View File

@@ -0,0 +1,124 @@
from enum import Enum
from typing import Dict, List, Any, TypedDict, Union, Literal
# Navigation key literals
NavigationKey = Literal['pagedown', 'pageup', 'home', 'end', 'left', 'right', 'up', 'down']
# Special key literals
SpecialKey = Literal['enter', 'esc', 'tab', 'space', 'backspace', 'del']
# Modifier key literals
ModifierKey = Literal['ctrl', 'alt', 'shift', 'win', 'command', 'option']
# Function key literals
FunctionKey = Literal['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12']
class Key(Enum):
"""Keyboard keys that can be used with press_key.
These key names map to PyAutoGUI's expected key names.
"""
# Navigation
PAGE_DOWN = 'pagedown'
PAGE_UP = 'pageup'
HOME = 'home'
END = 'end'
LEFT = 'left'
RIGHT = 'right'
UP = 'up'
DOWN = 'down'
# Special keys
RETURN = 'enter'
ENTER = 'enter'
ESCAPE = 'esc'
ESC = 'esc'
TAB = 'tab'
SPACE = 'space'
BACKSPACE = 'backspace'
DELETE = 'del'
# Modifier keys
ALT = 'alt'
CTRL = 'ctrl'
SHIFT = 'shift'
WIN = 'win'
COMMAND = 'command'
OPTION = 'option'
# Function keys
F1 = 'f1'
F2 = 'f2'
F3 = 'f3'
F4 = 'f4'
F5 = 'f5'
F6 = 'f6'
F7 = 'f7'
F8 = 'f8'
F9 = 'f9'
F10 = 'f10'
F11 = 'f11'
F12 = 'f12'
@classmethod
def from_string(cls, key: str) -> 'Key | str':
"""Convert a string key name to a Key enum value.
Args:
key: String key name to convert
Returns:
Key enum value if the string matches a known key,
otherwise returns the original string for single character keys
"""
# Map common alternative names to enum values
key_mapping = {
'page_down': cls.PAGE_DOWN,
'page down': cls.PAGE_DOWN,
'pagedown': cls.PAGE_DOWN,
'page_up': cls.PAGE_UP,
'page up': cls.PAGE_UP,
'pageup': cls.PAGE_UP,
'return': cls.RETURN,
'enter': cls.ENTER,
'escape': cls.ESCAPE,
'esc': cls.ESC,
'delete': cls.DELETE,
'del': cls.DELETE,
# Modifier key mappings
'alt': cls.ALT,
'ctrl': cls.CTRL,
'control': cls.CTRL,
'shift': cls.SHIFT,
'win': cls.WIN,
'windows': cls.WIN,
'super': cls.WIN,
'command': cls.COMMAND,
'cmd': cls.COMMAND,
'': cls.COMMAND,
'option': cls.OPTION,
'': cls.OPTION,
}
normalized = key.lower().strip()
return key_mapping.get(normalized, key)
# Combined key type
KeyType = Union[Key, NavigationKey, SpecialKey, ModifierKey, FunctionKey, str]
# Key type for mouse actions
MouseButton = Literal['left', 'right', 'middle']
class AccessibilityWindow(TypedDict):
"""Information about a window in the accessibility tree."""
app_name: str
pid: int
frontmost: bool
has_windows: bool
windows: List[Dict[str, Any]]
class AccessibilityTree(TypedDict):
"""Complete accessibility tree information."""
success: bool
frontmost_application: str
windows: List[AccessibilityWindow]

View File

@@ -0,0 +1,687 @@
import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import websockets
from ..logger import Logger, LogLevel
from .base import BaseComputerInterface
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
from .models import Key, KeyType, MouseButton
class WindowsComputerInterface(BaseComputerInterface):
"""Interface for Windows."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
self._closed = False
self._last_ping = 0
self._ping_interval = 5 # Send ping every 5 seconds
self._ping_timeout = 120 # Wait 120 seconds for pong response
self._reconnect_delay = 1 # Start with 1 second delay
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
self._log_connection_attempts = True # Flag to control connection attempt logging
self._authenticated = False # Track authentication status
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
# Set logger name for Windows interface
self.logger = Logger("computer.interface.windows", LogLevel.NORMAL)
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
async def _keep_alive(self):
"""Keep the WebSocket connection alive with automatic reconnection."""
retry_count = 0
max_log_attempts = 1 # Only log the first attempt at INFO level
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
last_warning_time = 0
min_warning_interval = 30 # Minimum seconds between connection lost warnings
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
while not self._closed:
try:
if self._ws is None or (
self._ws and self._ws.state == websockets.protocol.State.CLOSED
):
try:
retry_count += 1
# Add a minimum delay between connection attempts to avoid flooding
if retry_count > 1:
await asyncio.sleep(min_retry_delay)
# Only log the first attempt at INFO level, then every Nth attempt
if retry_count == 1:
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
elif retry_count % log_interval == 0:
self.logger.info(
f"Still attempting WebSocket connection (attempt {retry_count})..."
)
else:
# All other attempts are logged at DEBUG level
self.logger.debug(
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
)
self._ws = await asyncio.wait_for(
websockets.connect(
self.ws_uri,
max_size=1024 * 1024 * 10, # 10MB limit
max_queue=32,
ping_interval=self._ping_interval,
ping_timeout=self._ping_timeout,
close_timeout=5,
compression=None, # Disable compression to reduce overhead
),
timeout=120,
)
self.logger.info("WebSocket connection established")
# Authentication will be handled by the first command that needs it
# Don't do authentication here to avoid recv conflicts
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection
self._authenticated = False # Reset auth status on new connection
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
next_retry = self._reconnect_delay
# Only log the first error at WARNING level, then every Nth attempt
if retry_count == 1:
self.logger.warning(
f"Computer API Server not ready yet. Will retry automatically."
)
elif retry_count % log_interval == 0:
self.logger.warning(
f"Still waiting for Computer API Server (attempt {retry_count})..."
)
else:
# All other errors are logged at DEBUG level
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
# Regular ping to check connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
try:
if time.time() - self._last_ping >= self._ping_interval:
pong_waiter = await self._ws.ping()
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
self._last_ping = time.time()
except Exception as e:
self.logger.debug(f"Ping failed: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
continue
await asyncio.sleep(1)
except Exception as e:
current_time = time.time()
# Only log connection lost warnings at most once every min_warning_interval seconds
if current_time - last_warning_time >= min_warning_interval:
self.logger.warning(
f"Computer API Server connection lost. Will retry automatically."
)
last_warning_time = current_time
else:
# Log at debug level instead
self.logger.debug(f"Connection lost: {e}")
if self._ws:
try:
await self._ws.close()
except:
pass
self._ws = None
async def _ensure_connection(self):
"""Ensure WebSocket connection is established."""
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
retry_count = 0
max_retries = 5
while retry_count < max_retries:
try:
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
return
retry_count += 1
await asyncio.sleep(1)
except Exception as e:
# Only log at ERROR level for the last retry attempt
if retry_count == max_retries - 1:
self.logger.error(
f"Persistent connection check error after {retry_count} attempts: {e}"
)
else:
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
retry_count += 1
await asyncio.sleep(1)
continue
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command through WebSocket."""
max_retries = 3
retry_count = 0
last_error = None
# Acquire lock to ensure only one command is processed at a time
async with self._command_lock:
self.logger.debug(f"Acquired lock for command: {command}")
while retry_count < max_retries:
try:
await self._ensure_connection()
if not self._ws:
raise ConnectionError("WebSocket connection is not established")
# Handle authentication if needed
if self.api_key and self.vm_name and not self._authenticated:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"container_name": self.vm_name
}
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
self._authenticated = False
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._authenticated = True
message = {"command": command, "params": params or {}}
await self._ws.send(json.dumps(message))
response = await asyncio.wait_for(self._ws.recv(), timeout=30)
self.logger.debug(f"Completed command: {command}")
return json.loads(response)
except Exception as e:
last_error = e
retry_count += 1
if retry_count < max_retries:
# Only log at debug level for intermediate retries
self.logger.debug(
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
)
await asyncio.sleep(1)
continue
else:
# Only log at error level for the final failure
self.logger.error(
f"Failed to send command '{command}' after {max_retries} retries"
)
self.logger.debug(f"Command failure details: {e}")
raise last_error if last_error else RuntimeError("Failed to send command")
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
"""Wait for WebSocket connection to become available."""
start_time = time.time()
last_error = None
attempt_count = 0
progress_interval = 10 # Log progress every 10 seconds
last_progress_time = start_time
# Disable detailed logging for connection attempts
self._log_connection_attempts = False
try:
self.logger.info(
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
)
# Start the keep-alive task if it's not already running
if self._reconnect_task is None or self._reconnect_task.done():
self._reconnect_task = asyncio.create_task(self._keep_alive())
# Wait for the connection to be established
while time.time() - start_time < timeout:
try:
attempt_count += 1
current_time = time.time()
# Log progress periodically without flooding logs
if current_time - last_progress_time >= progress_interval:
elapsed = current_time - start_time
self.logger.info(
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
)
last_progress_time = current_time
# Check if we have a connection
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
# Test the connection with a simple command
try:
await self._send_command("get_screen_size")
elapsed = time.time() - start_time
self.logger.info(
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
)
return # Connection is fully working
except Exception as e:
last_error = e
self.logger.debug(f"Connection test failed: {e}")
# Wait before trying again
await asyncio.sleep(interval)
except Exception as e:
last_error = e
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
await asyncio.sleep(interval)
# If we get here, we've timed out
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
if last_error:
error_msg += f": {str(last_error)}"
self.logger.error(error_msg)
raise TimeoutError(error_msg)
finally:
# Reset to default logging behavior
self._log_connection_attempts = False
def close(self):
"""Close WebSocket connection.
Note: In host computer server mode, we leave the connection open
to allow other clients to connect to the same server. The server
will handle cleaning up idle connections.
"""
# Only cancel the reconnect task
if self._reconnect_task:
self._reconnect_task.cancel()
# Don't set closed flag or close websocket by default
# This allows the server to stay connected for other clients
# self._closed = True
# if self._ws:
# asyncio.create_task(self._ws.close())
# self._ws = None
def force_close(self):
"""Force close the WebSocket connection.
This method should be called when you want to completely
shut down the connection, not just for regular cleanup.
"""
self._closed = True
if self._reconnect_task:
self._reconnect_task.cancel()
if self._ws:
asyncio.create_task(self._ws.close())
self._ws = None
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("left_click", {"x": x, "y": y})
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("right_click", {"x": x, "y": y})
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
await self._send_command("double_click", {"x": x, "y": y})
async def move_cursor(self, x: int, y: int) -> None:
await self._send_command("move_cursor", {"x": x, "y": y})
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
)
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5) -> None:
await self._send_command(
"drag", {"path": path, "button": button, "duration": duration}
)
# Keyboard Actions
async def key_down(self, key: "KeyType") -> None:
await self._send_command("key_down", {"key": key})
async def key_up(self, key: "KeyType") -> None:
await self._send_command("key_up", {"key": key})
async def type_text(self, text: str) -> None:
# For Windows, use clipboard for Unicode text like Linux
if any(ord(char) > 127 for char in text):
# For Unicode text, use clipboard and paste
await self.set_clipboard(text)
await self.hotkey(Key.CTRL, 'v') # Windows uses Ctrl+V instead of Cmd+V
else:
# For ASCII text, use the regular typing method
await self._send_command("type_text", {"text": text})
async def press(self, key: "KeyType") -> None:
"""Press a single key.
Args:
key: The key to press. Can be any of:
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
- A direct key value string, e.g. 'pagedown'
- A single character string, e.g. 'a'
Examples:
```python
# Using enum (recommended)
await interface.press(Key.PAGE_DOWN)
await interface.press(Key.ENTER)
# Using direct values
await interface.press('pagedown')
await interface.press('enter')
# Using single characters
await interface.press('a')
```
Raises:
ValueError: If the key type is invalid or the key is not recognized
"""
if isinstance(key, Key):
actual_key = key.value
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("press_key", {"key": actual_key})
async def press_key(self, key: "KeyType") -> None:
"""DEPRECATED: Use press() instead.
This method is kept for backward compatibility but will be removed in a future version.
Please use the press() method instead.
"""
await self.press(key)
async def hotkey(self, *keys: "KeyType") -> None:
"""Press multiple keys simultaneously.
Args:
*keys: Multiple keys to press simultaneously. Each key can be any of:
- A Key enum value (recommended), e.g. Key.CTRL
- A direct key value string, e.g. 'ctrl'
- A single character string, e.g. 'a'
Examples:
```python
# Using enums (recommended)
await interface.hotkey(Key.CTRL, Key.C) # Copy
await interface.hotkey(Key.CTRL, Key.V) # Paste
# Using mixed formats
await interface.hotkey(Key.CTRL, 'a') # Select all
```
Raises:
ValueError: If any key type is invalid or not recognized
"""
actual_keys = []
for key in keys:
if isinstance(key, Key):
actual_keys.append(key.value)
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("hotkey", {"keys": actual_keys})
# Scrolling Actions
async def scroll(self, x: int, y: int) -> None:
await self._send_command("scroll", {"x": x, "y": y})
async def scroll_down(self, clicks: int = 1) -> None:
await self._send_command("scroll_down", {"clicks": clicks})
async def scroll_up(self, clicks: int = 1) -> None:
await self._send_command("scroll_up", {"clicks": clicks})
# Screen Actions
async def screenshot(
self,
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
box_color: str = "#FF0000",
box_thickness: int = 2,
scale_factor: float = 1.0,
) -> bytes:
"""Take a screenshot with optional box drawing and scaling.
Args:
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
box_color: Color of the boxes in hex format (default: "#FF0000" red)
box_thickness: Thickness of the box borders in pixels (default: 2)
scale_factor: Factor to scale the final image by (default: 1.0)
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
Returns:
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
"""
result = await self._send_command("screenshot")
if not result.get("image_data"):
raise RuntimeError("Failed to take screenshot")
screenshot = decode_base64_image(result["image_data"])
if boxes:
# Get the natural scaling between screen and screenshot
screen_size = await self.get_screen_size()
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Scale box coordinates from screen space to screenshot space
for box in boxes:
scaled_box = (
int(box[0] * width_scale), # x
int(box[1] * height_scale), # y
int(box[2] * width_scale), # width
int(box[3] * height_scale), # height
)
screenshot = draw_box(
screenshot,
x=scaled_box[0],
y=scaled_box[1],
width=scaled_box[2],
height=scaled_box[3],
color=box_color,
thickness=box_thickness,
)
if scale_factor != 1.0:
screenshot = resize_image(screenshot, scale_factor)
return screenshot
async def get_screen_size(self) -> Dict[str, int]:
result = await self._send_command("get_screen_size")
if result["success"] and result["size"]:
return result["size"]
raise RuntimeError("Failed to get screen size")
async def get_cursor_position(self) -> Dict[str, int]:
result = await self._send_command("get_cursor_position")
if result["success"] and result["position"]:
return result["position"]
raise RuntimeError("Failed to get cursor position")
# Clipboard Actions
async def copy_to_clipboard(self) -> str:
result = await self._send_command("copy_to_clipboard")
if result["success"] and result["content"]:
return result["content"]
raise RuntimeError("Failed to get clipboard content")
async def set_clipboard(self, text: str) -> None:
await self._send_command("set_clipboard", {"text": text})
# File System Actions
async def file_exists(self, path: str) -> bool:
result = await self._send_command("file_exists", {"path": path})
return result.get("exists", False)
async def directory_exists(self, path: str) -> bool:
result = await self._send_command("directory_exists", {"path": path})
return result.get("exists", False)
async def list_dir(self, path: str) -> list[str]:
result = await self._send_command("list_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to list directory"))
return result.get("files", [])
async def read_text(self, path: str) -> str:
result = await self._send_command("read_text", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
return result.get("content", "")
async def write_text(self, path: str, content: str) -> None:
result = await self._send_command("write_text", {"path": path, "content": content})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def read_bytes(self, path: str) -> bytes:
result = await self._send_command("read_bytes", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
content_b64 = result.get("content_b64", "")
return decode_base64_image(content_b64)
async def write_bytes(self, path: str, content: bytes) -> None:
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def delete_file(self, path: str) -> None:
result = await self._send_command("delete_file", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete file"))
async def create_dir(self, path: str) -> None:
result = await self._send_command("create_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to create directory"))
async def delete_dir(self, path: str) -> None:
result = await self._send_command("delete_dir", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to delete directory"))
async def run_command(self, command: str) -> Tuple[str, str]:
result = await self._send_command("run_command", {"command": command})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to run command"))
return result.get("stdout", ""), result.get("stderr", "")
# Accessibility Actions
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current screen."""
result = await self._send_command("get_accessibility_tree")
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
return result
async def get_active_window_bounds(self) -> Dict[str, int]:
"""Get the bounds of the currently active window."""
result = await self._send_command("get_active_window_bounds")
if result["success"] and result["bounds"]:
return result["bounds"]
raise RuntimeError("Failed to get active window bounds")
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
Args:
x: X coordinate in screenshot space
y: Y coordinate in screenshot space
Returns:
tuple[float, float]: (x, y) coordinates in screen space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screen_size["width"] / screenshot_width
height_scale = screen_size["height"] / screenshot_height
# Convert coordinates
screen_x = x * width_scale
screen_y = y * height_scale
return screen_x, screen_y
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screen coordinates to screenshot coordinates.
Args:
x: X coordinate in screen space
y: Y coordinate in screen space
Returns:
tuple[float, float]: (x, y) coordinates in screenshot space
"""
screen_size = await self.get_screen_size()
screenshot = await self.screenshot()
screenshot_img = bytes_to_image(screenshot)
screenshot_width, screenshot_height = screenshot_img.size
# Calculate scaling factors
width_scale = screenshot_width / screen_size["width"]
height_scale = screenshot_height / screen_size["height"]
# Convert coordinates
screenshot_x = x * width_scale
screenshot_y = y * height_scale
return screenshot_x, screenshot_y

View File

@@ -0,0 +1,84 @@
"""Logging utilities for the Computer module."""
import logging
from enum import IntEnum
# Keep LogLevel for backward compatibility, but it will be deprecated
class LogLevel(IntEnum):
"""Log levels for logging. Deprecated - use standard logging levels instead."""
QUIET = 0 # Only warnings and errors
NORMAL = 1 # Info level, standard output
VERBOSE = 2 # More detailed information
DEBUG = 3 # Full debug information
# Map LogLevel to standard logging levels for backward compatibility
LOGLEVEL_MAP = {
LogLevel.QUIET: logging.WARNING,
LogLevel.NORMAL: logging.INFO,
LogLevel.VERBOSE: logging.DEBUG,
LogLevel.DEBUG: logging.DEBUG,
}
class Logger:
"""Logger class for Computer."""
def __init__(self, name: str, verbosity: int):
"""Initialize the logger.
Args:
name: The name of the logger.
verbosity: The log level (use standard logging levels like logging.INFO).
For backward compatibility, LogLevel enum values are also accepted.
"""
self.logger = logging.getLogger(name)
# Convert LogLevel enum to standard logging level if needed
if isinstance(verbosity, LogLevel):
self.verbosity = LOGLEVEL_MAP.get(verbosity, logging.INFO)
else:
self.verbosity = verbosity
self._configure()
def _configure(self):
"""Configure the logger based on log level."""
# Set the logging level directly
self.logger.setLevel(self.verbosity)
# Log the verbosity level that was set
if self.verbosity <= logging.DEBUG:
self.logger.info("Logger set to DEBUG level")
elif self.verbosity <= logging.INFO:
self.logger.info("Logger set to INFO level")
elif self.verbosity <= logging.WARNING:
self.logger.warning("Logger set to WARNING level")
elif self.verbosity <= logging.ERROR:
self.logger.warning("Logger set to ERROR level")
elif self.verbosity <= logging.CRITICAL:
self.logger.warning("Logger set to CRITICAL level")
def debug(self, message: str):
"""Log a debug message if log level is DEBUG or lower."""
self.logger.debug(message)
def info(self, message: str):
"""Log an info message if log level is INFO or lower."""
self.logger.info(message)
def verbose(self, message: str):
"""Log a verbose message between INFO and DEBUG levels."""
# Since there's no standard verbose level,
# use debug level with [VERBOSE] prefix for backward compatibility
self.logger.debug(f"[VERBOSE] {message}")
def warning(self, message: str):
"""Log a warning message."""
self.logger.warning(message)
def error(self, message: str):
"""Log an error message."""
self.logger.error(message)

View File

@@ -0,0 +1,47 @@
"""Models for computer configuration."""
from dataclasses import dataclass
from typing import Optional, Any, Dict
# Import base provider interface
from .providers.base import BaseVMProvider
@dataclass
class Display:
"""Display configuration."""
width: int
height: int
@dataclass
class Image:
"""VM image configuration."""
image: str
tag: str
name: str
@dataclass
class Computer:
"""Computer configuration."""
image: str
tag: str
name: str
display: Display
memory: str
cpu: str
vm_provider: Optional[BaseVMProvider] = None
# @property # Remove the property decorator
async def get_ip(self) -> Optional[str]:
"""Get the IP address of the VM."""
if not self.vm_provider:
return None
vm = await self.vm_provider.get_vm(self.name)
# Handle both object attribute and dictionary access for ip_address
if vm:
if isinstance(vm, dict):
return vm.get("ip_address")
else:
# Access as attribute for object-based return values
return getattr(vm, "ip_address", None)
return None

View File

@@ -0,0 +1,4 @@
"""Provider implementations for different VM backends."""
# Import specific providers only when needed to avoid circular imports
__all__ = [] # Let each provider module handle its own exports

View File

@@ -0,0 +1,106 @@
"""Base provider interface for VM backends."""
import abc
from enum import StrEnum
from typing import Dict, List, Optional, Any, AsyncContextManager
class VMProviderType(StrEnum):
"""Enum of supported VM provider types."""
LUME = "lume"
LUMIER = "lumier"
CLOUD = "cloud"
WINSANDBOX = "winsandbox"
UNKNOWN = "unknown"
class BaseVMProvider(AsyncContextManager):
"""Base interface for VM providers.
All VM provider implementations must implement this interface.
"""
@property
@abc.abstractmethod
def provider_type(self) -> VMProviderType:
"""Get the provider type."""
pass
@abc.abstractmethod
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM information by name.
Args:
name: Name of the VM to get information for
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM information including status, IP address, etc.
"""
pass
@abc.abstractmethod
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
pass
@abc.abstractmethod
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Run a VM by name with the given options.
Args:
image: Name/tag of the image to use
name: Name of the VM to run
run_opts: Dictionary of run options (memory, cpu, etc.)
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM run status and information
"""
pass
@abc.abstractmethod
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Stop a VM by name.
Args:
name: Name of the VM to stop
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM stop status and information
"""
pass
@abc.abstractmethod
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.
Args:
name: Name of the VM to update
update_opts: Dictionary of update options (memory, cpu, etc.)
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM update status and information
"""
pass
@abc.abstractmethod
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.
Args:
name: Name of the VM to get the IP for
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM when it becomes available
"""
pass

View File

@@ -0,0 +1,5 @@
"""CloudProvider module for interacting with cloud-based virtual machines."""
from .provider import CloudProvider
__all__ = ["CloudProvider"]

View File

@@ -0,0 +1,75 @@
"""Cloud VM provider implementation.
This module contains a stub implementation for a future cloud VM provider.
"""
import logging
from typing import Dict, List, Optional, Any
from ..base import BaseVMProvider, VMProviderType
# Setup logging
logger = logging.getLogger(__name__)
import asyncio
import aiohttp
from urllib.parse import urlparse
class CloudProvider(BaseVMProvider):
"""Cloud VM Provider implementation."""
def __init__(
self,
api_key: str,
verbose: bool = False,
**kwargs,
):
"""
Args:
api_key: API key for authentication
name: Name of the VM
verbose: Enable verbose logging
"""
assert api_key, "api_key required for CloudProvider"
self.api_key = api_key
self.verbose = verbose
@property
def provider_type(self) -> VMProviderType:
return VMProviderType.CLOUD
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM VNC URL by name using the cloud API."""
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
async def list_vms(self) -> List[Dict[str, Any]]:
logger.warning("CloudProvider.list_vms is not implemented")
return []
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
# logger.warning("CloudProvider.run_vm is not implemented")
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
logger.warning("CloudProvider.update_vm is not implemented")
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""
Return the VM's IP address as '{container_name}.containers.cloud.trycua.com'.
Uses the provided 'name' argument (the VM name requested by the caller),
falling back to self.name only if 'name' is None.
Retries up to 3 times with retry_delay seconds if hostname is not available.
"""
if name is None:
raise ValueError("VM name is required for CloudProvider.get_ip")
return f"{name}.containers.cloud.trycua.com"

View File

@@ -0,0 +1,138 @@
"""Factory for creating VM providers."""
import logging
from typing import Dict, Optional, Any, Type, Union
from .base import BaseVMProvider, VMProviderType
logger = logging.getLogger(__name__)
class VMProviderFactory:
"""Factory for creating VM providers based on provider type."""
@staticmethod
def create_provider(
provider_type: Union[str, VMProviderType],
port: int = 7777,
host: str = "localhost",
bin_path: Optional[str] = None,
storage: Optional[str] = None,
shared_path: Optional[str] = None,
image: Optional[str] = None,
verbose: bool = False,
ephemeral: bool = False,
noVNC_port: Optional[int] = None,
**kwargs,
) -> BaseVMProvider:
"""Create a VM provider of the specified type.
Args:
provider_type: Type of VM provider to create
port: Port for the API server
host: Hostname for the API server
bin_path: Path to provider binary if needed
storage: Path for persistent VM storage
shared_path: Path for shared folder between host and VM
image: VM image to use (for Lumier provider)
verbose: Enable verbose logging
ephemeral: Use ephemeral (temporary) storage
noVNC_port: Specific port for noVNC interface (for Lumier provider)
Returns:
An instance of the requested VM provider
Raises:
ImportError: If the required dependencies for the provider are not installed
ValueError: If the provider type is not supported
"""
# Convert string to enum if needed
if isinstance(provider_type, str):
try:
provider_type = VMProviderType(provider_type.lower())
except ValueError:
provider_type = VMProviderType.UNKNOWN
if provider_type == VMProviderType.LUME:
try:
from .lume import LumeProvider, HAS_LUME
if not HAS_LUME:
raise ImportError(
"The pylume package is required for LumeProvider. "
"Please install it with 'pip install cua-computer[lume]'"
)
return LumeProvider(
port=port,
host=host,
storage=storage,
verbose=verbose,
ephemeral=ephemeral
)
except ImportError as e:
logger.error(f"Failed to import LumeProvider: {e}")
raise ImportError(
"The pylume package is required for LumeProvider. "
"Please install it with 'pip install cua-computer[lume]'"
) from e
elif provider_type == VMProviderType.LUMIER:
try:
from .lumier import LumierProvider, HAS_LUMIER
if not HAS_LUMIER:
raise ImportError(
"Docker is required for LumierProvider. "
"Please install Docker for Apple Silicon and Lume CLI before using this provider."
)
return LumierProvider(
port=port,
host=host,
storage=storage,
shared_path=shared_path,
image=image or "macos-sequoia-cua:latest",
verbose=verbose,
ephemeral=ephemeral,
noVNC_port=noVNC_port
)
except ImportError as e:
logger.error(f"Failed to import LumierProvider: {e}")
raise ImportError(
"Docker and Lume CLI are required for LumierProvider. "
"Please install Docker for Apple Silicon and run the Lume installer script."
) from e
elif provider_type == VMProviderType.CLOUD:
try:
from .cloud import CloudProvider
return CloudProvider(
verbose=verbose,
**kwargs,
)
except ImportError as e:
logger.error(f"Failed to import CloudProvider: {e}")
raise ImportError(
"The CloudProvider is not fully implemented yet. "
"Please use LUME or LUMIER provider instead."
) from e
elif provider_type == VMProviderType.WINSANDBOX:
try:
from .winsandbox import WinSandboxProvider, HAS_WINSANDBOX
if not HAS_WINSANDBOX:
raise ImportError(
"pywinsandbox is required for WinSandboxProvider. "
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
)
return WinSandboxProvider(
port=port,
host=host,
storage=storage,
verbose=verbose,
ephemeral=ephemeral,
**kwargs
)
except ImportError as e:
logger.error(f"Failed to import WinSandboxProvider: {e}")
raise ImportError(
"pywinsandbox is required for WinSandboxProvider. "
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
) from e
else:
raise ValueError(f"Unsupported provider type: {provider_type}")

View File

@@ -0,0 +1,9 @@
"""Lume VM provider implementation."""
try:
from .provider import LumeProvider
HAS_LUME = True
__all__ = ["LumeProvider"]
except ImportError:
HAS_LUME = False
__all__ = []

View File

@@ -0,0 +1,541 @@
"""Lume VM provider implementation using curl commands.
This provider uses direct curl commands to interact with the Lume API,
removing the dependency on the pylume Python package.
"""
import os
import re
import asyncio
import json
import logging
import subprocess
import urllib.parse
from typing import Dict, Any, Optional, List, Tuple
from ..base import BaseVMProvider, VMProviderType
from ...logger import Logger, LogLevel
from ..lume_api import (
lume_api_get,
lume_api_run,
lume_api_stop,
lume_api_update,
lume_api_pull,
HAS_CURL,
parse_memory
)
# Setup logging
logger = logging.getLogger(__name__)
class LumeProvider(BaseVMProvider):
"""Lume VM provider implementation using direct curl commands.
This provider uses curl to interact with the Lume API server,
removing the dependency on the pylume Python package.
"""
def __init__(
self,
port: int = 7777,
host: str = "localhost",
storage: Optional[str] = None,
verbose: bool = False,
ephemeral: bool = False,
):
"""Initialize the Lume provider.
Args:
port: Port for the Lume API server (default: 7777)
host: Host to use for API connections (default: localhost)
storage: Path to store VM data
verbose: Enable verbose logging
"""
if not HAS_CURL:
raise ImportError(
"curl is required for LumeProvider. "
"Please ensure it is installed and in your PATH."
)
self.host = host
self.port = port # Default port for Lume API
self.storage = storage
self.verbose = verbose
self.ephemeral = ephemeral # If True, VMs will be deleted after stopping
# Base API URL for Lume API calls
self.api_base_url = f"http://{self.host}:{self.port}"
self.logger = logging.getLogger(__name__)
@property
def provider_type(self) -> VMProviderType:
"""Get the provider type."""
return VMProviderType.LUME
async def __aenter__(self):
"""Enter async context manager."""
# No initialization needed, just return self
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context manager."""
# No cleanup needed
pass
def _lume_api_get(self, vm_name: str = "", storage: Optional[str] = None, debug: bool = False) -> Dict[str, Any]:
"""Get VM information using shared lume_api function.
Args:
vm_name: Optional name of the VM to get info for.
If empty, lists all VMs.
storage: Optional storage path override. If provided, this will be used instead of self.storage
debug: Whether to show debug output
Returns:
Dictionary with VM status information parsed from JSON response
"""
# Use the shared implementation from lume_api module
return lume_api_get(
vm_name=vm_name,
host=self.host,
port=self.port,
storage=storage if storage is not None else self.storage,
debug=debug,
verbose=self.verbose
)
def _lume_api_run(self, vm_name: str, run_opts: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
"""Run a VM using shared lume_api function.
Args:
vm_name: Name of the VM to run
run_opts: Dictionary of run options
debug: Whether to show debug output
Returns:
Dictionary with API response or error information
"""
# Use the shared implementation from lume_api module
return lume_api_run(
vm_name=vm_name,
host=self.host,
port=self.port,
run_opts=run_opts,
storage=self.storage,
debug=debug,
verbose=self.verbose
)
def _lume_api_stop(self, vm_name: str, debug: bool = False) -> Dict[str, Any]:
"""Stop a VM using shared lume_api function.
Args:
vm_name: Name of the VM to stop
debug: Whether to show debug output
Returns:
Dictionary with API response or error information
"""
# Use the shared implementation from lume_api module
return lume_api_stop(
vm_name=vm_name,
host=self.host,
port=self.port,
storage=self.storage,
debug=debug,
verbose=self.verbose
)
def _lume_api_update(self, vm_name: str, update_opts: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
"""Update VM configuration using shared lume_api function.
Args:
vm_name: Name of the VM to update
update_opts: Dictionary of update options
debug: Whether to show debug output
Returns:
Dictionary with API response or error information
"""
# Use the shared implementation from lume_api module
return lume_api_update(
vm_name=vm_name,
host=self.host,
port=self.port,
update_opts=update_opts,
storage=self.storage,
debug=debug,
verbose=self.verbose
)
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM information by name.
Args:
name: Name of the VM to get information for
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM information including status, IP address, etc.
Note:
If storage is not provided, the provider's default storage path will be used.
The storage parameter allows overriding the storage location for this specific call.
"""
if not HAS_CURL:
logger.error("curl is not available. Cannot get VM status.")
return {
"name": name,
"status": "unavailable",
"error": "curl is not available"
}
# First try to get detailed VM info from the API
try:
# Query the Lume API for VM status using the provider's storage_path
vm_info = self._lume_api_get(
vm_name=name,
storage=storage if storage is not None else self.storage,
debug=self.verbose
)
# Check for API errors
if "error" in vm_info:
logger.debug(f"API request error: {vm_info['error']}")
# If we got an error from the API, report the VM as not ready yet
return {
"name": name,
"status": "starting", # VM is still starting - do not attempt to connect yet
"api_status": "error",
"error": vm_info["error"]
}
# Process the VM status information
vm_status = vm_info.get("status", "unknown")
# Check if VM is stopped or not running - don't wait for IP in this case
if vm_status == "stopped":
logger.info(f"VM {name} is in '{vm_status}' state - not waiting for IP address")
# Return the status as-is without waiting for an IP
result = {
"name": name,
"status": vm_status,
**vm_info # Include all original fields from the API response
}
return result
# Handle field name differences between APIs
# Some APIs use camelCase, others use snake_case
if "vncUrl" in vm_info:
vnc_url = vm_info["vncUrl"]
elif "vnc_url" in vm_info:
vnc_url = vm_info["vnc_url"]
else:
vnc_url = ""
if "ipAddress" in vm_info:
ip_address = vm_info["ipAddress"]
elif "ip_address" in vm_info:
ip_address = vm_info["ip_address"]
else:
# If no IP address is provided and VM is supposed to be running,
# report it as still starting
ip_address = None
logger.info(f"VM {name} is in '{vm_status}' state but no IP address found - reporting as still starting")
logger.info(f"VM {name} status: {vm_status}")
# Return the complete status information
result = {
"name": name,
"status": vm_status if vm_status else "running",
"ip_address": ip_address,
"vnc_url": vnc_url,
"api_status": "ok"
}
# Include all original fields from the API response
if isinstance(vm_info, dict):
for key, value in vm_info.items():
if key not in result: # Don't override our carefully processed fields
result[key] = value
return result
except Exception as e:
logger.error(f"Failed to get VM status: {e}")
# Return a fallback status that indicates the VM is not ready yet
return {
"name": name,
"status": "initializing", # VM is still initializing
"error": f"Failed to get VM status: {str(e)}"
}
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
result = self._lume_api_get(debug=self.verbose)
# Extract the VMs list from the response
if "vms" in result and isinstance(result["vms"], list):
return result["vms"]
elif "error" in result:
logger.error(f"Error listing VMs: {result['error']}")
return []
else:
return []
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Run a VM with the given options.
If the VM does not exist in the storage location, this will attempt to pull it
from the Lume registry first.
Args:
image: Image name to use when pulling the VM if it doesn't exist
name: Name of the VM to run
run_opts: Dictionary of run options (memory, cpu, etc.)
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM run status and information
"""
# First check if VM exists by trying to get its info
vm_info = await self.get_vm(name, storage=storage)
if "error" in vm_info:
# VM doesn't exist, try to pull it
self.logger.info(f"VM {name} not found, attempting to pull image {image} from registry...")
# Call pull_vm with the image parameter
pull_result = await self.pull_vm(
name=name,
image=image,
storage=storage
)
# Check if pull was successful
if "error" in pull_result:
self.logger.error(f"Failed to pull VM image: {pull_result['error']}")
return pull_result # Return the error from pull
self.logger.info(f"Successfully pulled VM image {image} as {name}")
# Now run the VM with the given options
self.logger.info(f"Running VM {name} with options: {run_opts}")
from ..lume_api import lume_api_run
return lume_api_run(
vm_name=name,
host=self.host,
port=self.port,
run_opts=run_opts,
storage=storage if storage is not None else self.storage,
debug=self.verbose,
verbose=self.verbose
)
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Stop a running VM.
If this provider was initialized with ephemeral=True, the VM will also
be deleted after it is stopped.
Args:
name: Name of the VM to stop
storage: Optional storage path override
Returns:
Dictionary with stop status and information
"""
# Stop the VM first
stop_result = self._lume_api_stop(name, debug=self.verbose)
# Log ephemeral status for debugging
self.logger.info(f"Ephemeral mode status: {self.ephemeral}")
# If ephemeral mode is enabled, delete the VM after stopping
if self.ephemeral and (stop_result.get("success", False) or "error" not in stop_result):
self.logger.info(f"Ephemeral mode enabled - deleting VM {name} after stopping")
try:
delete_result = await self.delete_vm(name, storage=storage)
# Return combined result
return {
**stop_result, # Include all stop result info
"deleted": True,
"delete_result": delete_result
}
except Exception as e:
self.logger.error(f"Failed to delete ephemeral VM {name}: {e}")
# Include the error but still return stop result
return {
**stop_result,
"deleted": False,
"delete_error": str(e)
}
# Just return the stop result if not ephemeral
return stop_result
async def pull_vm(
self,
name: str,
image: str,
storage: Optional[str] = None,
registry: str = "ghcr.io",
organization: str = "trycua",
pull_opts: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Pull a VM image from the registry.
Args:
name: Name for the VM after pulling
image: The image name to pull (e.g. 'macos-sequoia-cua:latest')
storage: Optional storage path to use
registry: Registry to pull from (default: ghcr.io)
organization: Organization in registry (default: trycua)
pull_opts: Additional options for pulling the VM (optional)
Returns:
Dictionary with information about the pulled VM
Raises:
RuntimeError: If pull operation fails or image is not provided
"""
# Validate image parameter
if not image:
raise ValueError("Image parameter is required for pull_vm")
self.logger.info(f"Pulling VM image '{image}' as '{name}'")
self.logger.info("You can check the pull progress using: lume logs -f")
# Set default pull_opts if not provided
if pull_opts is None:
pull_opts = {}
# Log information about the operation
self.logger.debug(f"Pull storage location: {storage or 'default'}")
try:
# Call the lume_api_pull function from lume_api.py
from ..lume_api import lume_api_pull
result = lume_api_pull(
image=image,
name=name,
host=self.host,
port=self.port,
storage=storage if storage is not None else self.storage,
registry=registry,
organization=organization,
debug=self.verbose,
verbose=self.verbose
)
# Check for errors in the result
if "error" in result:
self.logger.error(f"Failed to pull VM image: {result['error']}")
return result
self.logger.info(f"Successfully pulled VM image '{image}' as '{name}'")
return result
except Exception as e:
self.logger.error(f"Failed to pull VM image '{image}': {e}")
return {"error": f"Failed to pull VM: {str(e)}"}
async def delete_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Delete a VM permanently.
Args:
name: Name of the VM to delete
storage: Optional storage path override
Returns:
Dictionary with delete status and information
"""
self.logger.info(f"Deleting VM {name}...")
try:
# Call the lume_api_delete function we created
from ..lume_api import lume_api_delete
result = lume_api_delete(
vm_name=name,
host=self.host,
port=self.port,
storage=storage if storage is not None else self.storage,
debug=self.verbose,
verbose=self.verbose
)
# Check for errors in the result
if "error" in result:
self.logger.error(f"Failed to delete VM: {result['error']}")
return result
self.logger.info(f"Successfully deleted VM '{name}'")
return result
except Exception as e:
self.logger.error(f"Failed to delete VM '{name}': {e}")
return {"error": f"Failed to delete VM: {str(e)}"}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration."""
return self._lume_api_update(name, update_opts, debug=self.verbose)
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.
Args:
name: Name of the VM to get the IP for
storage: Optional storage path override
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM when it becomes available
"""
# Track total attempts for logging purposes
total_attempts = 0
# Loop indefinitely until we get a valid IP
while True:
total_attempts += 1
# Log retry message but not on first attempt
if total_attempts > 1:
self.logger.info(f"Waiting for VM {name} IP address (attempt {total_attempts})...")
try:
# Get VM information
vm_info = await self.get_vm(name, storage=storage)
# Check if we got a valid IP
ip = vm_info.get("ip_address", None)
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
self.logger.info(f"Got valid VM IP address: {ip}")
return ip
# Check the VM status
status = vm_info.get("status", "unknown")
# If VM is not running yet, log and wait
if status != "running":
self.logger.info(f"VM is not running yet (status: {status}). Waiting...")
# If VM is running but no IP yet, wait and retry
else:
self.logger.info("VM is running but no valid IP address yet. Waiting...")
except Exception as e:
self.logger.warning(f"Error getting VM {name} IP: {e}, continuing to wait...")
# Wait before next retry
await asyncio.sleep(retry_delay)
# Add progress log every 10 attempts
if total_attempts % 10 == 0:
self.logger.info(f"Still waiting for VM {name} IP after {total_attempts} attempts...")

View File

@@ -0,0 +1,546 @@
"""Shared API utilities for Lume and Lumier providers.
This module contains shared functions for interacting with the Lume API,
used by both the LumeProvider and LumierProvider classes.
"""
import logging
import json
import subprocess
import urllib.parse
from typing import Dict, List, Optional, Any
# Setup logging
logger = logging.getLogger(__name__)
# Check if curl is available
try:
subprocess.run(["curl", "--version"], capture_output=True, check=True)
HAS_CURL = True
except (subprocess.SubprocessError, FileNotFoundError):
HAS_CURL = False
def lume_api_get(
vm_name: str,
host: str,
port: int,
storage: Optional[str] = None,
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Use curl to get VM information from Lume API.
Args:
vm_name: Name of the VM to get info for
host: API host
port: API port
storage: Storage path for the VM
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with VM status information parsed from JSON response
"""
# URL encode the storage parameter for the query
encoded_storage = ""
storage_param = ""
if storage:
# First encode the storage path properly
encoded_storage = urllib.parse.quote(storage, safe='')
storage_param = f"?storage={encoded_storage}"
# Construct API URL with encoded storage parameter if needed
api_url = f"http://{host}:{port}/lume/vms/{vm_name}{storage_param}"
# Construct the curl command with increased timeouts for more reliability
# --connect-timeout: Time to establish connection (15 seconds)
# --max-time: Maximum time for the whole operation (20 seconds)
# -f: Fail silently (no output at all) on server errors
# Add single quotes around URL to ensure special characters are handled correctly
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", f"'{api_url}'"]
# For logging and display, show the properly escaped URL
display_cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
# Only print the curl command when debug is enabled
display_curl_string = ' '.join(display_cmd)
logger.debug(f"Executing API request: {display_curl_string}")
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
try:
# Use a single string with shell=True for proper URL handling
shell_cmd = ' '.join(cmd)
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
# Handle curl exit codes
if result.returncode != 0:
curl_error = "Unknown error"
# Map common curl error codes to helpful messages
if result.returncode == 7:
curl_error = "Failed to connect to the API server - it might still be starting up"
elif result.returncode == 22:
curl_error = "HTTP error returned from API server"
elif result.returncode == 28:
curl_error = "Operation timeout - the API server is taking too long to respond"
elif result.returncode == 52:
curl_error = "Empty reply from server - the API server is starting but not fully ready yet"
elif result.returncode == 56:
curl_error = "Network problem during data transfer - check container networking"
# Only log at debug level to reduce noise during retries
logger.debug(f"API request failed with code {result.returncode}: {curl_error}")
# Return a more useful error message
return {
"error": f"API request failed: {curl_error}",
"curl_code": result.returncode,
"vm_name": vm_name,
"status": "unknown" # We don't know the actual status due to API error
}
# Try to parse the response as JSON
if result.stdout and result.stdout.strip():
try:
vm_status = json.loads(result.stdout)
if debug or verbose:
logger.info(f"Successfully parsed VM status: {vm_status.get('status', 'unknown')}")
return vm_status
except json.JSONDecodeError as e:
# Return the raw response if it's not valid JSON
logger.warning(f"Invalid JSON response: {e}")
if "Virtual machine not found" in result.stdout:
return {"status": "not_found", "message": "VM not found in Lume API"}
return {"error": f"Invalid JSON response: {result.stdout[:100]}...", "status": "unknown"}
else:
return {"error": "Empty response from API", "status": "unknown"}
except subprocess.SubprocessError as e:
logger.error(f"Failed to execute API request: {e}")
return {"error": f"Failed to execute API request: {str(e)}", "status": "unknown"}
def lume_api_run(
vm_name: str,
host: str,
port: int,
run_opts: Dict[str, Any],
storage: Optional[str] = None,
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Run a VM using curl.
Args:
vm_name: Name of the VM to run
host: API host
port: API port
run_opts: Dictionary of run options
storage: Storage path for the VM
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with API response or error information
"""
# Construct API URL
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/run"
# Prepare JSON payload with required parameters
payload = {}
# Add CPU cores if specified
if "cpu" in run_opts:
payload["cpu"] = run_opts["cpu"]
# Add memory if specified
if "memory" in run_opts:
payload["memory"] = run_opts["memory"]
# Add storage parameter if specified
if storage:
payload["storage"] = storage
elif "storage" in run_opts:
payload["storage"] = run_opts["storage"]
# Add shared directories if specified
if "shared_directories" in run_opts and run_opts["shared_directories"]:
payload["sharedDirectories"] = run_opts["shared_directories"]
# Log the payload for debugging
logger.debug(f"API payload: {json.dumps(payload, indent=2)}")
# Construct the curl command
cmd = [
"curl", "--connect-timeout", "30", "--max-time", "30",
"-s", "-X", "POST", "-H", "Content-Type: application/json",
"-d", json.dumps(payload),
api_url
]
# Execute the command
try:
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
return {"error": f"API request failed: {result.stderr}"}
# Try to parse the response as JSON
if result.stdout and result.stdout.strip():
try:
response = json.loads(result.stdout)
return response
except json.JSONDecodeError:
# Return the raw response if it's not valid JSON
return {"success": True, "message": "VM started successfully", "raw_response": result.stdout}
else:
return {"success": True, "message": "VM started successfully"}
except subprocess.SubprocessError as e:
logger.error(f"Failed to execute run request: {e}")
return {"error": f"Failed to execute run request: {str(e)}"}
def lume_api_stop(
vm_name: str,
host: str,
port: int,
storage: Optional[str] = None,
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Stop a VM using curl.
Args:
vm_name: Name of the VM to stop
host: API host
port: API port
storage: Storage path for the VM
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with API response or error information
"""
# Construct API URL
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/stop"
# Prepare JSON payload with required parameters
payload = {}
# Add storage path if specified
if storage:
payload["storage"] = storage
# Construct the curl command
cmd = [
"curl", "--connect-timeout", "15", "--max-time", "20",
"-s", "-X", "POST", "-H", "Content-Type: application/json",
"-d", json.dumps(payload),
api_url
]
# Execute the command
try:
if debug or verbose:
logger.info(f"Executing: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
return {"error": f"API request failed: {result.stderr}"}
# Try to parse the response as JSON
if result.stdout and result.stdout.strip():
try:
response = json.loads(result.stdout)
return response
except json.JSONDecodeError:
# Return the raw response if it's not valid JSON
return {"success": True, "message": "VM stopped successfully", "raw_response": result.stdout}
else:
return {"success": True, "message": "VM stopped successfully"}
except subprocess.SubprocessError as e:
logger.error(f"Failed to execute stop request: {e}")
return {"error": f"Failed to execute stop request: {str(e)}"}
def lume_api_update(
vm_name: str,
host: str,
port: int,
update_opts: Dict[str, Any],
storage: Optional[str] = None,
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Update VM settings using curl.
Args:
vm_name: Name of the VM to update
host: API host
port: API port
update_opts: Dictionary of update options
storage: Storage path for the VM
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with API response or error information
"""
# Construct API URL
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/update"
# Prepare JSON payload with required parameters
payload = {}
# Add CPU cores if specified
if "cpu" in update_opts:
payload["cpu"] = update_opts["cpu"]
# Add memory if specified
if "memory" in update_opts:
payload["memory"] = update_opts["memory"]
# Add storage path if specified
if storage:
payload["storage"] = storage
# Construct the curl command
cmd = [
"curl", "--connect-timeout", "15", "--max-time", "20",
"-s", "-X", "POST", "-H", "Content-Type: application/json",
"-d", json.dumps(payload),
api_url
]
# Execute the command
try:
if debug:
logger.info(f"Executing: {' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
return {"error": f"API request failed: {result.stderr}"}
# Try to parse the response as JSON
if result.stdout and result.stdout.strip():
try:
response = json.loads(result.stdout)
return response
except json.JSONDecodeError:
# Return the raw response if it's not valid JSON
return {"success": True, "message": "VM updated successfully", "raw_response": result.stdout}
else:
return {"success": True, "message": "VM updated successfully"}
except subprocess.SubprocessError as e:
logger.error(f"Failed to execute update request: {e}")
return {"error": f"Failed to execute update request: {str(e)}"}
def lume_api_pull(
image: str,
name: str,
host: str,
port: int,
storage: Optional[str] = None,
registry: str = "ghcr.io",
organization: str = "trycua",
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Pull a VM image from a registry using curl.
Args:
image: Name/tag of the image to pull
name: Name to give the VM after pulling
host: API host
port: API port
storage: Storage path for the VM
registry: Registry to pull from (default: ghcr.io)
organization: Organization in registry (default: trycua)
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with pull status and information
"""
# Prepare pull request payload
pull_payload = {
"image": image, # Use provided image name
"name": name, # Always use name as the target VM name
"registry": registry,
"organization": organization
}
if storage:
pull_payload["storage"] = storage
# Construct pull command with proper JSON payload
pull_cmd = [
"curl"
]
if not verbose:
pull_cmd.append("-s")
pull_cmd.extend([
"-X", "POST",
"-H", "Content-Type: application/json",
"-d", json.dumps(pull_payload),
f"http://{host}:{port}/lume/pull"
])
logger.debug(f"Executing API request: {' '.join(pull_cmd)}")
try:
# Execute pull command
result = subprocess.run(pull_cmd, capture_output=True, text=True)
if result.returncode != 0:
error_msg = f"Failed to pull VM {name}: {result.stderr}"
logger.error(error_msg)
return {"error": error_msg}
try:
response = json.loads(result.stdout)
logger.info(f"Successfully initiated pull for VM {name}")
return response
except json.JSONDecodeError:
if result.stdout:
logger.info(f"Pull response: {result.stdout}")
return {"success": True, "message": f"Successfully initiated pull for VM {name}"}
except subprocess.SubprocessError as e:
error_msg = f"Failed to execute pull command: {str(e)}"
logger.error(error_msg)
return {"error": error_msg}
def lume_api_delete(
vm_name: str,
host: str,
port: int,
storage: Optional[str] = None,
debug: bool = False,
verbose: bool = False
) -> Dict[str, Any]:
"""Delete a VM using curl.
Args:
vm_name: Name of the VM to delete
host: API host
port: API port
storage: Storage path for the VM
debug: Whether to show debug output
verbose: Enable verbose logging
Returns:
Dictionary with API response or error information
"""
# URL encode the storage parameter for the query
encoded_storage = ""
storage_param = ""
if storage:
# First encode the storage path properly
encoded_storage = urllib.parse.quote(storage, safe='')
storage_param = f"?storage={encoded_storage}"
# Construct API URL with encoded storage parameter if needed
api_url = f"http://{host}:{port}/lume/vms/{vm_name}{storage_param}"
# Construct the curl command for DELETE operation - using much longer timeouts matching shell implementation
cmd = ["curl", "--connect-timeout", "6000", "--max-time", "5000", "-s", "-X", "DELETE", f"'{api_url}'"]
# For logging and display, show the properly escaped URL
display_cmd = ["curl", "--connect-timeout", "6000", "--max-time", "5000", "-s", "-X", "DELETE", api_url]
# Only print the curl command when debug is enabled
display_curl_string = ' '.join(display_cmd)
logger.debug(f"Executing API request: {display_curl_string}")
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
try:
# Use a single string with shell=True for proper URL handling
shell_cmd = ' '.join(cmd)
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
# Handle curl exit codes
if result.returncode != 0:
curl_error = "Unknown error"
# Map common curl error codes to helpful messages
if result.returncode == 7:
curl_error = "Failed to connect to the API server - it might still be starting up"
elif result.returncode == 22:
curl_error = "HTTP error returned from API server"
elif result.returncode == 28:
curl_error = "Operation timeout - the API server is taking too long to respond"
elif result.returncode == 52:
curl_error = "Empty reply from server - the API server is starting but not fully ready yet"
elif result.returncode == 56:
curl_error = "Network problem during data transfer - check container networking"
# Only log at debug level to reduce noise during retries
logger.debug(f"API request failed with code {result.returncode}: {curl_error}")
# Return a more useful error message
return {
"error": f"API request failed: {curl_error}",
"curl_code": result.returncode,
"vm_name": vm_name,
"storage": storage
}
# Try to parse the response as JSON
if result.stdout and result.stdout.strip():
try:
response = json.loads(result.stdout)
return response
except json.JSONDecodeError:
# Return the raw response if it's not valid JSON
return {"success": True, "message": "VM deleted successfully", "raw_response": result.stdout}
else:
return {"success": True, "message": "VM deleted successfully"}
except subprocess.SubprocessError as e:
logger.error(f"Failed to execute delete request: {e}")
return {"error": f"Failed to execute delete request: {str(e)}"}
def parse_memory(memory_str: str) -> int:
"""Parse memory string to MB integer.
Examples:
"8GB" -> 8192
"1024MB" -> 1024
"512" -> 512
Returns:
Memory value in MB
"""
if isinstance(memory_str, int):
return memory_str
if isinstance(memory_str, str):
# Extract number and unit
import re
match = re.match(r"(\d+)([A-Za-z]*)", memory_str)
if match:
value, unit = match.groups()
value = int(value)
unit = unit.upper()
if unit == "GB" or unit == "G":
return value * 1024
elif unit == "MB" or unit == "M" or unit == "":
return value
# Default fallback
logger.warning(f"Could not parse memory string '{memory_str}', using 8GB default")
return 8192 # Default to 8GB

View File

@@ -0,0 +1,8 @@
"""Lumier VM provider implementation."""
try:
# Use the same import approach as in the Lume provider
from .provider import LumierProvider
HAS_LUMIER = True
except ImportError:
HAS_LUMIER = False

View File

@@ -0,0 +1,942 @@
"""
Lumier VM provider implementation.
This provider uses Docker containers running the Lumier image to create
macOS and Linux VMs. It handles VM lifecycle operations through Docker
commands and container management.
"""
import logging
import os
import json
import asyncio
from typing import Dict, List, Optional, Any
import subprocess
import time
import re
from ..base import BaseVMProvider, VMProviderType
from ..lume_api import (
lume_api_get,
lume_api_run,
lume_api_stop,
lume_api_update
)
# Setup logging
logger = logging.getLogger(__name__)
# Check if Docker is available
try:
subprocess.run(["docker", "--version"], capture_output=True, check=True)
HAS_LUMIER = True
except (subprocess.SubprocessError, FileNotFoundError):
HAS_LUMIER = False
class LumierProvider(BaseVMProvider):
"""
Lumier VM Provider implementation using Docker containers.
This provider uses Docker to run Lumier containers that can create
macOS and Linux VMs through containerization.
"""
def __init__(
self,
port: Optional[int] = 7777,
host: str = "localhost",
storage: Optional[str] = None, # Can be a path or 'ephemeral'
shared_path: Optional[str] = None,
image: str = "macos-sequoia-cua:latest", # VM image to use
verbose: bool = False,
ephemeral: bool = False,
noVNC_port: Optional[int] = 8006,
):
"""Initialize the Lumier VM Provider.
Args:
port: Port for the API server (default: 7777)
host: Hostname for the API server (default: localhost)
storage: Path for persistent VM storage
shared_path: Path for shared folder between host and VM
image: VM image to use (e.g. "macos-sequoia-cua:latest")
verbose: Enable verbose logging
ephemeral: Use ephemeral (temporary) storage
noVNC_port: Specific port for noVNC interface (default: 8006)
"""
self.host = host
# Always ensure api_port has a valid value (7777 is the default)
self.api_port = 7777 if port is None else port
self.vnc_port = noVNC_port # User-specified noVNC port, will be set in run_vm if provided
self.ephemeral = ephemeral
# Handle ephemeral storage (temporary directory)
if ephemeral:
self.storage = "ephemeral"
else:
self.storage = storage
self.shared_path = shared_path
self.image = image # Store the VM image name to use
# The container_name will be set in run_vm using the VM name
self.verbose = verbose
self._container_id = None
self._api_url = None # Will be set after container starts
@property
def provider_type(self) -> VMProviderType:
"""Return the provider type."""
return VMProviderType.LUMIER
def _parse_memory(self, memory_str: str) -> int:
"""Parse memory string to MB integer.
Examples:
"8GB" -> 8192
"1024MB" -> 1024
"512" -> 512
"""
if isinstance(memory_str, int):
return memory_str
if isinstance(memory_str, str):
# Extract number and unit
match = re.match(r"(\d+)([A-Za-z]*)", memory_str)
if match:
value, unit = match.groups()
value = int(value)
unit = unit.upper()
if unit == "GB" or unit == "G":
return value * 1024
elif unit == "MB" or unit == "M" or unit == "":
return value
# Default fallback
logger.warning(f"Could not parse memory string '{memory_str}', using 8GB default")
return 8192 # Default to 8GB
# Helper methods for interacting with the Lumier API through curl
# These methods handle the various VM operations via API calls
def _get_curl_error_message(self, return_code: int) -> str:
"""Get a descriptive error message for curl return codes.
Args:
return_code: The curl return code
Returns:
A descriptive error message
"""
# Map common curl error codes to helpful messages
if return_code == 7:
return "Failed to connect - API server is starting up"
elif return_code == 22:
return "HTTP error returned from API server"
elif return_code == 28:
return "Operation timeout - API server is slow to respond"
elif return_code == 52:
return "Empty reply from server - API is starting but not ready"
elif return_code == 56:
return "Network problem during data transfer"
else:
return f"Unknown curl error code: {return_code}"
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM information by name.
Args:
name: Name of the VM to get information for
storage: Optional storage path override. If provided, this will be used
instead of the provider's default storage path.
Returns:
Dictionary with VM information including status, IP address, etc.
"""
if not HAS_LUMIER:
logger.error("Docker is not available. Cannot get VM status.")
return {
"name": name,
"status": "unavailable",
"error": "Docker is not available"
}
# Store the current name for API requests
self.container_name = name
try:
# Check if the container exists and is running
check_cmd = ["docker", "ps", "-a", "--filter", f"name={name}", "--format", "{{.Status}}"]
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
container_status = check_result.stdout.strip()
if not container_status:
logger.info(f"Container {name} does not exist. Will create when run_vm is called.")
return {
"name": name,
"status": "not_found",
"message": "Container doesn't exist yet"
}
# Container exists, check if it's running
is_running = container_status.startswith("Up")
if not is_running:
logger.info(f"Container {name} exists but is not running. Status: {container_status}")
return {
"name": name,
"status": "stopped",
"container_status": container_status,
}
# Container is running, get the IP address and API status from Lumier API
logger.info(f"Container {name} is running. Getting VM status from API.")
# Use the shared lume_api_get function directly
vm_info = lume_api_get(
vm_name=name,
host=self.host,
port=self.api_port,
storage=storage if storage is not None else self.storage,
debug=self.verbose,
verbose=self.verbose
)
# Check for API errors
if "error" in vm_info:
# Use debug level instead of warning to reduce log noise during polling
logger.debug(f"API request error: {vm_info['error']}")
return {
"name": name,
"status": "running", # Container is running even if API is not responsive
"api_status": "error",
"error": vm_info["error"],
"container_status": container_status
}
# Process the VM status information
vm_status = vm_info.get("status", "unknown")
vnc_url = vm_info.get("vncUrl", "")
ip_address = vm_info.get("ipAddress", "")
# IMPORTANT: Always ensure we have a valid IP address for connectivity
# If the API doesn't return an IP address, default to localhost (127.0.0.1)
# This makes the behavior consistent with LumeProvider
if not ip_address and vm_status == "running":
ip_address = "127.0.0.1"
logger.info(f"No IP address returned from API, defaulting to {ip_address}")
vm_info["ipAddress"] = ip_address
logger.info(f"VM {name} status: {vm_status}")
if ip_address and vnc_url:
logger.info(f"VM {name} has IP: {ip_address} and VNC URL: {vnc_url}")
elif not ip_address and not vnc_url and vm_status != "running":
# Not running is expected in this case
logger.info(f"VM {name} is not running yet. Status: {vm_status}")
else:
# Missing IP or VNC but status is running - this is unusual but handled with default IP
logger.warning(f"VM {name} is running but missing expected fields. API response: {vm_info}")
# Return the full status information
return {
"name": name,
"status": vm_status,
"ip_address": ip_address,
"vnc_url": vnc_url,
"api_status": "ok",
"container_status": container_status,
**vm_info # Include all fields from the API response
}
except subprocess.SubprocessError as e:
logger.error(f"Failed to check container status: {e}")
return {
"name": name,
"status": "error",
"error": f"Failed to check container status: {str(e)}"
}
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all VMs managed by this provider.
For Lumier provider, there is only one VM per container.
"""
try:
status = await self.get_vm("default")
return [status] if status.get("status") != "unknown" else []
except Exception as e:
logger.error(f"Failed to list VMs: {e}")
return []
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Run a VM with the given options.
Args:
image: Name/tag of the image to use
name: Name of the VM to run (used for the container name and Docker image tag)
run_opts: Options for running the VM, including:
- cpu: Number of CPU cores
- memory: Amount of memory (e.g. "8GB")
- noVNC_port: Specific port for noVNC interface
Returns:
Dictionary with VM status information
"""
# Set the container name using the VM name for consistency
self.container_name = name
try:
# First, check if container already exists and remove it
try:
check_cmd = ["docker", "ps", "-a", "--filter", f"name={self.container_name}", "--format", "{{.ID}}"]
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
existing_container = check_result.stdout.strip()
if existing_container:
logger.info(f"Removing existing container: {self.container_name}")
remove_cmd = ["docker", "rm", "-f", self.container_name]
subprocess.run(remove_cmd, check=True)
except subprocess.CalledProcessError as e:
logger.warning(f"Error removing existing container: {e}")
# Continue anyway, next steps will fail if there's a real problem
# Prepare the Docker run command
cmd = ["docker", "run", "-d", "--name", self.container_name]
cmd.extend(["-p", f"{self.vnc_port}:8006"])
logger.debug(f"Using specified noVNC_port: {self.vnc_port}")
# Set API URL using the API port
self._api_url = f"http://{self.host}:{self.api_port}"
# Parse memory setting
memory_mb = self._parse_memory(run_opts.get("memory", "8GB"))
# Add storage volume mount if storage is specified (for persistent VM storage)
if self.storage and self.storage != "ephemeral":
# Create storage directory if it doesn't exist
storage_dir = os.path.abspath(os.path.expanduser(self.storage or ""))
os.makedirs(storage_dir, exist_ok=True)
# Add volume mount for storage
cmd.extend([
"-v", f"{storage_dir}:/storage",
"-e", f"HOST_STORAGE_PATH={storage_dir}"
])
logger.debug(f"Using persistent storage at: {storage_dir}")
# Add shared folder volume mount if shared_path is specified
if self.shared_path:
# Create shared directory if it doesn't exist
shared_dir = os.path.abspath(os.path.expanduser(self.shared_path or ""))
os.makedirs(shared_dir, exist_ok=True)
# Add volume mount for shared folder
cmd.extend([
"-v", f"{shared_dir}:/shared",
"-e", f"HOST_SHARED_PATH={shared_dir}"
])
logger.debug(f"Using shared folder at: {shared_dir}")
# Add environment variables
# Always use the container_name as the VM_NAME for consistency
# Use the VM image passed from the Computer class
logger.debug(f"Using VM image: {self.image}")
# If ghcr.io is in the image, use the full image name
if "ghcr.io" in self.image:
vm_image = self.image
else:
vm_image = f"ghcr.io/trycua/{self.image}"
cmd.extend([
"-e", f"VM_NAME={self.container_name}",
"-e", f"VERSION={vm_image}",
"-e", f"CPU_CORES={run_opts.get('cpu', '4')}",
"-e", f"RAM_SIZE={memory_mb}",
])
# Specify the Lumier image with the full image name
lumier_image = "trycua/lumier:latest"
# First check if the image exists locally
try:
logger.debug(f"Checking if Docker image {lumier_image} exists locally...")
check_image_cmd = ["docker", "image", "inspect", lumier_image]
subprocess.run(check_image_cmd, capture_output=True, check=True)
logger.debug(f"Docker image {lumier_image} found locally.")
except subprocess.CalledProcessError:
# Image doesn't exist locally
logger.warning(f"\nWARNING: Docker image {lumier_image} not found locally.")
logger.warning("The system will attempt to pull it from Docker Hub, which may fail if you have network connectivity issues.")
logger.warning("If the Docker pull fails, you may need to manually pull the image first with:")
logger.warning(f" docker pull {lumier_image}\n")
# Add the image to the command
cmd.append(lumier_image)
# Print the Docker command for debugging
logger.debug(f"DOCKER COMMAND: {' '.join(cmd)}")
# Run the container with improved error handling
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
except subprocess.CalledProcessError as e:
if "no route to host" in str(e.stderr).lower() or "failed to resolve reference" in str(e.stderr).lower():
error_msg = (f"Network error while trying to pull Docker image '{lumier_image}'\n"
f"Error: {e.stderr}\n\n"
f"SOLUTION: Please try one of the following:\n"
f"1. Check your internet connection\n"
f"2. Pull the image manually with: docker pull {lumier_image}\n"
f"3. Check if Docker is running properly\n")
logger.error(error_msg)
raise RuntimeError(error_msg)
raise
# Container started, now check VM status with polling
logger.debug("Container started, checking VM status...")
logger.debug("NOTE: This may take some time while the VM image is being pulled and initialized")
# Start a background thread to show container logs in real-time
import threading
def show_container_logs():
# Give the container a moment to start generating logs
time.sleep(1)
logger.debug(f"\n---- CONTAINER LOGS FOR '{name}' (LIVE) ----")
logger.debug("Showing logs as they are generated. Press Ctrl+C to stop viewing logs...\n")
try:
# Use docker logs with follow option
log_cmd = ["docker", "logs", "--tail", "30", "--follow", name]
process = subprocess.Popen(log_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1, universal_newlines=True)
# Read and print logs line by line
for line in process.stdout:
logger.debug(line, end='')
# Break if process has exited
if process.poll() is not None:
break
except Exception as e:
logger.error(f"\nError showing container logs: {e}")
if self.verbose:
logger.error(f"Error in log streaming thread: {e}")
finally:
logger.debug("\n---- LOG STREAMING ENDED ----")
# Make sure process is terminated
if 'process' in locals() and process.poll() is None:
process.terminate()
# Start log streaming in a background thread if verbose mode is enabled
log_thread = threading.Thread(target=show_container_logs)
log_thread.daemon = True # Thread will exit when main program exits
log_thread.start()
# Skip waiting for container readiness and just poll get_vm directly
# Poll the get_vm method indefinitely until the VM is ready with an IP address
attempt = 0
consecutive_errors = 0
vm_running = False
while True: # Wait indefinitely
try:
# Use longer delays to give the system time to initialize
if attempt > 0:
# Start with 5s delay, then increase gradually up to 30s for later attempts
# But use shorter delays while we're getting API errors
if consecutive_errors > 0 and consecutive_errors < 5:
wait_time = 3 # Use shorter delays when we're getting API errors
else:
wait_time = min(30, 5 + (attempt * 2))
logger.debug(f"Waiting {wait_time}s before retry #{attempt+1}...")
await asyncio.sleep(wait_time)
# Try to get VM status
logger.debug(f"Checking VM status (attempt {attempt+1})...")
vm_status = await self.get_vm(name)
# Check for API errors
if 'error' in vm_status:
consecutive_errors += 1
error_msg = vm_status.get('error', 'Unknown error')
# Only print a user-friendly status message, not the raw error
# since _lume_api_get already logged the technical details
if consecutive_errors == 1 or attempt % 5 == 0:
if 'Empty reply from server' in error_msg:
logger.info("API server is starting up - container is running, but API isn't fully initialized yet.")
logger.info("This is expected during the initial VM setup - will continue polling...")
else:
# Don't repeat the exact same error message each time
logger.warning(f"API request error (attempt {attempt+1}): {error_msg}")
# Just log that we're still working on it
if attempt > 3:
logger.debug("Still waiting for the API server to become available...")
# If we're getting errors but container is running, that's normal during startup
if vm_status.get('status') == 'running':
if not vm_running:
logger.info("Container is running, waiting for the VM within it to become fully ready...")
logger.info("This might take a minute while the VM initializes...")
vm_running = True
# Increase counter and continue
attempt += 1
continue
# Reset consecutive error counter when we get a successful response
consecutive_errors = 0
# If the VM is running, check if it has an IP address (which means it's fully ready)
if vm_status.get('status') == 'running':
vm_running = True
# Check if we have an IP address, which means the VM is fully ready
if 'ip_address' in vm_status and vm_status['ip_address']:
logger.info(f"VM is now fully running with IP: {vm_status.get('ip_address')}")
if 'vnc_url' in vm_status and vm_status['vnc_url']:
logger.info(f"VNC URL: {vm_status.get('vnc_url')}")
return vm_status
else:
logger.debug("VM is running but still initializing network interfaces...")
logger.debug("Waiting for IP address to be assigned...")
else:
# VM exists but might still be starting up
status = vm_status.get('status', 'unknown')
logger.debug(f"VM found but status is: {status}. Continuing to poll...")
# Increase counter for next iteration's delay calculation
attempt += 1
# If we reach a very large number of attempts, give a reassuring message but continue
if attempt % 10 == 0:
logger.debug(f"Still waiting after {attempt} attempts. This might take several minutes for first-time setup.")
if not vm_running and attempt >= 20:
logger.warning("\nNOTE: First-time VM initialization can be slow as images are downloaded.")
logger.warning("If this continues for more than 10 minutes, you may want to check:")
logger.warning(" 1. Docker logs with: docker logs " + name)
logger.warning(" 2. If your network can access container registries")
logger.warning("Press Ctrl+C to abort if needed.\n")
# After 150 attempts (likely over 30-40 minutes), return current status
if attempt >= 150:
logger.debug(f"Reached 150 polling attempts. VM status is: {vm_status.get('status', 'unknown')}")
logger.debug("Returning current VM status, but please check Docker logs if there are issues.")
return vm_status
except Exception as e:
# Always continue retrying, but with increasing delays
logger.warning(f"Error checking VM status (attempt {attempt+1}): {e}. Will retry.")
consecutive_errors += 1
# If we've had too many consecutive errors, might be a deeper problem
if consecutive_errors >= 10:
logger.warning(f"\nWARNING: Encountered {consecutive_errors} consecutive errors while checking VM status.")
logger.warning("You may need to check the Docker container logs or restart the process.")
logger.warning(f"Error details: {str(e)}\n")
# Increase attempt counter for next iteration
attempt += 1
# After many consecutive errors, add a delay to avoid hammering the system
if attempt > 5:
error_delay = min(30, 10 + attempt)
logger.warning(f"Multiple connection errors, waiting {error_delay}s before next attempt...")
await asyncio.sleep(error_delay)
except subprocess.CalledProcessError as e:
error_msg = f"Failed to start Lumier container: {e.stderr if hasattr(e, 'stderr') else str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
async def _wait_for_container_ready(self, container_name: str, timeout: int = 90) -> bool:
"""Wait for the Lumier container to be fully ready with a valid API response.
Args:
container_name: Name of the Docker container to check
timeout: Maximum time to wait in seconds (default: 90 seconds)
Returns:
True if the container is running, even if API is not fully ready.
This allows operations to continue with appropriate fallbacks.
"""
start_time = time.time()
api_ready = False
container_running = False
logger.debug(f"Waiting for container {container_name} to be ready (timeout: {timeout}s)...")
while time.time() - start_time < timeout:
# Check if container is running
try:
check_cmd = ["docker", "ps", "--filter", f"name={container_name}", "--format", "{{.Status}}"]
result = subprocess.run(check_cmd, capture_output=True, text=True, check=True)
container_status = result.stdout.strip()
if container_status and container_status.startswith("Up"):
container_running = True
logger.info(f"Container {container_name} is running with status: {container_status}")
else:
logger.warning(f"Container {container_name} not yet running, status: {container_status}")
# container is not running yet, wait and try again
await asyncio.sleep(2) # Longer sleep to give Docker time
continue
except subprocess.CalledProcessError as e:
logger.warning(f"Error checking container status: {e}")
await asyncio.sleep(2)
continue
# Container is running, check if API is responsive
try:
# First check the health endpoint
api_url = f"http://{self.host}:{self.api_port}/health"
logger.info(f"Checking API health at: {api_url}")
# Use longer timeout for API health check since it may still be initializing
curl_cmd = ["curl", "-s", "--connect-timeout", "5", "--max-time", "10", api_url]
result = subprocess.run(curl_cmd, capture_output=True, text=True)
if result.returncode == 0 and "ok" in result.stdout.lower():
api_ready = True
logger.info(f"API is ready at {api_url}")
break
else:
# API health check failed, now let's check if the VM status endpoint is responsive
# This covers cases where the health endpoint isn't implemented but the VM API is working
vm_api_url = f"http://{self.host}:{self.api_port}/lume/vms/{container_name}"
if self.storage:
import urllib.parse
encoded_storage = urllib.parse.quote_plus(self.storage)
vm_api_url += f"?storage={encoded_storage}"
curl_vm_cmd = ["curl", "-s", "--connect-timeout", "5", "--max-time", "10", vm_api_url]
vm_result = subprocess.run(curl_vm_cmd, capture_output=True, text=True)
if vm_result.returncode == 0 and vm_result.stdout.strip():
# VM API responded with something - consider the API ready
api_ready = True
logger.info(f"VM API is ready at {vm_api_url}")
break
else:
curl_code = result.returncode
if curl_code == 0:
curl_code = vm_result.returncode
# Map common curl error codes to helpful messages
if curl_code == 7:
curl_error = "Failed to connect - API server is starting up"
elif curl_code == 22:
curl_error = "HTTP error returned from API server"
elif curl_code == 28:
curl_error = "Operation timeout - API server is slow to respond"
elif curl_code == 52:
curl_error = "Empty reply from server - API is starting but not ready"
elif curl_code == 56:
curl_error = "Network problem during data transfer"
else:
curl_error = f"Unknown curl error code: {curl_code}"
logger.info(f"API not ready yet: {curl_error}")
except subprocess.SubprocessError as e:
logger.warning(f"Error checking API status: {e}")
# If the container is running but API is not ready, that's OK - we'll just wait
# a bit longer before checking again, as the container may still be initializing
elapsed_seconds = time.time() - start_time
if int(elapsed_seconds) % 5 == 0: # Only print status every 5 seconds to reduce verbosity
logger.debug(f"Waiting for API to initialize... ({elapsed_seconds:.1f}s / {timeout}s)")
await asyncio.sleep(3) # Longer sleep between API checks
# Handle timeout - if the container is running but API is not ready, that's not
# necessarily an error - the API might just need more time to start up
if not container_running:
logger.warning(f"Timed out waiting for container {container_name} to start")
return False
if not api_ready:
logger.warning(f"Container {container_name} is running, but API is not fully ready yet.")
logger.warning(f"NOTE: You may see some 'API request failed' messages while the API initializes.")
# Return True if container is running, even if API isn't ready yet
# This allows VM operations to proceed, with appropriate retries for API calls
return container_running
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Stop a running VM by stopping the Lumier container."""
try:
# Use Docker commands to stop the container directly
if hasattr(self, '_container_id') and self._container_id:
logger.info(f"Stopping Lumier container: {self.container_name}")
cmd = ["docker", "stop", self.container_name]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
logger.info(f"Container stopped: {result.stdout.strip()}")
# Return minimal status info
return {
"name": name,
"status": "stopped",
"container_id": self._container_id,
}
else:
# Try to find the container by name
check_cmd = ["docker", "ps", "-a", "--filter", f"name={self.container_name}", "--format", "{{.ID}}"]
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
container_id = check_result.stdout.strip()
if container_id:
logger.info(f"Found container ID: {container_id}")
cmd = ["docker", "stop", self.container_name]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
logger.info(f"Container stopped: {result.stdout.strip()}")
return {
"name": name,
"status": "stopped",
"container_id": container_id,
}
else:
logger.warning(f"No container found with name {self.container_name}")
return {
"name": name,
"status": "unknown",
}
except subprocess.CalledProcessError as e:
error_msg = f"Failed to stop container: {e.stderr if hasattr(e, 'stderr') else str(e)}"
logger.error(error_msg)
raise RuntimeError(f"Failed to stop Lumier container: {error_msg}")
# update_vm is not implemented as it's not needed for Lumier
# The BaseVMProvider requires it, so we provide a minimal implementation
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Not implemented for Lumier provider."""
logger.warning("update_vm is not implemented for Lumier provider")
return {"name": name, "status": "unchanged"}
async def get_logs(self, name: str, num_lines: int = 100, follow: bool = False, timeout: Optional[int] = None) -> str:
"""Get the logs from the Lumier container.
Args:
name: Name of the VM/container to get logs for
num_lines: Number of recent log lines to return (default: 100)
follow: If True, follow the logs (stream new logs as they are generated)
timeout: Optional timeout in seconds for follow mode (None means no timeout)
Returns:
Container logs as a string
Note:
If follow=True, this function will continuously stream logs until timeout
or until interrupted. The output will be printed to console in real-time.
"""
if not HAS_LUMIER:
error_msg = "Docker is not available. Cannot get container logs."
logger.error(error_msg)
return error_msg
# Make sure we have a container name
container_name = name
# Check if the container exists and is running
try:
# Check if the container exists
inspect_cmd = ["docker", "container", "inspect", container_name]
result = subprocess.run(inspect_cmd, capture_output=True, text=True)
if result.returncode != 0:
error_msg = f"Container '{container_name}' does not exist or is not accessible"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Error checking container status: {str(e)}"
logger.error(error_msg)
return error_msg
# Base docker logs command
log_cmd = ["docker", "logs"]
# Add tail parameter to limit the number of lines
log_cmd.extend(["--tail", str(num_lines)])
# Handle follow mode with or without timeout
if follow:
log_cmd.append("--follow")
if timeout is not None:
# For follow mode with timeout, we'll run the command and handle the timeout
log_cmd.append(container_name)
logger.info(f"Following logs for container '{container_name}' with timeout {timeout}s")
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LIVE) ----")
logger.info(f"Press Ctrl+C to stop following logs\n")
try:
# Run with timeout
process = subprocess.Popen(log_cmd, text=True)
# Wait for the specified timeout
if timeout:
try:
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
process.terminate() # Stop after timeout
logger.info(f"\n---- LOG FOLLOWING STOPPED (timeout {timeout}s reached) ----")
else:
# Without timeout, wait for user interruption
process.wait()
return "Logs were displayed to console in follow mode"
except KeyboardInterrupt:
process.terminate()
logger.info("\n---- LOG FOLLOWING STOPPED (user interrupted) ----")
return "Logs were displayed to console in follow mode (interrupted)"
else:
# For follow mode without timeout, we'll print a helpful message
log_cmd.append(container_name)
logger.info(f"Following logs for container '{container_name}' indefinitely")
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LIVE) ----")
logger.info(f"Press Ctrl+C to stop following logs\n")
try:
# Run the command and let it run until interrupted
process = subprocess.Popen(log_cmd, text=True)
process.wait() # Wait indefinitely (until user interrupts)
return "Logs were displayed to console in follow mode"
except KeyboardInterrupt:
process.terminate()
logger.info("\n---- LOG FOLLOWING STOPPED (user interrupted) ----")
return "Logs were displayed to console in follow mode (interrupted)"
else:
# For non-follow mode, capture and return the logs as a string
log_cmd.append(container_name)
logger.info(f"Getting {num_lines} log lines for container '{container_name}'")
try:
result = subprocess.run(log_cmd, capture_output=True, text=True, check=True)
logs = result.stdout
# Only print header and logs if there's content
if logs.strip():
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LAST {num_lines} LINES) ----\n")
logger.info(logs)
logger.info(f"\n---- END OF LOGS ----")
else:
logger.info(f"\nNo logs available for container '{container_name}'")
return logs
except subprocess.CalledProcessError as e:
error_msg = f"Error getting logs: {e.stderr}"
logger.error(error_msg)
return error_msg
except Exception as e:
error_msg = f"Unexpected error getting logs: {str(e)}"
logger.error(error_msg)
return error_msg
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.
Args:
name: Name of the VM to get the IP for
storage: Optional storage path override
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM when it becomes available
"""
# Use container_name = name for consistency
self.container_name = name
# Track total attempts for logging purposes
total_attempts = 0
# Loop indefinitely until we get a valid IP
while True:
total_attempts += 1
# Log retry message but not on first attempt
if total_attempts > 1:
logger.info(f"Waiting for VM {name} IP address (attempt {total_attempts})...")
try:
# Get VM information
vm_info = await self.get_vm(name, storage=storage)
# Check if we got a valid IP
ip = vm_info.get("ip_address", None)
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
logger.info(f"Got valid VM IP address: {ip}")
return ip
# Check the VM status
status = vm_info.get("status", "unknown")
# Special handling for Lumier: it may report "stopped" even when the VM is starting
# If the VM information contains an IP but status is stopped, it might be a race condition
if status == "stopped" and "ip_address" in vm_info:
ip = vm_info.get("ip_address")
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
logger.info(f"Found valid IP {ip} despite VM status being {status}")
return ip
logger.info(f"VM status is {status}, but still waiting for IP to be assigned")
# If VM is not running yet, log and wait
elif status != "running":
logger.info(f"VM is not running yet (status: {status}). Waiting...")
# If VM is running but no IP yet, wait and retry
else:
logger.info("VM is running but no valid IP address yet. Waiting...")
except Exception as e:
logger.warning(f"Error getting VM {name} IP: {e}, continuing to wait...")
# Wait before next retry
await asyncio.sleep(retry_delay)
# Add progress log every 10 attempts
if total_attempts % 10 == 0:
logger.info(f"Still waiting for VM {name} IP after {total_attempts} attempts...")
async def __aenter__(self):
"""Async context manager entry.
This method is called when entering an async context manager block.
Returns self to be used in the context.
"""
logger.debug("Entering LumierProvider context")
# Initialize the API URL with the default value if not already set
# This ensures get_vm can work before run_vm is called
if not hasattr(self, '_api_url') or not self._api_url:
self._api_url = f"http://{self.host}:{self.api_port}"
logger.info(f"Initialized default Lumier API URL: {self._api_url}")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit.
This method is called when exiting an async context manager block.
It handles proper cleanup of resources, including stopping any running containers.
"""
logger.debug(f"Exiting LumierProvider context, handling exceptions: {exc_type}")
try:
# If we have a container ID, we should stop it to clean up resources
if hasattr(self, '_container_id') and self._container_id:
logger.info(f"Stopping Lumier container on context exit: {self.container_name}")
try:
cmd = ["docker", "stop", self.container_name]
subprocess.run(cmd, capture_output=True, text=True, check=True)
logger.info(f"Container stopped during context exit: {self.container_name}")
except subprocess.CalledProcessError as e:
logger.warning(f"Failed to stop container during cleanup: {e.stderr}")
# Don't raise an exception here, we want to continue with cleanup
except Exception as e:
logger.error(f"Error during LumierProvider cleanup: {e}")
# We don't want to suppress the original exception if there was one
if exc_type is None:
raise
# Return False to indicate that any exception should propagate
return False

View File

@@ -0,0 +1,11 @@
"""Windows Sandbox provider for CUA Computer."""
try:
import winsandbox
HAS_WINSANDBOX = True
except ImportError:
HAS_WINSANDBOX = False
from .provider import WinSandboxProvider
__all__ = ["WinSandboxProvider", "HAS_WINSANDBOX"]

View File

@@ -0,0 +1,468 @@
"""Windows Sandbox VM provider implementation using pywinsandbox."""
import os
import asyncio
import logging
import time
from typing import Dict, Any, Optional, List
from ..base import BaseVMProvider, VMProviderType
# Setup logging
logger = logging.getLogger(__name__)
try:
import winsandbox
HAS_WINSANDBOX = True
except ImportError:
HAS_WINSANDBOX = False
class WinSandboxProvider(BaseVMProvider):
"""Windows Sandbox VM provider implementation using pywinsandbox.
This provider uses Windows Sandbox to create isolated Windows environments.
Storage is always ephemeral with Windows Sandbox.
"""
def __init__(
self,
port: int = 7777,
host: str = "localhost",
storage: Optional[str] = None,
verbose: bool = False,
ephemeral: bool = True, # Windows Sandbox is always ephemeral
memory_mb: int = 4096,
networking: bool = True,
**kwargs
):
"""Initialize the Windows Sandbox provider.
Args:
port: Port for the computer server (default: 7777)
host: Host to use for connections (default: localhost)
storage: Storage path (ignored - Windows Sandbox is always ephemeral)
verbose: Enable verbose logging
ephemeral: Always True for Windows Sandbox
memory_mb: Memory allocation in MB (default: 4096)
networking: Enable networking in sandbox (default: True)
"""
if not HAS_WINSANDBOX:
raise ImportError(
"pywinsandbox is required for WinSandboxProvider. "
"Please install it with 'pip install pywinsandbox'"
)
self.host = host
self.port = port
self.verbose = verbose
self.memory_mb = memory_mb
self.networking = networking
# Windows Sandbox is always ephemeral
if not ephemeral:
logger.warning("Windows Sandbox storage is always ephemeral. Ignoring ephemeral=False.")
self.ephemeral = True
# Storage is always ephemeral for Windows Sandbox
if storage and storage != "ephemeral":
logger.warning("Windows Sandbox does not support persistent storage. Using ephemeral storage.")
self.storage = "ephemeral"
self.logger = logging.getLogger(__name__)
# Track active sandboxes
self._active_sandboxes: Dict[str, Any] = {}
@property
def provider_type(self) -> VMProviderType:
"""Get the provider type."""
return VMProviderType.WINSANDBOX
async def __aenter__(self):
"""Enter async context manager."""
# Verify Windows Sandbox is available
if not HAS_WINSANDBOX:
raise ImportError("pywinsandbox is not available")
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit async context manager."""
# Clean up any active sandboxes
for name, sandbox in self._active_sandboxes.items():
try:
sandbox.shutdown()
self.logger.info(f"Terminated sandbox: {name}")
except Exception as e:
self.logger.error(f"Error terminating sandbox {name}: {e}")
self._active_sandboxes.clear()
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Get VM information by name.
Args:
name: Name of the VM to get information for
storage: Ignored for Windows Sandbox (always ephemeral)
Returns:
Dictionary with VM information including status, IP address, etc.
"""
if name not in self._active_sandboxes:
return {
"name": name,
"status": "stopped",
"ip_address": None,
"storage": "ephemeral"
}
sandbox = self._active_sandboxes[name]
# Check if sandbox is still running
try:
# Try to ping the sandbox to see if it's responsive
try:
sandbox.rpyc.modules.os.getcwd()
sandbox_responsive = True
except Exception:
sandbox_responsive = False
if not sandbox_responsive:
return {
"name": name,
"status": "starting",
"ip_address": None,
"storage": "ephemeral",
"memory_mb": self.memory_mb,
"networking": self.networking
}
# Check for computer server address file
server_address_file = r"C:\Users\WDAGUtilityAccount\Desktop\shared_windows_sandbox_dir\server_address"
try:
# Check if the server address file exists
file_exists = sandbox.rpyc.modules.os.path.exists(server_address_file)
if file_exists:
# Read the server address file
with sandbox.rpyc.builtin.open(server_address_file, 'r') as f:
server_address = f.read().strip()
if server_address and ':' in server_address:
# Parse IP:port from the file
ip_address, port = server_address.split(':', 1)
# Verify the server is actually responding
try:
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(3)
result = sock.connect_ex((ip_address, int(port)))
sock.close()
if result == 0:
# Server is responding
status = "running"
self.logger.debug(f"Computer server found at {ip_address}:{port}")
else:
# Server file exists but not responding
status = "starting"
ip_address = None
except Exception as e:
self.logger.debug(f"Error checking server connectivity: {e}")
status = "starting"
ip_address = None
else:
# File exists but doesn't contain valid address
status = "starting"
ip_address = None
else:
# Server address file doesn't exist yet
status = "starting"
ip_address = None
except Exception as e:
self.logger.debug(f"Error checking server address file: {e}")
status = "starting"
ip_address = None
except Exception as e:
self.logger.error(f"Error checking sandbox status: {e}")
status = "error"
ip_address = None
return {
"name": name,
"status": status,
"ip_address": ip_address,
"storage": "ephemeral",
"memory_mb": self.memory_mb,
"networking": self.networking
}
async def list_vms(self) -> List[Dict[str, Any]]:
"""List all available VMs."""
vms = []
for name in self._active_sandboxes.keys():
vm_info = await self.get_vm(name)
vms.append(vm_info)
return vms
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Run a VM with the given options.
Args:
image: Image name (ignored for Windows Sandbox - always uses host Windows)
name: Name of the VM to run
run_opts: Dictionary of run options (memory, cpu, etc.)
storage: Ignored for Windows Sandbox (always ephemeral)
Returns:
Dictionary with VM run status and information
"""
if name in self._active_sandboxes:
return {
"success": False,
"error": f"Sandbox {name} is already running"
}
try:
# Extract options from run_opts
memory_mb = run_opts.get("memory_mb", self.memory_mb)
if isinstance(memory_mb, str):
# Convert memory string like "4GB" to MB
if memory_mb.upper().endswith("GB"):
memory_mb = int(float(memory_mb[:-2]) * 1024)
elif memory_mb.upper().endswith("MB"):
memory_mb = int(memory_mb[:-2])
else:
memory_mb = self.memory_mb
networking = run_opts.get("networking", self.networking)
# Create folder mappers if shared directories are specified
folder_mappers = []
shared_directories = run_opts.get("shared_directories", [])
for shared_dir in shared_directories:
if isinstance(shared_dir, dict):
host_path = shared_dir.get("hostPath", "")
elif isinstance(shared_dir, str):
host_path = shared_dir
else:
continue
if host_path and os.path.exists(host_path):
folder_mappers.append(winsandbox.FolderMapper(host_path))
self.logger.info(f"Creating Windows Sandbox: {name}")
self.logger.info(f"Memory: {memory_mb}MB, Networking: {networking}")
if folder_mappers:
self.logger.info(f"Shared directories: {len(folder_mappers)}")
# Create the sandbox without logon script
sandbox = winsandbox.new_sandbox(
memory_mb=str(memory_mb),
networking=networking,
folder_mappers=folder_mappers
)
# Store the sandbox
self._active_sandboxes[name] = sandbox
self.logger.info(f"Windows Sandbox {name} created successfully")
# Setup the computer server in the sandbox
await self._setup_computer_server(sandbox, name)
return {
"success": True,
"name": name,
"status": "starting",
"memory_mb": memory_mb,
"networking": networking,
"storage": "ephemeral"
}
except Exception as e:
self.logger.error(f"Failed to create Windows Sandbox {name}: {e}")
# stack trace
import traceback
self.logger.error(f"Stack trace: {traceback.format_exc()}")
return {
"success": False,
"error": f"Failed to create sandbox: {str(e)}"
}
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
"""Stop a running VM.
Args:
name: Name of the VM to stop
storage: Ignored for Windows Sandbox
Returns:
Dictionary with stop status and information
"""
if name not in self._active_sandboxes:
return {
"success": False,
"error": f"Sandbox {name} is not running"
}
try:
sandbox = self._active_sandboxes[name]
# Terminate the sandbox
sandbox.shutdown()
# Remove from active sandboxes
del self._active_sandboxes[name]
self.logger.info(f"Windows Sandbox {name} stopped successfully")
return {
"success": True,
"name": name,
"status": "stopped"
}
except Exception as e:
self.logger.error(f"Failed to stop Windows Sandbox {name}: {e}")
return {
"success": False,
"error": f"Failed to stop sandbox: {str(e)}"
}
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
"""Update VM configuration.
Note: Windows Sandbox does not support runtime configuration updates.
The sandbox must be stopped and restarted with new configuration.
Args:
name: Name of the VM to update
update_opts: Dictionary of update options
storage: Ignored for Windows Sandbox
Returns:
Dictionary with update status and information
"""
return {
"success": False,
"error": "Windows Sandbox does not support runtime configuration updates. "
"Please stop and restart the sandbox with new configuration."
}
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
"""Get the IP address of a VM, waiting indefinitely until it's available.
Args:
name: Name of the VM to get the IP for
storage: Ignored for Windows Sandbox
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM when it becomes available
"""
total_attempts = 0
# Loop indefinitely until we get a valid IP
while True:
total_attempts += 1
# Log retry message but not on first attempt
if total_attempts > 1:
self.logger.info(f"Waiting for Windows Sandbox {name} IP address (attempt {total_attempts})...")
try:
# Get VM information
vm_info = await self.get_vm(name, storage=storage)
# Check if we got a valid IP
ip = vm_info.get("ip_address", None)
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
self.logger.info(f"Got valid Windows Sandbox IP address: {ip}")
return ip
# Check the VM status
status = vm_info.get("status", "unknown")
# If VM is not running yet, log and wait
if status != "running":
self.logger.info(f"Windows Sandbox is not running yet (status: {status}). Waiting...")
# If VM is running but no IP yet, wait and retry
else:
self.logger.info("Windows Sandbox is running but no valid IP address yet. Waiting...")
except Exception as e:
self.logger.warning(f"Error getting Windows Sandbox {name} IP: {e}, continuing to wait...")
# Wait before next retry
await asyncio.sleep(retry_delay)
# Add progress log every 10 attempts
if total_attempts % 10 == 0:
self.logger.info(f"Still waiting for Windows Sandbox {name} IP after {total_attempts} attempts...")
async def _setup_computer_server(self, sandbox, name: str, visible: bool = False):
"""Setup the computer server in the Windows Sandbox using RPyC.
Args:
sandbox: The Windows Sandbox instance
name: Name of the sandbox
visible: Whether the opened process should be visible (default: False)
"""
try:
self.logger.info(f"Setting up computer server in sandbox {name}...")
# Read the PowerShell setup script
script_path = os.path.join(os.path.dirname(__file__), "setup_script.ps1")
with open(script_path, 'r', encoding='utf-8') as f:
setup_script_content = f.read()
# Write the setup script to the sandbox using RPyC
script_dest_path = r"C:\Users\WDAGUtilityAccount\setup_cua.ps1"
self.logger.info(f"Writing setup script to {script_dest_path}")
with sandbox.rpyc.builtin.open(script_dest_path, 'w') as f:
f.write(setup_script_content)
# Execute the PowerShell script in the background
self.logger.info("Executing setup script in sandbox...")
# Use subprocess to run PowerShell script
import subprocess
powershell_cmd = [
"powershell.exe",
"-ExecutionPolicy", "Bypass",
"-NoExit", # Keep window open after script completes
"-File", script_dest_path
]
# Set creation flags based on visibility preference
if visible:
# CREATE_NEW_CONSOLE - creates a new console window (visible)
creation_flags = 0x00000010
else:
creation_flags = 0x08000000 # CREATE_NO_WINDOW
# Start the process using RPyC
process = sandbox.rpyc.modules.subprocess.Popen(
powershell_cmd,
creationflags=creation_flags,
shell=False
)
# # Sleep for 30 seconds
# await asyncio.sleep(30)
ip = await self.get_ip(name)
self.logger.info(f"Sandbox IP: {ip}")
self.logger.info(f"Setup script started in background in sandbox {name} with PID: {process.pid}")
except Exception as e:
self.logger.error(f"Failed to setup computer server in sandbox {name}: {e}")
import traceback
self.logger.error(f"Stack trace: {traceback.format_exc()}")

View File

@@ -0,0 +1,124 @@
# Setup script for Windows Sandbox CUA Computer provider
# This script runs when the sandbox starts
Write-Host "Starting CUA Computer setup in Windows Sandbox..."
# Function to find the mapped Python installation from pywinsandbox
function Find-MappedPython {
Write-Host "Looking for mapped Python installation from pywinsandbox..."
# pywinsandbox maps the host Python installation to the sandbox
# Look for mapped shared folders on the desktop (common pywinsandbox pattern)
$desktopPath = "C:\Users\WDAGUtilityAccount\Desktop"
$sharedFolders = Get-ChildItem -Path $desktopPath -Directory -ErrorAction SilentlyContinue
foreach ($folder in $sharedFolders) {
# Look for Python executables in shared folders
$pythonPaths = @(
"$($folder.FullName)\python.exe",
"$($folder.FullName)\Scripts\python.exe",
"$($folder.FullName)\bin\python.exe"
)
foreach ($pythonPath in $pythonPaths) {
if (Test-Path $pythonPath) {
try {
$version = & $pythonPath --version 2>&1
if ($version -match "Python") {
Write-Host "Found mapped Python: $pythonPath - $version"
return $pythonPath
}
} catch {
continue
}
}
}
# Also check subdirectories that might contain Python
$subDirs = Get-ChildItem -Path $folder.FullName -Directory -ErrorAction SilentlyContinue
foreach ($subDir in $subDirs) {
$pythonPath = "$($subDir.FullName)\python.exe"
if (Test-Path $pythonPath) {
try {
$version = & $pythonPath --version 2>&1
if ($version -match "Python") {
Write-Host "Found mapped Python in subdirectory: $pythonPath - $version"
return $pythonPath
}
} catch {
continue
}
}
}
}
# Fallback: try common Python commands that might be available
$pythonCommands = @("python", "py", "python3")
foreach ($cmd in $pythonCommands) {
try {
$version = & $cmd --version 2>&1
if ($version -match "Python") {
Write-Host "Found Python via command '$cmd': $version"
return $cmd
}
} catch {
continue
}
}
throw "Could not find any Python installation (mapped or otherwise)"
}
try {
# Step 1: Find the mapped Python installation
Write-Host "Step 1: Finding mapped Python installation..."
$pythonExe = Find-MappedPython
Write-Host "Using Python: $pythonExe"
# Verify Python works and show version
$pythonVersion = & $pythonExe --version 2>&1
Write-Host "Python version: $pythonVersion"
# Step 2: Install cua-computer-server directly
Write-Host "Step 2: Installing cua-computer-server..."
Write-Host "Upgrading pip..."
& $pythonExe -m pip install --upgrade pip --quiet
Write-Host "Installing cua-computer-server..."
& $pythonExe -m pip install cua-computer-server --quiet
Write-Host "cua-computer-server installation completed."
# Step 3: Start computer server in background
Write-Host "Step 3: Starting computer server in background..."
Write-Host "Starting computer server with: $pythonExe"
# Start the computer server in the background
$serverProcess = Start-Process -FilePath $pythonExe -ArgumentList "-m", "computer_server.main" -WindowStyle Hidden -PassThru
Write-Host "Computer server started in background with PID: $($serverProcess.Id)"
# Give it a moment to start
Start-Sleep -Seconds 3
# Check if the process is still running
if (Get-Process -Id $serverProcess.Id -ErrorAction SilentlyContinue) {
Write-Host "Computer server is running successfully in background"
} else {
throw "Computer server failed to start or exited immediately"
}
} catch {
Write-Error "Setup failed: $_"
Write-Host "Error details: $($_.Exception.Message)"
Write-Host "Stack trace: $($_.ScriptStackTrace)"
Write-Host ""
Write-Host "Press any key to close this window..."
$null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
exit 1
}
Write-Host ""
Write-Host "Setup completed successfully!"
Write-Host "Press any key to close this window..."
$null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")

View File

@@ -0,0 +1,116 @@
"""Computer telemetry for tracking anonymous usage and feature usage."""
import logging
import platform
from typing import Any
# Import the core telemetry module
TELEMETRY_AVAILABLE = False
try:
from core.telemetry import (
increment,
is_telemetry_enabled,
is_telemetry_globally_disabled,
record_event,
)
def increment_counter(counter_name: str, value: int = 1) -> None:
"""Wrapper for increment to maintain backward compatibility."""
if is_telemetry_enabled():
increment(counter_name, value)
def set_dimension(name: str, value: Any) -> None:
"""Set a dimension that will be attached to all events."""
logger = logging.getLogger("computer.telemetry")
logger.debug(f"Setting dimension {name}={value}")
TELEMETRY_AVAILABLE = True
logger = logging.getLogger("computer.telemetry")
logger.info("Successfully imported telemetry")
except ImportError as e:
logger = logging.getLogger("computer.telemetry")
logger.warning(f"Could not import telemetry: {e}")
TELEMETRY_AVAILABLE = False
# Local fallbacks in case core telemetry isn't available
def _noop(*args: Any, **kwargs: Any) -> None:
"""No-op function for when telemetry is not available."""
pass
logger = logging.getLogger("computer.telemetry")
# If telemetry isn't available, use no-op functions
if not TELEMETRY_AVAILABLE:
logger.debug("Telemetry not available, using no-op functions")
record_event = _noop # type: ignore
increment_counter = _noop # type: ignore
set_dimension = _noop # type: ignore
get_telemetry_client = lambda: None # type: ignore
flush = _noop # type: ignore
is_telemetry_enabled = lambda: False # type: ignore
is_telemetry_globally_disabled = lambda: True # type: ignore
# Get system info once to use in telemetry
SYSTEM_INFO = {
"os": platform.system().lower(),
"os_version": platform.release(),
"python_version": platform.python_version(),
}
def enable_telemetry() -> bool:
"""Enable telemetry if available.
Returns:
bool: True if telemetry was successfully enabled, False otherwise
"""
global TELEMETRY_AVAILABLE
# Check if globally disabled using core function
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
return False
# Already enabled
if TELEMETRY_AVAILABLE:
return True
# Try to import and enable
try:
# Verify we can import core telemetry
from core.telemetry import record_event # type: ignore
TELEMETRY_AVAILABLE = True
logger.info("Telemetry successfully enabled")
return True
except ImportError as e:
logger.warning(f"Could not enable telemetry: {e}")
return False
def is_telemetry_enabled() -> bool:
"""Check if telemetry is enabled.
Returns:
bool: True if telemetry is enabled, False otherwise
"""
# Use the core function if available, otherwise use our local flag
if TELEMETRY_AVAILABLE:
from core.telemetry import is_telemetry_enabled as core_is_enabled
return core_is_enabled()
return False
def record_computer_initialization() -> None:
"""Record when a computer instance is initialized."""
if TELEMETRY_AVAILABLE and is_telemetry_enabled():
record_event("computer_initialized", SYSTEM_INFO)
# Set dimensions that will be attached to all events
set_dimension("os", SYSTEM_INFO["os"])
set_dimension("os_version", SYSTEM_INFO["os_version"])
set_dimension("python_version", SYSTEM_INFO["python_version"])

View File

@@ -0,0 +1 @@
"""UI modules for the Computer Interface."""

View File

@@ -0,0 +1,15 @@
"""
Main entry point for computer.ui module.
This allows running the computer UI with:
python -m computer.ui
Instead of:
python -m computer.ui.gradio.app
"""
from .gradio.app import create_gradio_ui
if __name__ == "__main__":
app = create_gradio_ui()
app.launch()

View File

@@ -0,0 +1,6 @@
"""Gradio UI for Computer UI."""
import gradio as gr
from typing import Optional
from .app import create_gradio_ui

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,101 @@
import base64
from typing import Tuple, Optional, Dict, Any
from PIL import Image, ImageDraw
import io
def decode_base64_image(base64_str: str) -> bytes:
"""Decode a base64 string into image bytes."""
return base64.b64decode(base64_str)
def encode_base64_image(image_bytes: bytes) -> str:
"""Encode image bytes to base64 string."""
return base64.b64encode(image_bytes).decode('utf-8')
def bytes_to_image(image_bytes: bytes) -> Image.Image:
"""Convert bytes to PIL Image.
Args:
image_bytes: Raw image bytes
Returns:
PIL.Image: The converted image
"""
return Image.open(io.BytesIO(image_bytes))
def image_to_bytes(image: Image.Image, format: str = 'PNG') -> bytes:
"""Convert PIL Image to bytes."""
buf = io.BytesIO()
image.save(buf, format=format)
return buf.getvalue()
def resize_image(image_bytes: bytes, scale_factor: float) -> bytes:
"""Resize an image by a scale factor.
Args:
image_bytes: The original image as bytes
scale_factor: Factor to scale the image by (e.g., 0.5 for half size, 2.0 for double)
Returns:
bytes: The resized image as bytes
"""
image = bytes_to_image(image_bytes)
if scale_factor != 1.0:
new_size = (int(image.width * scale_factor), int(image.height * scale_factor))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image_to_bytes(image)
def draw_box(
image_bytes: bytes,
x: int,
y: int,
width: int,
height: int,
color: str = "#FF0000",
thickness: int = 2
) -> bytes:
"""Draw a box on an image.
Args:
image_bytes: The original image as bytes
x: X coordinate of top-left corner
y: Y coordinate of top-left corner
width: Width of the box
height: Height of the box
color: Color of the box in hex format
thickness: Thickness of the box border in pixels
Returns:
bytes: The modified image as bytes
"""
# Convert bytes to PIL Image
image = bytes_to_image(image_bytes)
# Create drawing context
draw = ImageDraw.Draw(image)
# Draw rectangle
draw.rectangle(
[(x, y), (x + width, y + height)],
outline=color,
width=thickness
)
# Convert back to bytes
return image_to_bytes(image)
def get_image_size(image_bytes: bytes) -> Tuple[int, int]:
"""Get the dimensions of an image.
Args:
image_bytes: The image as bytes
Returns:
Tuple[int, int]: Width and height of the image
"""
image = bytes_to_image(image_bytes)
return image.size
def parse_vm_info(vm_info: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Parse VM info from pylume response."""
if not vm_info:
return None

View File

@@ -0,0 +1,2 @@
[virtualenvs]
in-project = true

View File

@@ -0,0 +1,73 @@
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
[project]
name = "cua-computer"
version = "0.2.0"
description = "Computer-Use Interface (CUI) framework powering Cua"
readme = "README.md"
authors = [
{ name = "TryCua", email = "gh@trycua.com" }
]
dependencies = [
"pillow>=10.0.0",
"websocket-client>=1.8.0",
"websockets>=12.0",
"aiohttp>=3.9.0",
"cua-core>=0.1.0,<0.2.0",
"pydantic>=2.11.1"
]
requires-python = ">=3.11"
[project.optional-dependencies]
lume = [
]
lumier = [
]
ui = [
"gradio>=5.23.3,<6.0.0",
"python-dotenv>=1.0.1,<2.0.0",
"datasets>=3.6.0,<4.0.0",
]
all = [
# Include all optional dependencies
"gradio>=5.23.3,<6.0.0",
"python-dotenv>=1.0.1,<2.0.0",
"datasets>=3.6.0,<4.0.0",
]
[tool.pdm]
distribution = true
[tool.pdm.build]
includes = ["computer/"]
source-includes = ["tests/", "README.md", "LICENSE"]
[tool.black]
line-length = 100
target-version = ["py311"]
[tool.ruff]
line-length = 100
target-version = "py311"
select = ["E", "F", "B", "I"]
fix = true
[tool.ruff.format]
docstring-code-format = true
[tool.mypy]
strict = true
python_version = "3.11"
ignore_missing_imports = true
disallow_untyped_defs = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = "test_*.py"